├── .gitignore
├── DejaVuSansMono-Bold.ttf
├── LICENSE.txt
├── README.md
├── convert.py
├── datasets.py
├── eval.py
├── img
├── DualDecoderArch.png
└── tokenization.png
├── metric.py
├── models.py
├── parallel.py
├── prepare_data.py
├── requirements.txt
├── train_dual_decoder.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 |
--------------------------------------------------------------------------------
/DejaVuSansMono-Bold.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ibm-aur-nlp/EDD/e6bb2bd509d7c89cdf1fb9608f9db9a044413bed/DejaVuSansMono-Bold.ttf
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Copyright 2021 IBM
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Image-based table recognition: data, model, evaluation
2 |
3 | ## Task
4 |
5 | Converting table images into HTML code
6 |
7 | ## Dataset
8 |
9 | [PubTabNet](https://github.com/ibm-aur-nlp/PubTabNet) contains over 500k table
10 | images annotated with the corresponding HTML representation.
11 |
12 | ## Model
13 |
14 | Encoder-Dual-Decoder (EDD)
15 |
16 | ")
17 |
18 | ## Evaluation
19 |
20 | **T**ree-**E**dit-**D**istance-based **S**imilarity (TEDS)
21 |
22 | `TEDS(T_1, T_2) = 1 - EditDistance(T_1, T_2) / max(|T_1|, |T_2|)`, where `EditDistance(T_1, T_2)` is the tree edit distance between `T_1` and `T_2`, and `|T|` is the number of nodes in `T`.
23 |
24 | ## Installation
25 |
26 | Please use python 3 (>=3.6) environment.
27 |
28 | `pip install -r requirements`
29 |
30 | ## Training and testing on PubTabNet
31 |
32 | ### Prepare data
33 |
34 | Download PubTabNet and extract the files into the following file structure
35 | ```
36 | {DATA_DIR}
37 | |
38 | -- train
39 | |
40 | -- PMCXXXXXXX.png
41 | -- ...
42 | -- val
43 | |
44 | -- PMCXXXXXXX.png
45 | -- ...
46 | -- test
47 | |
48 | -- PMCXXXXXXX.png
49 | -- ...
50 | -- PubTabNet_2.0.0.jsonl
51 | ```
52 |
53 | Prepare data for training
54 | ```
55 | python prepare_data.py \
56 | --annotation {DATA_DIR}/PubTabNet_2.0.0.jsonl \
57 | --image_dir {DATA_DIR} \
58 | --out_dir {TRAIN_DATA_DIR}
59 | ```
60 |
61 | The following files will be generated in {TRAIN_DATA_DIR}:
62 | ```
63 | - TRAIN_IMAGES_{POSTFIX}.h5 # Training images
64 | - TRAIN_TAGS_{POSTFIX}.json # Training structural tokens
65 | - TRAIN_TAGLENS_{POSTFIX}.json # Length of training structural tokens
66 | - TRAIN_CELLS_{POSTFIX}.json # Training cell tokens
67 | - TRAIN_CELLLENS_{POSTFIX}.json # Length of training cell tokens
68 | - TRAIN_CELLBBOXES_{POSTFIX}.json # Training cell bboxes
69 | - VAL.json # Validation ground truth
70 | - WORDMAP_{POSTFIX}.json # Vocab
71 | ```
72 | where `{POSTFIX}` is `PubTabNet_False_keep_AR_300_max_tag_len_100_max_cell_len_512_max_image_size`
73 | ### Train tag decoder
74 |
75 | Use larger (0.001) learning rate in the first 10 epochs
76 | ```
77 | python train_dual_decoder.py \
78 | --out_dir {CHECKPOINT_DIR} \
79 | --data_folder {TRAIN_DATA_DIR} \
80 | --data_name PubTabNet_False_keep_AR_300_max_tag_len_100_max_cell_len_512_max_image_size \
81 | --epochs 10 \
82 | --batch_size 10 \
83 | --fine_tune_encoder \
84 | --encoder_lr 0.001 \
85 | --fine_tune_tag_decoder \
86 | --tag_decoder_lr 0.001 \
87 | --tag_loss_weight 1.0 \
88 | --cell_decoder_lr 0.001 \
89 | --cell_loss_weight 0.0 \
90 | --tag_embed_dim 16 \
91 | --cell_embed_dim 80 \
92 | --encoded_image_size 28 \
93 | --decoder_cell LSTM \
94 | --tag_attention_dim 256 \
95 | --cell_attention_dim 256 \
96 | --tag_decoder_dim 256 \
97 | --cell_decoder_dim 512 \
98 | --cell_decoder_type 1 \
99 | --cnn_stride '{"tag":1, "cell":1}' \
100 | --resume
101 | ```
102 |
103 | Use smaller (0.0001) learning rate for another 3 epochs
104 | ```
105 | python train_dual_decoder.py \
106 | --out_dir {CHECKPOINT_DIR} \
107 | --data_folder {TRAIN_DATA_DIR} \
108 | --data_name PubTabNet_False_keep_AR_300_max_tag_len_100_max_cell_len_512_max_image_size \
109 | --epochs 13 \
110 | --batch_size 10 \
111 | --fine_tune_encoder \
112 | --encoder_lr 0.0001 \
113 | --fine_tune_tag_decoder \
114 | --tag_decoder_lr 0.0001 \
115 | --tag_loss_weight 1.0 \
116 | --cell_decoder_lr 0.001 \
117 | --cell_loss_weight 0.0 \
118 | --tag_embed_dim 16 \
119 | --cell_embed_dim 80 \
120 | --encoded_image_size 28 \
121 | --decoder_cell LSTM \
122 | --tag_attention_dim 256 \
123 | --cell_attention_dim 256 \
124 | --tag_decoder_dim 256 \
125 | --cell_decoder_dim 512 \
126 | --cell_decoder_type 1 \
127 | --cnn_stride '{"tag":1, "cell":1}' \
128 | --resume
129 | ```
130 |
131 | ### Train dual decoders
132 |
133 | **NOTE**:
134 | - Sometimes when a random batch is too large, it may exceeds the GPU memory. When this happens, just re-execute the training command, which will resume from the latest checkpoint.
135 | - Training dual decoders requires 2 V100 GPUs.
136 |
137 | Use larger (0.001) learning rate in the first 10 epochs
138 | ```
139 | python train_dual_decoder.py \
140 | --checkpoint {CHECKPOINT_DIR}/PubTabNet_False_keep_AR_300_max_tag_len_100_max_cell_len_512_max_image_size/checkpoint_12.pth.tar \
141 | --out_dir {CHECKPOINT_DIR}/cell_decoder \
142 | --data_folder {TRAIN_DATA_DIR} \
143 | --data_name PubTabNet_False_keep_AR_300_max_tag_len_100_max_cell_len_512_max_image_size \
144 | --epochs 23 \
145 | --batch_size 8 \
146 | --fine_tune_encoder \
147 | --encoder_lr 0.001 \
148 | --fine_tune_tag_decoder \
149 | --tag_decoder_lr 0.001 \
150 | --tag_loss_weight 0.5 \
151 | --cell_decoder_lr 0.001 \
152 | --cell_loss_weight 0.5 \
153 | --tag_embed_dim 16 \
154 | --cell_embed_dim 80 \
155 | --encoded_image_size 28 \
156 | --decoder_cell LSTM \
157 | --tag_attention_dim 256 \
158 | --cell_attention_dim 256 \
159 | --tag_decoder_dim 256 \
160 | --cell_decoder_dim 512 \
161 | --cell_decoder_type 1 \
162 | --cnn_stride '{"tag":1, "cell":1}' \
163 | --resume \
164 | --predict_content
165 | ```
166 |
167 | Use smaller (0.0001) learning rate for another 2 epochs
168 | ```
169 | python train_dual_decoder.py \
170 | --out_dir {CHECKPOINT_DIR}/cell_decoder \
171 | --data_folder {TRAIN_DATA_DIR} \
172 | --data_name PubTabNet_False_keep_AR_300_max_tag_len_100_max_cell_len_512_max_image_size \
173 | --epochs 25 \
174 | --batch_size 8 \
175 | --fine_tune_encoder \
176 | --encoder_lr 0.0001 \
177 | --fine_tune_tag_decoder \
178 | --tag_decoder_lr 0.0001 \
179 | --tag_loss_weight 0.5 \
180 | --cell_decoder_lr 0.0001 \
181 | --cell_loss_weight 0.5 \
182 | --tag_embed_dim 16 \
183 | --cell_embed_dim 80 \
184 | --encoded_image_size 28 \
185 | --decoder_cell LSTM \
186 | --tag_attention_dim 256 \
187 | --cell_attention_dim 256 \
188 | --tag_decoder_dim 256 \
189 | --cell_decoder_dim 512 \
190 | --cell_decoder_type 1 \
191 | --cnn_stride '{"tag":1, "cell":1}' \
192 | --resume \
193 | --predict_content
194 | ```
195 |
196 |
197 | ### Inferencing
198 |
199 | Get validation performance
200 | ```
201 | python eval.py \
202 | --image_folder {DATA_DIR}/val \
203 | --result_json {RESULT_DIR}/RESULT_FILE.json \
204 | --gt {TRAIN_DATA_DIR}/VAL.json \
205 | --model {CHECKPOINT_DIR}/cell_decoder/PubTabNet_False_keep_AR_300_max_tag_len_100_max_cell_len_512_max_image_size/checkpoint_24.pth.tar \
206 | --word_map {TRAIN_DATA_DIR}/WORDMAP_PubTabNet_False_keep_AR_300_max_tag_len_100_max_cell_len_512_max_image_size.json \
207 | --image_size 448 \
208 | --dual_decoder \
209 | --beam_size '{"tag":3, "cell":3}' \
210 | --max_steps '{"tag":1800, "cell":600}'
211 | ```
212 | This will save the TEDS score of every validation sample in `{RESULT_DIR}/RESULT_FILE.json` in the following format:
213 | ```
214 | {
215 | 'PMCXXXXXXX.png': float,
216 | }
217 | ```
218 |
219 | Get testing performance
220 | ```
221 | python eval.py \
222 | --image_folder {DATA_DIR}/test \
223 | --result_json {RESULT_DIR}/RESULT_FILE.json \
224 | --model {CHECKPOINT_DIR}/cell_decoder/PubTabNet_False_keep_AR_300_max_tag_len_100_max_cell_len_512_max_image_size/checkpoint_24.pth.tar \
225 | --word_map {TRAIN_DATA_DIR}/WORDMAP_PubTabNet_False_keep_AR_300_max_tag_len_100_max_cell_len_512_max_image_size.json \
226 | --image_size 448 \
227 | --dual_decoder \
228 | --beam_size '{"tag":3, "cell":3}' \
229 | --max_steps '{"tag":1800, "cell":600}'
230 | ```
231 | This will save the inference result (HTML code) of every testing sample in `{RESULT_DIR}/RESULT_FILE.json` in the following format:
232 | ```
233 | {
234 | 'PMCXXXXXXX.png': str,
235 | }
236 | ```
237 | The json file can be compared agains the ground truth using the code [here](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src). The ground truth of test set has been kept secret.
238 |
239 | ## Cite us
240 |
241 | ```
242 | @article{zhong2019image,
243 | title={Image-based table recognition: data, model, and evaluation},
244 | author={Zhong, Xu and ShafieiBavani, Elaheh and Yepes, Antonio Jimeno},
245 | journal={arXiv preprint arXiv:1911.10683},
246 | year={2019}
247 | }
248 | ```
249 |
--------------------------------------------------------------------------------
/convert.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import json
4 | import torchvision.transforms as transforms
5 | import skimage.transform
6 | import argparse
7 | from PIL import Image, ImageDraw, ImageFont
8 | from utils import image_rescale, image_resize
9 | from metric import format_html
10 | import os
11 | from glob import glob
12 | from tqdm import tqdm
13 | import shutil
14 | import textwrap
15 |
16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17 |
18 | def caption_image_beam_search(encoder, decoder, image_path, word_map,
19 | image_size=448, max_steps=400, beam_size=3,
20 | vis_att=False):
21 | """
22 | Reads an image and captions it with beam search.
23 |
24 | :param encoder: encoder model
25 | :param decoder: decoder model
26 | :param image_path: path to image
27 | :param word_map: word map
28 | :param beam_size: number of sequences to consider at each decode-step
29 | :return: caption, weights for visualization
30 | """
31 | # Read image and process
32 | img = image_rescale(image_path, image_size, False)
33 | img = img / 255.
34 | img = torch.FloatTensor(img)
35 | normalize = transforms.Normalize(mean=[0.94247851, 0.94254675, 0.94292611],
36 | std=[0.17910956, 0.17940403, 0.17931663])
37 | transform = transforms.Compose([normalize])
38 | image = transform(img).to(device) # (3, image_size, image_size)
39 |
40 | # Encode
41 | image = image.unsqueeze(0) # (1, 3, image_size, image_size)
42 | encoder_out = encoder(image) # (1, enc_image_size, enc_image_size, encoder_dim)
43 |
44 | return decoder.inference(encoder_out, word_map, max_steps, beam_size, return_attention=vis_att)
45 |
46 | def visualize_result(image_path, res, rev_word_map, smooth=True, image_size=448):
47 | """
48 | Visualizes caption with weights at every word.
49 |
50 | Adapted from paper authors' repo: https://github.com/kelvinxu/arctic-captions/blob/master/alpha_visualization.ipynb
51 |
52 | :param image_path: path to image that has been captioned
53 | :param res: result of inference model
54 | :param rev_word_map: reverse word mapping, i.e. ix2word
55 | :param smooth: smooth weights?
56 | """
57 |
58 | def vis_attention(c, image, alpha, smooth, image_size, original_size, x_offset=0, cap=None):
59 | alpha = np.array(alpha)
60 | if smooth:
61 | alpha = skimage.transform.pyramid_expand(alpha, upscale=image_size / alpha.shape[0], sigma=4)
62 | else:
63 | alpha = alpha.repeat(image_size / alpha.shape[0], axis=0).repeat(image_size / alpha.shape[0], axis=1)
64 | if cap is None:
65 | alpha = (alpha - np.min(alpha)) / (np.max(alpha) - np.min(alpha))
66 | else:
67 | alpha *= 1 / cap
68 | alpha[alpha > 1.] = 1.
69 | alpha *= 255.
70 | alpha = alpha.astype('uint8')
71 | alpha = Image.fromarray(alpha)
72 | image = image.convert("RGBA")
73 | alpha = alpha.convert("RGBA")
74 | new_img = Image.blend(image, alpha, 0.6)
75 | new_img = new_img.resize(original_size, Image.LANCZOS)
76 | if c:
77 | font = ImageFont.truetype("DejaVuSansMono-Bold.ttf", 24)
78 | # font = ImageFont.truetype(os.environ["DATA_DIR"] + "/Table2HTML/dejavu/DejaVuSansMono-Bold.ttf", 24)
79 | lines = textwrap.wrap(c, width=25)
80 | w, h = font.getsize(lines[0])
81 | H = h * len(lines)
82 | y_text = original_size[1] / 2 - H / 2
83 | draw = ImageDraw.Draw(new_img)
84 | for line in lines:
85 | w, h = font.getsize(line)
86 | draw.text(((original_size[0] - w) / 2 + x_offset, y_text), line, (255, 255, 255), font=font)
87 | y_text += h
88 | return new_img
89 |
90 | if len(res) == 2:
91 | tags, cells = res
92 | elif len(res) == 4:
93 | tags, tag_alphas, cells, cell_alphas = res
94 | with open(image_path.replace('.png', '.html'), 'w') as fp:
95 | fp.write(format_html(tags, rev_word_map['tag'], cells, rev_word_map['cell']))
96 |
97 | if len(res) == 4:
98 | image, original_size = image_resize(image_path, image_size, False)
99 | folder = image_path[:-4]
100 | if os.path.exists(folder):
101 | shutil.rmtree(folder)
102 | os.makedirs(folder)
103 | os.makedirs(os.path.join(folder, 'structure'))
104 | os.makedirs(os.path.join(folder, 'cells'))
105 |
106 | for ind, (c, alpha) in enumerate(zip(tags[1:], tag_alphas[1:]), 1):
107 | if ind <= 50 or len(tags[1:]) - ind <= 50:
108 | new_img = vis_attention(rev_word_map['tag'][c], image, alpha, smooth, image_size, original_size, cap=None)
109 | new_img.save(os.path.join(folder, 'structure', '%03d.png' % (ind)), "PNG")
110 |
111 | for j, (cell, alphas) in enumerate(zip(cells, cell_alphas)):
112 | if cell is not None:
113 | # for ind, (c, alpha) in enumerate(zip(cell[1:], alphas[1:]), 1):
114 | # # if ind <= 5 or len(cell[1:]) - ind <= 5:
115 | # new_img = vis_attention(rev_word_map['cell'][c], image, alpha, smooth, image_size, original_size)
116 | # new_img.save(os.path.join(folder, 'cells', '%03d_%03d.png' % (j, ind)), "PNG")
117 | new_img = vis_attention(''.join([rev_word_map['cell'][c] for c in cell[1:-1]]),
118 | image,
119 | np.mean(alphas[1:-1], axis=0) if len(alphas[1:-1]) else np.mean(alphas[1:], axis=0),
120 | smooth, image_size, original_size,
121 | x_offset=50 if j % 3 == 0 and j > 0 else 0,
122 | cap=None)
123 | new_img.save(os.path.join(folder, 'cells', '%03d.png' % (j)), "PNG")
124 |
125 |
126 | if __name__ == '__main__':
127 | parser = argparse.ArgumentParser(description='Inference on given images')
128 |
129 | parser.add_argument('--input', '-i', help='path to image')
130 | parser.add_argument('--model', '-m', help='path to model')
131 | parser.add_argument('--word_map', '-wm', help='path to word map JSON')
132 | parser.add_argument('--image_size', '-is', default=448, type=int, help='target size of image rescaling')
133 | parser.add_argument('--beam_size', '-b', default={"tag": 3, "cell": 3}, type=json.loads, help='beam size for beam search')
134 | parser.add_argument('--max_steps', '-ms', default=400, type=json.loads, help='max output steps of decoder')
135 | parser.add_argument('--dont_smooth', dest='smooth', action='store_false', help='do not smooth alpha overlay')
136 | parser.add_argument('--vis_attention', dest='vis_attention', action='store_true', help='visualize attention')
137 |
138 | args = parser.parse_args()
139 |
140 | # Load model
141 | checkpoint = torch.load(args.model)
142 | decoder = checkpoint['decoder']
143 | decoder = decoder.to(device)
144 | decoder.eval()
145 | encoder = checkpoint['encoder']
146 | encoder = encoder.to(device)
147 | encoder.eval()
148 |
149 | # Load word map (word2ix)
150 | with open(args.word_map, 'r') as j:
151 | word_map = json.load(j)
152 | rev_word_map = {'tag': {v: k for k, v in word_map['word_map_tag'].items()},
153 | 'cell': {v: k for k, v in word_map['word_map_cell'].items()}}
154 |
155 | if os.path.isfile(args.input):
156 | # Encode, decode with attention and beam search
157 | res = caption_image_beam_search(encoder, decoder, args.input, word_map, args.image_size, args.max_steps, args.beam_size, args.vis_attention)
158 | if res is None:
159 | print('No complete sequence is generated')
160 | else:
161 | # Visualize caption and attention of best sequence
162 | visualize_result(args.input, res, rev_word_map, args.smooth, args.image_size)
163 | elif os.path.exists(args.input):
164 | images = glob(os.path.join(args.input, '*.png')) + glob(os.path.join(args.input, '*.jpg'))
165 | for image in tqdm(images):
166 | # Encode, decode with attention and beam search
167 | try:
168 | res = caption_image_beam_search(encoder, decoder, image, word_map, args.image_size, args.max_steps, args.beam_size, args.vis_attention)
169 | except Exception as e:
170 | print(e)
171 | res = None
172 | if res is not None:
173 | # Visualize caption and attention of best sequence
174 | visualize_result(image, res, rev_word_map, args.smooth, args.image_size)
175 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import h5py
3 | import json
4 | import os
5 | import numpy as np
6 | import random
7 |
8 | class TableDatasetEvenLength(object):
9 | """
10 | Data loader for training baseline encoder-decoder model (WYGIWYS, Dent et al. 2017)
11 | """
12 |
13 | def __init__(self, data_folder, data_name, batch_size, transform=None):
14 | # Open hdf5 file where images are stored
15 | f = os.path.join(data_folder, 'TRAIN_IMAGES_' + data_name + '.hdf5')
16 | self.h = h5py.File(f, 'r')
17 |
18 | self.imgs = self.h['images']
19 |
20 | # Load encoded tables (completely into memory)
21 | with open(os.path.join(data_folder, 'TRAIN_TABLES_' + data_name + '.json'), 'r') as j:
22 | self.tables = json.load(j)
23 |
24 | # Load table lengths (completely into memory)
25 | with open(os.path.join(data_folder, 'TRAIN_TABLELENS_' + data_name + '.json'), 'r') as j:
26 | self.tablelens = json.load(j)
27 |
28 | # PyTorch transformation pipeline for the image (normalizing, etc.)
29 | self.transform = transform
30 | self.batch_size = batch_size
31 | self.batch_id = 0
32 |
33 | def shuffle(self):
34 | self.batch_id = 0
35 | self.batches = [[]]
36 | len_dict = dict()
37 | # Split samples into groups by table lengths
38 | for i, l in enumerate(self.tablelens):
39 | if l in len_dict:
40 | len_dict[l].append(i)
41 | else:
42 | len_dict[l] = [i]
43 | # Fill with long samples first, so that the samples do not need to be sorted before training
44 | lens = sorted(list(len_dict.keys()), key=lambda x: -x)
45 | # Shuffle each group
46 | for l in lens:
47 | random.shuffle(len_dict[l])
48 | # Generate batches
49 | for l in lens:
50 | k = 0
51 | # Fill previous incomplete batch
52 | if len(self.batches[-1]) < self.batch_size:
53 | deficit = min(len(len_dict[l]), self.batch_size - len(self.batches[-1]))
54 | self.batches[-1] += len_dict[l][k:k + deficit]
55 | k = deficit
56 | # Generate complete batches
57 | while len(len_dict[l]) - k >= self.batch_size:
58 | self.batches.append(len_dict[l][k:k + self.batch_size])
59 | k += self.batch_size
60 | # Create an incomplete batch with left overs
61 | if k < len(len_dict[l]):
62 | self.batches.append(len_dict[l][k:])
63 | # Shuffle the order of batches
64 | random.shuffle(self.batches)
65 |
66 | def __iter__(self):
67 | return self
68 |
69 | def __next__(self):
70 | if self.batch_id < len(self.batches):
71 | samples = self.batches[self.batch_id]
72 | image_size = self.imgs[samples[0]].shape
73 | imgs = torch.zeros(len(samples), image_size[0], image_size[1], image_size[2], dtype=torch.float)
74 | table_size = len(self.tables[samples[0]])
75 | tables = torch.zeros(len(samples), table_size, dtype=torch.long)
76 | tablelens = torch.zeros(len(samples), 1, dtype=torch.long)
77 | for i, sample in enumerate(samples):
78 | img = torch.FloatTensor(self.imgs[sample] / 255.)
79 | if self.transform is not None:
80 | imgs[i] = self.transform(img)
81 | else:
82 | imgs[i] = img
83 | tables[i] = torch.LongTensor(self.tables[sample])
84 | tablelens[i] = torch.LongTensor([self.tablelens[sample]])
85 | self.batch_id += 1
86 | return imgs, tables, tablelens
87 | else:
88 | raise StopIteration()
89 |
90 | def __len__(self):
91 | return len(self.batches)
92 |
93 | class TagCellDataset(object):
94 | """
95 | Data loader for training encoder-dual-decoder model
96 | """
97 |
98 | def __init__(self, data_folder, data_name, split, batch_size, mode='all', transform=None):
99 | """
100 | :param data_folder: folder where data files are stored
101 | :param data_name: base name of processed datasets
102 | :param split: split, one of 'TRAIN', 'VAL', or 'TEST'
103 | :param batch_size: batch size
104 | :param mode: 'tag', 'tag+cell', 'tag+bbox', or 'tag+cell+bbox'
105 | :param transform: image transform pipeline
106 | """
107 |
108 | assert split in {'TRAIN', 'VAL', 'TEST'}
109 | assert mode in {'tag', 'tag+cell', 'tag+bbox', 'tag+cell+bbox'}
110 |
111 | self.split = split
112 | self.mode = mode
113 | self.batch_size = batch_size
114 |
115 | # Open hdf5 file where images are stored
116 | f = os.path.join(data_folder, self.split + '_IMAGES_' + data_name + '.hdf5')
117 | self.h = h5py.File(f, 'r')
118 | self.imgs = self.h['images']
119 |
120 | # Load encoded tags (completely into memory)
121 | with open(os.path.join(data_folder, self.split + '_TAGS_' + data_name + '.json'), 'r') as j:
122 | self.tags = json.load(j)
123 |
124 | # Load tag lengths (completely into memory)
125 | with open(os.path.join(data_folder, self.split + '_TAGLENS_' + data_name + '.json'), 'r') as j:
126 | self.taglens = json.load(j)
127 |
128 | # Load cell lengths (completely into memory)
129 | with open(os.path.join(data_folder, self.split + '_CELLLENS_' + data_name + '.json'), 'r') as j:
130 | self.celllens = json.load(j)
131 |
132 | if 'cell' in self.mode:
133 | # Load encoded cell tokens (completely into memory)
134 | with open(os.path.join(data_folder, self.split + '_CELLS_' + data_name + '.json'), 'r') as j:
135 | self.cells = json.load(j)
136 |
137 | if 'bbox' in self.mode:
138 | # Load encoded tags (completely into memory)
139 | with open(os.path.join(data_folder, self.split + '_CELLBBOXES_' + data_name + '.json'), 'r') as j:
140 | self.cellbboxes = json.load(j)
141 |
142 | # PyTorch transformation pipeline for the image (normalizing, etc.)
143 | self.transform = transform
144 |
145 | # Total number of datapoints
146 | self.dataset_size = len(self.tags)
147 | self.ind = np.array(range(self.dataset_size))
148 | self.pointer = 0
149 |
150 | def shuffle(self):
151 | self.ind = np.random.permutation(self.dataset_size)
152 | self.pointer = 0
153 |
154 | def __iter__(self):
155 | return self
156 |
157 | def __getitem__(self, i):
158 | img = torch.FloatTensor(self.imgs[i])
159 | tags = self.tags[i]
160 | taglens = self.taglens[i]
161 | cells = self.cells[i]
162 | celllens = self.celllens[i]
163 | image_size = self.imgsizes[i]
164 | return img, tags, taglens, cells, celllens, image_size
165 |
166 | def __next__(self):
167 | if self.pointer < self.dataset_size:
168 | if self.dataset_size - self.pointer >= self.batch_size:
169 | step = self.batch_size
170 | samples = self.ind[self.pointer:self.pointer + step]
171 | else:
172 | step = self.dataset_size - self.pointer
173 | lack = self.batch_size - step
174 | samples = np.hstack((self.ind[self.pointer:self.pointer + step], np.array(range(lack))))
175 | image_size = self.imgs[samples[0]].shape
176 | imgs = torch.zeros(len(samples), image_size[0], image_size[1], image_size[2], dtype=torch.float)
177 | max_tag_len = max([self.taglens[sample] for sample in samples])
178 | tags = torch.zeros(len(samples), max_tag_len, dtype=torch.long)
179 | taglens = torch.zeros(len(samples), 1, dtype=torch.long)
180 | num_cells = torch.zeros(len(samples), 1, dtype=torch.long)
181 | if 'cell' in self.mode:
182 | cells = []
183 | celllens = []
184 | if 'bbox' in self.mode:
185 | cellbboxes = []
186 |
187 | for i, sample in enumerate(samples):
188 | img = torch.FloatTensor(self.imgs[sample] / 255.)
189 | if self.transform is not None:
190 | imgs[i] = self.transform(img)
191 | else:
192 | imgs[i] = img
193 | tags[i] = torch.LongTensor(self.tags[sample][:max_tag_len])
194 | taglens[i] = torch.LongTensor([self.taglens[sample]])
195 | num_cells[i] = len(self.celllens[sample])
196 | if 'cell' in self.mode:
197 | max_cell_len = max(self.celllens[sample])
198 | cells.append(torch.LongTensor(self.cells[sample])[:, :max_cell_len])
199 | celllens.append(torch.LongTensor(self.celllens[sample]))
200 | if 'bbox' in self.mode:
201 | cellbboxes.append(torch.FloatTensor(self.cellbboxes[sample]))
202 |
203 | self.pointer += step
204 | output = (imgs, tags, taglens, num_cells)
205 | if 'cell' in self.mode:
206 | output += (cells, celllens)
207 | if 'bbox' in self.mode:
208 | output += (cellbboxes,)
209 | return output
210 | else:
211 | raise StopIteration()
212 |
213 | def __len__(self):
214 | return int(np.ceil(self.dataset_size / self.batch_size))
215 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import json
4 | import torchvision.transforms as transforms
5 | import argparse
6 | import os
7 | from tqdm import tqdm
8 | import sys
9 | import time
10 | from utils import image_rescale
11 | from metric import format_html, similarity_eval_html
12 | from lxml import html
13 | import numpy as np
14 | from glob import glob
15 | import traceback
16 |
17 |
18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19 |
20 | def convert_table_beam_search(encoder, decoder, image_path, word_map, rev_word_map,
21 | image_size=448, max_steps=400, beam_size=3,
22 | dual_decoder=True):
23 | """
24 | Reads an image and captions it with beam search.
25 |
26 | :param encoder: encoder model
27 | :param decoder: decoder model
28 | :param image_path: path to image
29 | :param word_map: word map
30 | :param max_steps: max numerb of decoding steps
31 | :param beam_size: number of sequences to consider at each decode-step
32 | :param dual_decoder: if the model has dual decoders
33 | :return: HTML code of input table image
34 | """
35 | # Read image and process
36 | img = image_rescale(image_path, image_size, False)
37 | img = img / 255.
38 | img = torch.FloatTensor(img)
39 | normalize = transforms.Normalize(mean=[0.94247851, 0.94254675, 0.94292611],
40 | std=[0.17910956, 0.17940403, 0.17931663])
41 | transform = transforms.Compose([normalize])
42 | image = transform(img).to(device) # (3, image_size, image_size)
43 |
44 | # Encode
45 | image = image.unsqueeze(0) # (1, 3, image_size, image_size)
46 | encoder_out = encoder(image) # (1, enc_image_size, enc_image_size, encoder_dim)
47 |
48 | res = decoder.inference(encoder_out, word_map, max_steps, beam_size, return_attention=False)
49 | if res is not None:
50 | if dual_decoder:
51 | if len(res) == 2:
52 | html_string = format_html(res[0], rev_word_map['tag'], res[1], rev_word_map['cell'])
53 | else:
54 | html_string = format_html(res[0], rev_word_map['tag'])
55 | else:
56 | html_string = format_html(res, rev_word_map)
57 | else:
58 | html_string = ''
59 | return html_string
60 |
61 |
62 | if __name__ == '__main__':
63 | parser = argparse.ArgumentParser(description='Evaluation of table2html conversion models')
64 |
65 | parser.add_argument('--image_folder', type=str, help='path to image folder')
66 | parser.add_argument('--result_json', type=str, help='path to save results (json)')
67 | parser.add_argument('--model', help='path to model')
68 | parser.add_argument('--word_map', help='path to word map JSON')
69 | parser.add_argument('--gt', default=None, type=str, help='path to ground truth')
70 | parser.add_argument('--image_size', default=448, type=int, help='target size of image rescaling')
71 | parser.add_argument('--dual_decoder', default=False, dest='dual_decoder', action='store_true', help='the decoder is a dual decoder')
72 | parser.add_argument('--beam_size', default={"tag": 3, "cell": 3}, type=json.loads, help='beam size for beam search')
73 | parser.add_argument('--max_steps', default={"tag": 1800, "cell": 600}, type=json.loads, help='max output steps of decoder')
74 |
75 | args = parser.parse_args()
76 |
77 | # Wait until model file exists
78 | if not os.path.isfile(args.model):
79 | while not os.path.isfile(args.model):
80 | print('Model not found, retry in 10 minutes', file=sys.stderr)
81 | sys.stderr.flush()
82 | time.sleep(600)
83 | # Make sure model file is saved completely
84 | time.sleep(10)
85 | # Load model
86 | checkpoint = torch.load(args.model)
87 |
88 | decoder = checkpoint['decoder']
89 | decoder = decoder.to(device)
90 | decoder.eval()
91 | encoder = checkpoint['encoder']
92 | encoder = encoder.to(device)
93 | encoder.eval()
94 |
95 | # Load word map (word2ix)
96 | with open(args.word_map, 'r') as j:
97 | word_map = json.load(j)
98 |
99 | if args.dual_decoder:
100 | rev_word_map = {'tag': {v: k for k, v in word_map['word_map_tag'].items()},
101 | 'cell': {v: k for k, v in word_map['word_map_cell'].items()}}
102 | else:
103 | rev_word_map = {v: k for k, v in word_map.items()} # ix2word
104 |
105 | # Load ground truth
106 | if args.gt is not None:
107 | with open(args.gt, 'r') as j:
108 | gt = json.load(j)
109 |
110 | normalize = transforms.Normalize(mean=[0.94247851, 0.94254675, 0.94292611],
111 | std=[0.17910956, 0.17940403, 0.17931663])
112 | transform = transforms.Compose([normalize])
113 |
114 | if args.gt is None:
115 | # Ground truth of test set is not provide. To evaluate test performance,
116 | # Please do not specify the ground truth file, and all png images in
117 | # image_folderwill be converted. Conversion results are saved in a json,
118 | # which can be uploaded to our evaluation service (coming soon) for
119 | # evaluation.
120 | HTML = dict()
121 | images = glob(os.path.join(args.image_folder, '*.png'))
122 | for filename in tqdm(images):
123 | try:
124 | html_pred = convert_table_beam_search(
125 | encoder, decoder, filename, word_map, rev_word_map,
126 | args.image_size, args.max_steps, args.beam_size,
127 | args.dual_decoder)
128 | except Exception as e:
129 | traceback.print_exc()
130 | html_pred = ''
131 | HTML[os.path.basename(filename)] = html_pred
132 | if not os.path.exists(os.path.dirname(args.result_json)):
133 | os.makedirs(os.path.dirname(args.result_json))
134 | with open(args.result_json, 'w') as fp:
135 | json.dump(HTML, fp)
136 | else:
137 | # Ground truth of validation set is provide. Please specify the ground
138 | # truth file, and the TEDS scores on simple, complex, and all table
139 | # samples will be computed.
140 | TEDS = dict()
141 | for filename, attributes in tqdm(gt.items()):
142 | try:
143 | html_pred = convert_table_beam_search(
144 | encoder, decoder,
145 | os.path.join(args.image_folder, filename),
146 | word_map, rev_word_map,
147 | args.image_size, args.max_steps, args.beam_size,
148 | args.dual_decoder)
149 | if html_pred:
150 | TEDS[filename] = similarity_eval_html(html.fromstring(html_pred), html.fromstring(attributes['html']))
151 | else:
152 | TEDS[filename] = 0.
153 | except Exception as e:
154 | traceback.print_exc()
155 | TEDS[filename] = 0.
156 |
157 | simple = [TEDS[filename] for filename, attributes in gt.items() if attributes['type'] == 'simple']
158 | complex = [TEDS[filename] for filename, attributes in gt.items() if attributes['type'] == 'complex']
159 | total = [TEDS[filename] for filename, attributes in gt.items()]
160 |
161 | print('TEDS of %d simple tables: %.3f' % (len(simple), np.mean(simple)))
162 | print('TEDS of %d complex tables: %.3f' % (len(complex), np.mean(complex)))
163 | print('TEDS of %d all tables: %.3f' % (len(total), np.mean(total)))
164 |
165 | if not os.path.exists(os.path.dirname(args.result_json)):
166 | os.makedirs(os.path.dirname(args.result_json))
167 | with open(args.result_json, 'w') as fp:
168 | json.dump(TEDS, fp)
169 |
--------------------------------------------------------------------------------
/img/DualDecoderArch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ibm-aur-nlp/EDD/e6bb2bd509d7c89cdf1fb9608f9db9a044413bed/img/DualDecoderArch.png
--------------------------------------------------------------------------------
/img/tokenization.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ibm-aur-nlp/EDD/e6bb2bd509d7c89cdf1fb9608f9db9a044413bed/img/tokenization.png
--------------------------------------------------------------------------------
/metric.py:
--------------------------------------------------------------------------------
1 | import distance
2 | from apted import APTED, Config
3 | from apted.helpers import Tree
4 | from lxml import html
5 | from collections import deque
6 | from parallel import parallel_process
7 | import numpy as np
8 | import subprocess
9 | import re
10 | import os
11 | import sys
12 | from html import escape
13 |
14 | class TableTree(Tree):
15 | def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
16 | self.tag = tag
17 | self.colspan = colspan
18 | self.rowspan = rowspan
19 | self.content = content
20 | self.children = list(children)
21 |
22 | def bracket(self):
23 | """Show tree using brackets notation"""
24 | if self.tag == 'td':
25 | result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
26 | (self.tag, self.colspan, self.rowspan, self.content)
27 | else:
28 | result = '"tag": %s' % self.tag
29 | for child in self.children:
30 | result += child.bracket()
31 | return "{{{}}}".format(result)
32 |
33 | class CustomConfig(Config):
34 | @staticmethod
35 | def maximum(*sequences):
36 | """Get maximum possible value
37 | """
38 | return max(map(len, sequences))
39 |
40 | def normalized_distance(self, *sequences):
41 | """Get distance from 0 to 1
42 | """
43 | return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
44 |
45 | def rename(self, node1, node2):
46 | """Compares attributes of trees"""
47 | if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
48 | return 1.
49 | if node1.tag == 'td':
50 | if node1.content or node2.content:
51 | return self.normalized_distance(node1.content, node2.content)
52 | return 0.
53 |
54 | def tokenize(node):
55 | ''' Tokenizes table cells
56 | '''
57 | global __tokens__
58 | __tokens__.append('<%s>' % node.tag)
59 | if node.text is not None:
60 | __tokens__ += list(node.text)
61 | for n in node.getchildren():
62 | tokenize(n)
63 | if node.tag != 'unk':
64 | __tokens__.append('%s>' % node.tag)
65 | if node.tag != 'td' and node.tail is not None:
66 | __tokens__ += list(node.tail)
67 |
68 | def format_html(tags, rev_word_map_tags, cells=None, rev_word_map_cells=None):
69 | ''' Formats html code from raw model output
70 | '''
71 | HTML = [rev_word_map_tags[ind] for ind in tags[1:-1]]
72 | if cells is not None:
73 | to_insert = [i for i, tag in enumerate(HTML) if tag in ('
', '>')]
74 | for i, cell in zip(to_insert[::-1], cells[::-1]):
75 | if cell is not None:
76 | cell = [rev_word_map_cells[ind] for ind in cell[1:-1]]
77 | cell = ''.join([escape(token) if len(token) == 1 else token for token in cell])
78 | HTML.insert(i + 1, cell)
79 |
80 | HTML = '''
81 |
82 |
83 |
89 |
90 |
91 |
94 |
95 | ''' % ''.join(HTML)
96 | return HTML
97 |
98 | def tree_convert_html(node, convert_cell=False, parent=None):
99 | ''' Converts HTML tree to the format required by apted
100 | '''
101 | global __tokens__
102 | if node.tag == 'td':
103 | if convert_cell:
104 | __tokens__ = []
105 | tokenize(node)
106 | cell = __tokens__[1:-1].copy()
107 | else:
108 | cell = []
109 | new_node = TableTree(node.tag,
110 | int(node.attrib.get('colspan', '1')),
111 | int(node.attrib.get('rowspan', '1')),
112 | cell, *deque())
113 | else:
114 | new_node = TableTree(node.tag, None, None, None, *deque())
115 | if parent is not None:
116 | parent.children.append(new_node)
117 | if node.tag != 'td':
118 | for n in node.getchildren():
119 | tree_convert_html(n, convert_cell, new_node)
120 | if parent is None:
121 | return new_node
122 |
123 | def similarity_eval_html(pred, true, structure_only=False):
124 | ''' Computes TEDS score between the prediction and the ground truth of a
125 | given samples
126 | '''
127 | if pred.xpath('body/table') and true.xpath('body/table'):
128 | pred = pred.xpath('body/table')[0]
129 | true = true.xpath('body/table')[0]
130 | n_nodes_pred = len(pred.xpath(".//*"))
131 | n_nodes_true = len(true.xpath(".//*"))
132 | tree_pred = tree_convert_html(pred, convert_cell=not structure_only)
133 | tree_true = tree_convert_html(true, convert_cell=not structure_only)
134 | n_nodes = max(n_nodes_pred, n_nodes_true)
135 | distance = APTED(tree_pred, tree_true, CustomConfig()).compute_edit_distance()
136 | return 1.0 - (float(distance) / n_nodes)
137 | else:
138 | return 0.0
139 |
140 | def TEDS_wraper(prediction, ground_truth, filename=None):
141 | if prediction:
142 | return similarity_eval_html(
143 | html.fromstring(prediction),
144 | html.fromstring(ground_truth)
145 | )
146 | else:
147 | return 0.
148 |
149 | def TEDS(gt, pred, n_jobs=8):
150 | ''' Computes TEDS scores for an evaluation set
151 | '''
152 | assert n_jobs > 0 and isinstance(n_jobs, int), 'n_jobs must be positive integer'
153 | inputs = [{'filename': filename, 'prediction': pred.get(filename, ''), 'ground_truth': attributes['html']} for filename, attributes in gt.items()]
154 | scores = parallel_process(inputs, TEDS_wraper, use_kwargs=True, n_jobs=n_jobs, front_num=1)
155 | scores = {i['filename']: score for i, score in zip(inputs, scores)}
156 | return scores
157 |
158 | def html2xml(html_code, out_path):
159 | if not html_code:
160 | return
161 | root = html.fromstring(html_code)
162 | if root.xpath('body/table'):
163 | table = root.xpath('body/table')[0]
164 | cells = []
165 | multi_row_cells = []
166 | row_pt = 0
167 | for row in table.iter('tr'):
168 | row_skip = np.inf
169 | col_pt = 0
170 | for cell in row.getchildren():
171 | # Skip cells expanded from previous rows
172 | multi_row_cells = sorted(multi_row_cells, key=lambda x: x['start-col'])
173 | for c in multi_row_cells:
174 | if 'end-col' in c:
175 | if c['start-row'] <= row_pt <= c['end-row'] and c['start-col'] <= col_pt <= c['end-col']:
176 | col_pt += c['end-col'] - c['start-col'] + 1
177 | else:
178 | if c['start-row'] <= row_pt <= c['end-row'] and c['start-col'] == col_pt:
179 | col_pt += 1
180 | # Generate new cell
181 | new_cell = {'start-row': row_pt,
182 | 'start-col': col_pt,
183 | 'content': html.tostring(cell, method='text', encoding='utf-8').decode('utf-8')}
184 | # Handle multi-row/col cells
185 | if int(cell.attrib.get('colspan', '1')) > 1:
186 | new_cell['end-col'] = col_pt + int(cell.attrib['colspan']) - 1
187 | if int(cell.attrib.get('rowspan', '1')) > 1:
188 | new_cell['end-row'] = row_pt + int(cell.attrib['rowspan']) - 1
189 | multi_row_cells.append(new_cell)
190 | if new_cell['content']:
191 | cells.append(new_cell)
192 | row_skip = min(row_skip, int(cell.attrib.get('rowspan', '1')))
193 | col_pt += int(cell.attrib.get('colspan', '1'))
194 | row_pt += row_skip if not np.isinf(row_skip) else 1
195 | multi_row_cells = [cell for cell in multi_row_cells if row_pt <= cell['end-row']]
196 | with open(out_path, 'w') as fp:
197 | fp.write('\n')
198 | fp.write('\n')
199 | fp.write(' \n')
200 | fp.write(' \n')
201 | for i, cell in enumerate(cells):
202 | attributes = ' '.join(['%s=\'%d\'' % (key, value) for key, value in cell.items() if key != 'content'])
203 | fp.write(' \n' % (i, attributes))
204 | fp.write(' %s\n' % escape(cell['content']))
205 | fp.write(' | \n')
206 | fp.write(' \n')
207 | fp.write(' \n')
208 | fp.write('')
209 |
210 | def relation_metric(pred, gt, thresholds=None):
211 | if thresholds is None:
212 | thresholds = np.linspace(0.6, 0.95, 8)
213 | precisions = []
214 | recalls = []
215 | f1scores = []
216 | for threshold in thresholds:
217 | try:
218 | result = subprocess.check_output(['java', '-jar', 'dataset-tools-fat-lib.jar', '-str', gt, pred, '-threshold%f' % threshold])
219 | result = result.split(b'\n')[-2].decode('utf-8')
220 | try:
221 | precision = float(re.search(r'Precision[^=]*= ([0-9.]*)', result).group(1))
222 | except ValueError:
223 | print(ValueError, file=sys.stderr)
224 | precision = 0.0
225 | try:
226 | recall = float(re.search(r'Recall[^=]*= ([0-9.]*)', result).group(1))
227 | except ValueError:
228 | print(ValueError, file=sys.stderr)
229 | recall = 0.0
230 | f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0. else 0.
231 | precisions.append(precision)
232 | recalls.append(recall)
233 | f1scores.append(f1)
234 | except Exception as e:
235 | print(os.path.basename(pred), file=sys.stderr)
236 | print(e, file=sys.stderr)
237 | precisions.append(0.)
238 | recalls.append(0.)
239 | f1scores.append(0.)
240 | return np.mean(precisions), np.mean(recalls), np.mean(f1scores)
241 |
242 |
243 | if __name__ == '__main__':
244 | from paramiko import SSHClient
245 |
246 | html_pred = '/Users/peterzhong/Downloads/table2html/Tag+Cell/PMC5059900_003_02.html'
247 | with open(html_pred, 'r') as fp:
248 | pred = html.parse(fp).getroot()
249 | filename = os.path.basename(html_pred).split('.')[0]
250 |
251 | ssh = SSHClient()
252 | ssh.load_system_host_keys()
253 | ssh.connect('dccxl003.pok.ibm.com', username='peterz')
254 | sftp_client = ssh.open_sftp()
255 | with sftp_client.open('/dccstor/ddig/peter/Medline_paper_annotator/data/table_norm/htmls/%s.html' % (filename)) as remote_file:
256 | true = html.parse(remote_file).getroot()
257 | true_table = html.Element("table")
258 | for n in true.xpath('body')[0].getchildren():
259 | true_table.append(n)
260 | true.xpath('body')[0].append(true_table)
261 | print(similarity_eval_html(pred, true))
262 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | '''
2 | Implementation of encoder-dual-decoder model
3 | '''
4 | import torch
5 | from torch import nn
6 | import torchvision
7 | from torchvision.models.resnet import BasicBlock, conv1x1
8 | import torch.nn.functional as F
9 | from torch.nn.utils.rnn import pack_padded_sequence
10 | from utils import *
11 | import time
12 | import sys
13 |
14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15 |
16 | def resnet_block(stride=1):
17 | layers = []
18 | downsample = nn.Sequential(
19 | conv1x1(256, 512, stride),
20 | nn.BatchNorm2d(512),
21 | )
22 | layers.append(BasicBlock(256, 512, stride, downsample))
23 | layers.append(BasicBlock(512, 512, 1))
24 | return nn.Sequential(*layers)
25 |
26 | def repackage_hidden(h):
27 | """Wraps hidden states in new Tensors, to detach them from their history."""
28 | if isinstance(h, torch.Tensor):
29 | return h.detach()
30 | else:
31 | return tuple(repackage_hidden(v) for v in h)
32 |
33 | class Encoder(nn.Module):
34 | """
35 | Encoder.
36 | """
37 |
38 | def __init__(self, encoded_image_size=14, use_RNN=False, rnn_size=512, last_layer_stride=2):
39 | super(Encoder, self).__init__()
40 | self.enc_image_size = encoded_image_size
41 | self.use_RNN = use_RNN
42 | self.rnn_size = rnn_size
43 |
44 | resnet = torchvision.models.resnet18(pretrained=False) # ImageNet ResNet-18
45 |
46 | # Remove linear and pool layers (since we're not doing classification)
47 | # Also remove the last CNN layer for higher resolution feature map
48 | modules = list(resnet.children())[:-3]
49 | if last_layer_stride is not None:
50 | modules.append(resnet_block(stride=last_layer_stride))
51 |
52 | # Change stride of max pooling layer for higher resolution feature map
53 | # modules[3] = nn.MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
54 |
55 | self.resnet = nn.Sequential(*modules)
56 |
57 | # Resize image to fixed size to allow input images of variable size
58 | self.adaptive_pool = nn.AdaptiveAvgPool2d((self.enc_image_size, self.enc_image_size))
59 |
60 | if self.use_RNN:
61 | self.RNN = nn.LSTM(512, self.rnn_size, bias=True, batch_first=True) # LSTM that transforms the image features
62 | self.init_h = nn.Linear(512, self.rnn_size) # linear layer to find initial hidden state of LSTM
63 | self.init_c = nn.Linear(512, self.rnn_size) # linear layer to find initial cell state of LSTM
64 | self.fine_tune()
65 |
66 | def init_hidden_state(self, encoder_out):
67 | """
68 | Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images.
69 |
70 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
71 | :return: hidden state, cell state
72 | """
73 | mean_encoder_out = encoder_out.mean(dim=1)
74 | h = self.init_h(mean_encoder_out).unsqueeze(0) # (batch_size*encoded_image_size, rnn_size)
75 | c = self.init_c(mean_encoder_out).unsqueeze(0)
76 | return h, c
77 |
78 | def forward(self, images):
79 | """
80 | Forward propagation.
81 |
82 | :param images: images, a tensor of dimensions (batch_size, 3, image_size, image_size)
83 | :return: encoded images
84 | """
85 | batch_size = images.size(0)
86 | out = self.resnet(images) # (batch_size, 512, image_size/32, image_size/32)
87 | out = self.adaptive_pool(out) # (batch_size, 512, encoded_image_size, encoded_image_size)
88 | out = out.permute(0, 2, 3, 1) # (batch_size, encoded_image_size, encoded_image_size, 512)
89 | if self.use_RNN:
90 | out = out.contiguous().view(-1, self.enc_image_size, 512) # (batch_size*encoded_image_size, encoded_image_size, 512)
91 | h = self.init_hidden_state(out)
92 | out, h = self.RNN(out, h) # (batch_size*encoded_image_size, encoded_image_size, 512)
93 | out = out.view(batch_size, self.enc_image_size, self.enc_image_size, self.rnn_size).contiguous()
94 | return out
95 |
96 | def fine_tune(self, fine_tune=True):
97 | """
98 | Allow or prevent the computation of gradients for convolutional blocks 2 through 4 of the encoder.
99 |
100 | :param fine_tune: Allow?
101 | """
102 | for p in self.resnet.parameters():
103 | p.requires_grad = fine_tune
104 |
105 | class Attention(nn.Module):
106 | """
107 | Attention Network.
108 | """
109 |
110 | def __init__(self, encoder_dim, decoder_dim, attention_dim):
111 | """
112 | :param encoder_dim: feature size of encoded images
113 | :param decoder_dim: size of decoder's RNN
114 | :param attention_dim: size of the attention network
115 | """
116 | super(Attention, self).__init__()
117 | self.encoder_att = nn.Linear(encoder_dim, attention_dim) # linear layer to transform encoded image
118 | self.decoder_att = nn.Linear(decoder_dim, attention_dim) # linear layer to transform decoder's output
119 | self.full_att = nn.Linear(attention_dim, 1) # linear layer to calculate values to be softmax-ed
120 | self.relu = nn.ReLU()
121 | self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights
122 |
123 | def forward(self, encoder_out, decoder_hidden):
124 | """
125 | Forward propagation.
126 |
127 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
128 | :param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim)
129 | :return: attention weighted encoding, weights
130 | """
131 | att1 = self.encoder_att(encoder_out) # (batch_size, num_pixels, attention_dim)
132 | att2 = self.decoder_att(decoder_hidden) # (batch_size, attention_dim)
133 | att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2) # (batch_size, num_pixels)
134 | alpha = self.softmax(att) # (batch_size, num_pixels)
135 | attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1) # (batch_size, encoder_dim)
136 |
137 | return attention_weighted_encoding, alpha
138 |
139 | class CellAttention(nn.Module):
140 | """
141 | Attention Network.
142 | """
143 |
144 | def __init__(self, encoder_dim, tag_decoder_dim, language_dim, attention_dim):
145 | """
146 | :param encoder_dim: feature size of encoded images
147 | :param tag_decoder_dim: size of tag decoder's RNN
148 | :param language_dim: size of language model's RNN
149 | :param attention_dim: size of the attention network
150 | """
151 | super(CellAttention, self).__init__()
152 | self.encoder_att = nn.Linear(encoder_dim, attention_dim) # linear layer to transform encoded image
153 | self.tag_decoder_att = nn.Linear(tag_decoder_dim, attention_dim) # linear layer to transform tag decoder output
154 | self.language_att = nn.Linear(language_dim, attention_dim) # linear layer to transform language models output
155 | self.full_att = nn.Linear(attention_dim, 1) # linear layer to calculate values to be softmax-ed
156 | self.relu = nn.ReLU()
157 | self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights
158 |
159 | def forward(self, encoder_out, decoder_hidden, language_out):
160 | """
161 | Forward propagation.
162 |
163 | :param encoder_out: encoded images, a tensor of dimension (1, num_pixels, encoder_dim)
164 | :param decoder_hidden: tag decoder output, a tensor of dimension [(num_cells, tag_decoder_dim)]
165 | :param language_out: language model output, a tensor of dimension (num_cells, language_dim)
166 | :return: attention weighted encoding, weights
167 | """
168 | att1 = self.encoder_att(encoder_out) # (1, num_pixels, attention_dim)
169 | att2 = self.tag_decoder_att(decoder_hidden) # (num_cells, tag_decoder_dim)
170 | att3 = self.language_att(language_out) # (num_cells, attention_dim)
171 | att = self.full_att(self.relu(att1 + att2.unsqueeze(1) + att3.unsqueeze(1))).squeeze(2) # (num_cells, num_pixels)
172 | alpha = self.softmax(att) # (num_cells, num_pixels)
173 | attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1) # (num_cells, encoder_dim)
174 |
175 | return attention_weighted_encoding, alpha
176 |
177 | class DecoderWithAttention(nn.Module):
178 | """
179 | Decoder.
180 | """
181 |
182 | def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, decoder_cell=nn.LSTMCell, encoder_dim=512, dropout=0.5):
183 | """
184 | :param attention_dim: size of attention network
185 | :param embed_dim: embedding size
186 | :param decoder_dim: size of decoder's RNN
187 | :param vocab_size: size of vocabulary
188 | :param encoder_dim: feature size of encoded images
189 | :param dropout: dropout
190 | """
191 | super(DecoderWithAttention, self).__init__()
192 |
193 | assert decoder_cell.__name__ in ('GRUCell', 'LSTMCell'), 'decoder_cell must be either nn.LSTMCell or nn.GRUCell'
194 | self.encoder_dim = encoder_dim
195 | self.attention_dim = attention_dim
196 | self.embed_dim = embed_dim
197 | self.decoder_dim = decoder_dim
198 | self.vocab_size = vocab_size
199 | self.dropout = dropout
200 |
201 | self.attention = Attention(encoder_dim, decoder_dim, attention_dim) # attention network
202 |
203 | self.embedding = nn.Embedding(vocab_size, embed_dim) # embedding layer
204 | self.dropout = nn.Dropout(p=self.dropout)
205 | self.decode_step = decoder_cell(embed_dim + encoder_dim, decoder_dim, bias=True) # decoding LSTMCell
206 | self.init_h = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial hidden state of LSTMCell
207 | if isinstance(self.decode_step, nn.LSTMCell):
208 | self.init_c = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial cell state of LSTMCell
209 | self.f_beta = nn.Linear(decoder_dim, encoder_dim) # linear layer to create a sigmoid-activated gate
210 | self.sigmoid = nn.Sigmoid()
211 | self.fc = nn.Linear(decoder_dim, vocab_size) # linear layer to find scores over vocabulary
212 | self.init_weights() # initialize some layers with the uniform distribution
213 |
214 | def init_weights(self):
215 | """
216 | Initializes some parameters with values from the uniform distribution, for easier convergence.
217 | """
218 | self.embedding.weight.data.uniform_(-0.1, 0.1)
219 | self.fc.bias.data.fill_(0)
220 | self.fc.weight.data.uniform_(-0.1, 0.1)
221 |
222 | def load_pretrained_embeddings(self, embeddings):
223 | """
224 | Loads embedding layer with pre-trained embeddings.
225 |
226 | :param embeddings: pre-trained embeddings
227 | """
228 | self.embedding.weight = nn.Parameter(embeddings)
229 |
230 | def fine_tune_embeddings(self, fine_tune=True):
231 | """
232 | Allow fine-tuning of embedding layer? (Only makes sense to not-allow if using pre-trained embeddings).
233 |
234 | :param fine_tune: Allow?
235 | """
236 | for p in self.embedding.parameters():
237 | p.requires_grad = fine_tune
238 |
239 | def init_hidden_state(self, encoder_out):
240 | """
241 | Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images.
242 |
243 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
244 | :return: hidden state, cell state
245 | """
246 | mean_encoder_out = encoder_out.mean(dim=1)
247 | h = self.init_h(mean_encoder_out) # (batch_size, decoder_dim)
248 | if isinstance(self.decode_step, nn.LSTMCell):
249 | c = self.init_c(mean_encoder_out)
250 | return h, c
251 | else:
252 | return h
253 |
254 | def inference(self, encoder_out, word_map, max_steps=400, beam_size=5, return_attention=False):
255 | """
256 | Inference on test images with beam search
257 | """
258 | enc_image_size = encoder_out.size(1)
259 | encoder_dim = encoder_out.size(3)
260 |
261 | # Flatten encoding
262 | encoder_out = encoder_out.view(1, -1, encoder_dim) # (1, num_pixels, encoder_dim)
263 | num_pixels = encoder_out.size(1)
264 |
265 | k = beam_size
266 | vocab_size = len(word_map)
267 |
268 | # We'll treat the problem as having a batch size of k
269 | encoder_out = encoder_out.expand(k, num_pixels, encoder_dim) # (k, num_pixels, encoder_dim)
270 |
271 | # Tensor to store top k previous words at each step; now they're just
272 | k_prev_words = torch.LongTensor([[word_map['']]] * k).to(device) # (k, 1)
273 |
274 | # Tensor to store top k sequences; now they're just
275 | seqs = k_prev_words # (k, 1)
276 |
277 | # Tensor to store top k sequences' scores; now they're just 0
278 | top_k_scores = torch.zeros(k, 1).to(device) # (k, 1)
279 |
280 | # Tensor to store top k sequences' alphas; now they're just 1s
281 | if return_attention:
282 | seqs_alpha = torch.ones(k, 1, enc_image_size, enc_image_size).to(device) # (k, 1, enc_image_size, enc_image_size)
283 |
284 | # Lists to store completed sequences, their alphas and scores
285 | complete_seqs = list()
286 | if return_attention:
287 | complete_seqs_alpha = list()
288 | complete_seqs_scores = list()
289 |
290 | # Start decoding
291 | step = 1
292 |
293 | if isinstance(self.decode_step, nn.LSTMCell):
294 | h, c = self.init_hidden_state(encoder_out)
295 | else:
296 | h = self.init_hidden_state(encoder_out)
297 |
298 | # s is a number less than or equal to k, because sequences are removed from this process once they hit
299 | while True:
300 | embeddings = self.embedding(k_prev_words).squeeze(1) # (s, embed_dim)
301 | if return_attention:
302 | awe, alpha = self.attention(encoder_out, h) # (s, encoder_dim), (s, num_pixels)
303 | alpha = alpha.view(-1, enc_image_size, enc_image_size) # (s, enc_image_size, enc_image_size)
304 | else:
305 | awe, _ = self.attention(encoder_out, h) # (s, encoder_dim), (s, num_pixels)
306 |
307 | gate = self.sigmoid(self.f_beta(h)) # gating scalar, (s, encoder_dim)
308 | awe = gate * awe
309 |
310 | if isinstance(self.decode_step, nn.LSTMCell):
311 | h, c = self.decode_step(torch.cat([embeddings, awe], dim=1), (h, c)) # (s, decoder_dim)
312 | else:
313 | h = self.decode_step(torch.cat([embeddings, awe], dim=1), h) # (s, decoder_dim)
314 |
315 | h = repackage_hidden(h)
316 | if isinstance(self.decode_step, nn.LSTMCell):
317 | c = repackage_hidden(c)
318 |
319 | scores = self.fc(h) # (s, vocab_size)
320 | scores = F.log_softmax(scores, dim=1)
321 |
322 | # Add
323 | scores = top_k_scores.expand_as(scores) + scores # (s, vocab_size)
324 |
325 | # For the first step, all k points will have the same scores (since same k previous words, h, c)
326 | if step == 1:
327 | top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s)
328 | else:
329 | # Unroll and find top scores, and their unrolled indices
330 | top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s)
331 |
332 | # Convert unrolled indices to actual indices of scores
333 | prev_word_inds = top_k_words / vocab_size # (s)
334 | next_word_inds = top_k_words % vocab_size # (s)
335 |
336 | # Add new words to sequences, alphas
337 | seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1)
338 | if return_attention:
339 | seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)],
340 | dim=1) # (s, step+1, enc_image_size, enc_image_size)
341 |
342 | # Which sequences are incomplete (didn't reach )?
343 | incomplete_inds = []
344 | complete_inds = []
345 | for ind, next_word in enumerate(next_word_inds):
346 | if next_word == word_map['']:
347 | complete_inds.append(ind)
348 | else:
349 | incomplete_inds.append(ind)
350 |
351 | # Set aside complete sequences
352 | if len(complete_inds) > 0:
353 | complete_seqs.extend(seqs[complete_inds].tolist())
354 | if return_attention:
355 | complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist())
356 | complete_seqs_scores.extend(top_k_scores[complete_inds])
357 | k -= len(complete_inds) # reduce beam length accordingly
358 |
359 | # Proceed with incomplete sequences
360 | if k == 0:
361 | break
362 |
363 | # Break if things have been going on too long
364 | if step > max_steps:
365 | # If no complete sequence is generated, finish the incomplete
366 | # sequences with
367 | if not complete_seqs_scores:
368 | complete_seqs = seqs.tolist()
369 | for i in range(len(complete_seqs)):
370 | complete_seqs[i].append(word_map[''])
371 | if return_attention:
372 | complete_seqs_alpha = seqs_alpha.tolist()
373 | complete_seqs_scores = top_k_scores.tolist()
374 | break
375 |
376 | seqs = seqs[incomplete_inds]
377 | if return_attention:
378 | seqs_alpha = seqs_alpha[incomplete_inds]
379 | h = h[prev_word_inds[incomplete_inds]]
380 | if isinstance(self.decode_step, nn.LSTMCell):
381 | c = c[prev_word_inds[incomplete_inds]]
382 | encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
383 | top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
384 | k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)
385 |
386 | step += 1
387 | i = complete_seqs_scores.index(max(complete_seqs_scores))
388 | seq = complete_seqs[i]
389 | if return_attention:
390 | alphas = complete_seqs_alpha[i]
391 | return seq, alphas
392 | else:
393 | return seq
394 |
395 | def forward(self, encoder_out, encoded_captions, caption_lengths, h, c=None, begin_tokens=None):
396 | """
397 | Forward propagation.
398 |
399 | :param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim)
400 | :param encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length)
401 | :param caption_lengths: caption lengths, a tensor of dimension (batch_size, 1)
402 | :return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices
403 | """
404 | batch_size = encoder_out.size(0)
405 | encoder_dim = encoder_out.size(-1)
406 | vocab_size = self.vocab_size
407 |
408 | # Flatten image
409 | num_pixels = encoder_out.size(1)
410 |
411 | if begin_tokens is None:
412 | # Embedding
413 | embeddings = self.embedding(encoded_captions) # (batch_size, max_caption_length, embed_dim)
414 | # We won't decode at the position, since we've finished generating as soon as we generate
415 | # So, decoding lengths are actual lengths - 1
416 | decode_lengths = (caption_lengths - 1).tolist()
417 | else: # For TBPTT, use the end token of the previous sub-sequence as begin token instead of
418 | embeddings = torch.cat([self.embedding(begin_tokens), self.embedding(encoded_captions)], dim=1)
419 | decode_lengths = caption_lengths.tolist()
420 |
421 | # Create tensors to hold word predicion scores and alphas
422 | predictions = torch.zeros(batch_size, decode_lengths[0], vocab_size).to(device)
423 | alphas = torch.zeros(batch_size, decode_lengths[0], num_pixels).to(device)
424 |
425 | # At each time-step, decode by
426 | # attention-weighing the encoder's output based on the decoder's previous hidden state output
427 | # then generate a new word in the decoder with the previous word and the attention weighted encoding
428 | for t in range(decode_lengths[0]):
429 | batch_size_t = sum([l > t for l in decode_lengths])
430 | attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t],
431 | h[:batch_size_t])
432 | gate = self.sigmoid(self.f_beta(h[:batch_size_t])) # gating scalar, (batch_size_t, encoder_dim)
433 | attention_weighted_encoding = gate * attention_weighted_encoding
434 | if isinstance(self.decode_step, nn.LSTMCell):
435 | h, c = self.decode_step(
436 | torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
437 | (h[:batch_size_t], c[:batch_size_t])) # (batch_size_t, decoder_dim)
438 | else:
439 | h = self.decode_step(
440 | torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
441 | h[:batch_size_t]) # (batch_size_t, decoder_dim)
442 | predictions[:batch_size_t, t, :] = self.fc(self.dropout(h)) # (batch_size_t, vocab_size)
443 | alphas[:batch_size_t, t, :] = alpha
444 |
445 | return predictions, decode_lengths, alphas, h, c
446 |
447 | def train_epoch(self, train_loader, encoder, criterion, encoder_optimizer, decoder_optimizer, epoch, args, step=None):
448 | """
449 | Performs one epoch's training.
450 |
451 | :param train_loader: DataLoader for training data
452 | :param encoder: encoder model
453 | :param decoder: decoder model
454 | :param criterion: loss layer
455 | :param encoder_optimizer: optimizer to update encoder's weights (if fine-tuning)
456 | :param decoder_optimizer: optimizer to update decoder's weights
457 | :param epoch: epoch number
458 | """
459 |
460 | self.train() # train mode (dropout and batchnorm is used)
461 | encoder.train()
462 |
463 | batch_time = AverageMeter() # forward prop. + back prop. time
464 | data_time = AverageMeter() # data loading time
465 | losses = AverageMeter() # loss (per word decoded)
466 | top1accs = AverageMeter() # top1 accuracy
467 |
468 | start = time.time()
469 |
470 | # Batches
471 | train_loader.shuffle()
472 | for i, (imgs, caps_sorted, caplens) in enumerate(train_loader):
473 | if step is not None:
474 | if i <= step:
475 | continue
476 | data_time.update(time.time() - start)
477 |
478 | # Move to GPU, if available
479 | imgs = imgs.to(device)
480 | caps_sorted = caps_sorted.to(device)
481 | caplens = caplens.to(device)
482 |
483 | # Forward prop.
484 | imgs = encoder(imgs)
485 | # Flatten image
486 | batch_size = imgs.size(0)
487 | encoder_dim = imgs.size(-1)
488 | imgs = imgs.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim)
489 | caplens = caplens.squeeze(1)
490 |
491 | # Sort input data by decreasing lengths
492 | # caplens, sort_ind = caplens.squeeze(1).sort(dim=0, descending=True)
493 | # imgs = imgs[sort_ind]
494 | # caps_sorted = caps[sort_ind]
495 |
496 | # Initialize LSTM state
497 | if isinstance(self.decode_step, nn.LSTMCell):
498 | h, c = self.init_hidden_state(imgs) # (batch_size, decoder_dim)
499 | else:
500 | h = self.init_hidden_state(imgs) # (batch_size, decoder_dim)
501 | c = None
502 |
503 | max_cap_length = max(caplens.tolist())
504 | # TBPTT
505 | j = 0
506 | while j < max_cap_length:
507 | if j == 0:
508 | # bptt tokens after
509 | sub_seq_len = min(args.bptt + 1, max_cap_length - j)
510 | else:
511 | sub_seq_len = min(args.bptt, max_cap_length - j)
512 | # Do not leave too short tails (less than 10 tokens)
513 | short_tail = (caplens - (j + sub_seq_len) < 10) & (caplens - (j + sub_seq_len) > 0)
514 | if short_tail.any():
515 | sub_seq_len += max((caplens - (j + sub_seq_len))[short_tail].tolist())
516 |
517 | sub_seq_caplens = caplens - j
518 | sub_seq_caplens[sub_seq_caplens > sub_seq_len] = sub_seq_len
519 | batch_size_t = (sub_seq_caplens > 0).sum().item()
520 | sub_seq_caplens = sub_seq_caplens[:batch_size_t]
521 | sub_seq_cap = caps_sorted[:batch_size_t, j:j + sub_seq_len]
522 |
523 | h = repackage_hidden(h)
524 | if isinstance(self.decode_step, nn.LSTMCell):
525 | c = repackage_hidden(c)
526 |
527 | decoder_optimizer.zero_grad()
528 | if encoder_optimizer is not None:
529 | encoder_optimizer.zero_grad()
530 | if j == 0:
531 | scores, decode_lengths, alphas, h, c = self(
532 | imgs[:batch_size_t],
533 | sub_seq_cap,
534 | sub_seq_caplens,
535 | h,
536 | c)
537 | # Since we decoded starting with , the targets are all words after , up to
538 | targets = sub_seq_cap[:, 1:]
539 | else:
540 | scores, decode_lengths, alphas, h, c = self(
541 | imgs[:batch_size_t],
542 | sub_seq_cap,
543 | sub_seq_caplens,
544 | h,
545 | c,
546 | caps_sorted[:batch_size_t, j - 1].unsqueeze(1))
547 | targets = sub_seq_cap
548 |
549 | # Remove timesteps that we didn't decode at, or are pads
550 | # pack_padded_sequence is an easy trick to do this
551 | scores = pack_padded_sequence(scores, decode_lengths, batch_first=True)[0]
552 | targets = pack_padded_sequence(targets, decode_lengths, batch_first=True)[0]
553 |
554 | # Calculate loss
555 | loss = criterion(scores, targets)
556 |
557 | # Add doubly stochastic attention regularization
558 | loss += args.alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()
559 |
560 | # Back prop.
561 | if j + sub_seq_len < max_cap_length:
562 | loss.backward(retain_graph=True)
563 | else:
564 | loss.backward()
565 |
566 | # Clip gradients
567 | if args.grad_clip is not None:
568 | clip_gradient(decoder_optimizer, args.grad_clip)
569 | if encoder_optimizer is not None:
570 | clip_gradient(encoder_optimizer, args.grad_clip)
571 |
572 | # Update weights
573 | decoder_optimizer.step()
574 | if encoder_optimizer is not None:
575 | encoder_optimizer.step()
576 |
577 | # Keep track of metrics
578 | top1 = accuracy(scores, targets, 1)
579 | losses.update(loss.item(), sum(decode_lengths))
580 | top1accs.update(top1, sum(decode_lengths))
581 | j += sub_seq_len
582 | batch_time.update(time.time() - start)
583 |
584 | start = time.time()
585 |
586 | # Print status
587 | if i % args.print_freq == 0:
588 | print('Epoch: [{0}][{1}/{2}]\t'
589 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
590 | 'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
591 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
592 | 'Top-1 Accuracy {top1.val:.3f} ({top1.avg:.3f})'.format(epoch, i, len(train_loader),
593 | batch_time=batch_time,
594 | data_time=data_time, loss=losses,
595 | top1=top1accs), file=sys.stderr)
596 | sys.stderr.flush()
597 |
598 | class DecoderWithAttentionAndLanguageModel(nn.Module):
599 | '''
600 | Stacked 2-layer LSTM with Attention model. First LSTM is a languange model, second LSTM is a decoder.
601 | See "Recursive Recurrent Nets with Attention Modeling for OCR in the Wild"
602 | '''
603 | def __init__(self, attention_dim, embed_dim, language_dim, decoder_dim, vocab_size,
604 | decoder_cell=nn.LSTMCell, encoder_dim=512, dropout=0.5):
605 | """
606 | :param attention_dim: size of attention network
607 | :param embed_dim: embedding size
608 | :param language_dim: size of language model's RNN
609 | :param decoder_dim: size of decoder's RNN
610 | :param vocab_size: size of vocabulary
611 | :param encoder_dim: feature size of encoded images
612 | :param dropout: dropout
613 | """
614 | super(DecoderWithAttentionAndLanguageModel, self).__init__()
615 |
616 | self.encoder_dim = encoder_dim
617 | self.attention_dim = attention_dim
618 | self.embed_dim = embed_dim
619 | self.language_dim = language_dim
620 | self.decoder_dim = decoder_dim
621 | self.vocab_size = vocab_size
622 | self.dropout = dropout
623 |
624 | self.attention = Attention(encoder_dim, language_dim, attention_dim) # attention network
625 |
626 | self.embedding = nn.Embedding(vocab_size, embed_dim) # embedding layer
627 |
628 | self.decode_step_LM = decoder_cell(embed_dim, language_dim, bias=True) # language model LSTMCell
629 |
630 | self.decode_step_pred = decoder_cell(encoder_dim, decoder_dim, bias=True) # decoding LSTMCell
631 | self.init_h = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial hidden state of LSTMCell
632 | if isinstance(self.decode_step_pred, nn.LSTMCell):
633 | self.init_c = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial cell state of LSTMCell
634 | self.f_beta = nn.Linear(decoder_dim, encoder_dim) # linear layer to create a sigmoid-activated gate
635 | self.sigmoid = nn.Sigmoid()
636 | self.fc = nn.Linear(decoder_dim, vocab_size) # linear layer to find scores over vocabulary
637 | self.dropout = nn.Dropout(p=self.dropout)
638 | self.init_weights() # initialize some layers with the uniform distribution
639 |
640 | def init_weights(self):
641 | """
642 | Initializes some parameters with values from the uniform distribution, for easier convergence.
643 | """
644 | self.embedding.weight.data.uniform_(-0.1, 0.1)
645 | self.fc.bias.data.fill_(0)
646 | self.fc.weight.data.uniform_(-0.1, 0.1)
647 |
648 | def init_hidden_state(self, encoder_out):
649 | """
650 | Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images.
651 |
652 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
653 | :return: hidden state, cell state
654 | """
655 | batch_size = encoder_out.size(0)
656 | mean_encoder_out = encoder_out.mean(dim=1)
657 | h_LM = torch.zeros(batch_size, self.language_dim).to(device)
658 | h_pred = self.init_h(mean_encoder_out) # (batch_size, decoder_dim)
659 | if isinstance(self.decode_step_pred, nn.LSTMCell):
660 | c_LM = torch.zeros(batch_size, self.language_dim).to(device)
661 | c_pred = self.init_c(mean_encoder_out)
662 | return h_LM, c_LM, h_pred, c_pred
663 | else:
664 | return h_LM, h_pred
665 |
666 | def inference(self, encoder_out, word_map, max_steps=400, beam_size=5):
667 | """
668 | Inference on test images with beam search
669 | """
670 | enc_image_size = encoder_out.size(1)
671 | encoder_dim = encoder_out.size(3)
672 |
673 | # Flatten encoding
674 | encoder_out = encoder_out.view(1, -1, encoder_dim) # (1, num_pixels, encoder_dim)
675 | num_pixels = encoder_out.size(1)
676 |
677 | k = beam_size
678 | vocab_size = len(word_map)
679 |
680 | # We'll treat the problem as having a batch size of k
681 | encoder_out = encoder_out.expand(k, num_pixels, encoder_dim) # (k, num_pixels, encoder_dim)
682 |
683 | # Tensor to store top k previous words at each step; now they're just
684 | k_prev_words = torch.LongTensor([[word_map['']]] * k).to(device) # (k, 1)
685 |
686 | # Tensor to store top k sequences; now they're just
687 | seqs = k_prev_words # (k, 1)
688 |
689 | # Tensor to store top k sequences' scores; now they're just 0
690 | top_k_scores = torch.zeros(k, 1).to(device) # (k, 1)
691 |
692 | # Tensor to store top k sequences' alphas; now they're just 1s
693 | seqs_alpha = torch.ones(k, 1, enc_image_size, enc_image_size).to(device) # (k, 1, enc_image_size, enc_image_size)
694 |
695 | # Lists to store completed sequences, their alphas and scores
696 | complete_seqs = list()
697 | complete_seqs_alpha = list()
698 | complete_seqs_scores = list()
699 |
700 | # Start decoding
701 | step = 1
702 | if isinstance(self.decode_step_pred, nn.LSTMCell):
703 | h_LM, c_LM, h_cell, c_cell = self.init_hidden_state(encoder_out)
704 | else:
705 | h_LM, h_cell = self.init_hidden_state(encoder_out)
706 |
707 | # s is a number less than or equal to k, because sequences are removed from this process once they hit
708 | while True:
709 | embeddings = self.embedding(k_prev_words).squeeze(1) # (s, embed_dim)
710 |
711 | if isinstance(self.decode_step_LM, nn.LSTMCell):
712 | h_LM, c_LM = self.decode_step_LM(embeddings, (h_LM, c_LM))
713 | else:
714 | h_LM = self.decode_step_LM(embeddings, h_LM)
715 | awe, alpha = self.attention(encoder_out, h_LM)
716 |
717 | alpha = alpha.view(-1, enc_image_size, enc_image_size) # (s, enc_image_size, enc_image_size)
718 |
719 | gate = self.sigmoid(self.f_beta(h_cell)) # gating scalar, (s, encoder_dim)
720 | awe = gate * awe
721 |
722 | if isinstance(self.decode_step_pred, nn.LSTMCell):
723 | h_cell, c_cell = self.decode_step_pred(awe, (h_cell, c_cell)) # (batch_size_t, decoder_dim)
724 | else:
725 | h_cell = self.decode_step_pred(awe, h_cell) # (batch_size_t, decoder_dim)
726 |
727 | scores = self.fc(h_cell) # (s, vocab_size)
728 | scores = F.log_softmax(scores, dim=1)
729 |
730 | # Add
731 | scores = top_k_scores.expand_as(scores) + scores # (s, vocab_size)
732 |
733 | # For the first step, all k points will have the same scores (since same k previous words, h, c)
734 | if step == 1:
735 | top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s)
736 | else:
737 | # Unroll and find top scores, and their unrolled indices
738 | top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s)
739 |
740 | # Convert unrolled indices to actual indices of scores
741 | prev_word_inds = top_k_words / vocab_size # (s)
742 | next_word_inds = top_k_words % vocab_size # (s)
743 |
744 | # Add new words to sequences, alphas
745 | seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1)
746 | seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)],
747 | dim=1) # (s, step+1, enc_image_size, enc_image_size)
748 |
749 | # Which sequences are incomplete (didn't reach )?
750 | incomplete_inds = []
751 | complete_inds = []
752 | for ind, next_word in enumerate(next_word_inds):
753 | if next_word == word_map['']:
754 | complete_inds.append(ind)
755 | else:
756 | incomplete_inds.append(ind)
757 |
758 | # Set aside complete sequences
759 | if len(complete_inds) > 0:
760 | complete_seqs.extend(seqs[complete_inds].tolist())
761 | complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist())
762 | complete_seqs_scores.extend(top_k_scores[complete_inds])
763 | k -= len(complete_inds) # reduce beam length accordingly
764 |
765 | # Proceed with incomplete sequences
766 | if k == 0:
767 | break
768 | seqs = seqs[incomplete_inds]
769 | seqs_alpha = seqs_alpha[incomplete_inds]
770 | h_LM = h_LM[prev_word_inds[incomplete_inds]]
771 | h_cell = h_cell[prev_word_inds[incomplete_inds]]
772 | if isinstance(self.decode_step_pred, nn.LSTMCell):
773 | c_LM = c_LM[prev_word_inds[incomplete_inds]]
774 | c_cell = c_cell[prev_word_inds[incomplete_inds]]
775 | encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
776 | top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
777 | k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)
778 |
779 | # Break if things have been going on too long
780 | if step > max_steps:
781 | break
782 | step += 1
783 | if complete_seqs_scores:
784 | i = complete_seqs_scores.index(max(complete_seqs_scores))
785 | seq = complete_seqs[i]
786 | alphas = complete_seqs_alpha[i]
787 | return seq, alphas
788 | else:
789 | return None
790 |
791 | def forward(self, encoder_out, encoded_captions, caption_lengths, h_LM, h_pred, c_LM=None, c_pred=None, begin_tokens=None):
792 | """
793 | Forward propagation.
794 |
795 | :param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim)
796 | :param encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length)
797 | :param caption_lengths: caption lengths, a tensor of dimension (batch_size, 1)
798 | :return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices
799 | """
800 | batch_size = encoder_out.size(0)
801 | encoder_dim = encoder_out.size(-1)
802 | vocab_size = self.vocab_size
803 |
804 | # Flatten image
805 | num_pixels = encoder_out.size(1)
806 |
807 | if begin_tokens is None:
808 | # Embedding
809 | embeddings = self.embedding(encoded_captions) # (batch_size, max_caption_length, embed_dim)
810 | # We won't decode at the position, since we've finished generating as soon as we generate
811 | # So, decoding lengths are actual lengths - 1
812 | decode_lengths = (caption_lengths - 1).tolist()
813 | else: # For TBPTT, use the end token of the previous sub-sequence as begin token instead of
814 | embeddings = torch.cat([self.embedding(begin_tokens), self.embedding(encoded_captions)], dim=1)
815 | decode_lengths = caption_lengths.tolist()
816 |
817 | # Create tensors to hold word predicion scores and alphas
818 | predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device)
819 | alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device)
820 |
821 | # At each time-step, decode by
822 | # attention-weighing the encoder's output based on the decoder's previous hidden state output
823 | # then generate a new word in the decoder with the previous word and the attention weighted encoding
824 | for t in range(max(decode_lengths)):
825 | batch_size_t = sum([l > t for l in decode_lengths])
826 |
827 | # Language LSTM
828 | if isinstance(self.decode_step_LM, nn.LSTMCell):
829 | h_LM, c_LM = self.decode_step_LM(
830 | embeddings[:batch_size_t, t, :],
831 | (h_LM[:batch_size_t], c_LM[:batch_size_t])) # (batch_size_t, decoder_dim)
832 | else:
833 | h_LM = self.decode_step_LM(
834 | embeddings[:batch_size_t, t, :],
835 | h_LM[:batch_size_t]) # (batch_size_t, decoder_dim)
836 |
837 | # Attention
838 | attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t],
839 | h_LM)
840 |
841 | # Decoder LSTM
842 | gate = self.sigmoid(self.f_beta(h_pred[:batch_size_t])) # gating scalar, (batch_size_t, encoder_dim)
843 | attention_weighted_encoding = gate * attention_weighted_encoding
844 | if isinstance(self.decode_step_pred, nn.LSTMCell):
845 | h_pred, c_pred = self.decode_step_pred(
846 | attention_weighted_encoding,
847 | (h_pred[:batch_size_t], c_pred[:batch_size_t])) # (batch_size_t, decoder_dim)
848 | else:
849 | h_pred = self.decode_step_pred(
850 | attention_weighted_encoding,
851 | h_pred[:batch_size_t]) # (batch_size_t, decoder_dim)
852 | predictions[:batch_size_t, t, :] = self.fc(self.dropout(h_pred)) # (batch_size_t, vocab_size)
853 | alphas[:batch_size_t, t, :] = alpha
854 |
855 | return predictions, decode_lengths, alphas, h_LM, h_pred, c_LM, c_pred
856 |
857 | def train_epoch(self, train_loader, encoder, criterion, encoder_optimizer, decoder_optimizer, epoch, args, step=None):
858 | """
859 | Performs one epoch's training.
860 |
861 | :param train_loader: DataLoader for training data
862 | :param encoder: encoder model
863 | :param decoder: decoder model
864 | :param criterion: loss layer
865 | :param encoder_optimizer: optimizer to update encoder's weights (if fine-tuning)
866 | :param decoder_optimizer: optimizer to update decoder's weights
867 | :param epoch: epoch number
868 | """
869 |
870 | self.train() # train mode (dropout and batchnorm is used)
871 | encoder.train()
872 |
873 | batch_time = AverageMeter() # forward prop. + back prop. time
874 | data_time = AverageMeter() # data loading time
875 | losses = AverageMeter() # loss (per word decoded)
876 | top1accs = AverageMeter() # top1 accuracy
877 |
878 | start = time.time()
879 |
880 | # Batches
881 | train_loader.shuffle()
882 | for i, (imgs, caps_sorted, caplens) in enumerate(train_loader):
883 | if step is not None:
884 | if i <= step:
885 | continue
886 | data_time.update(time.time() - start)
887 |
888 | # Move to GPU, if available
889 | imgs = imgs.to(device)
890 | caps_sorted = caps_sorted.to(device)
891 | caplens = caplens.to(device)
892 |
893 | # Forward prop.
894 | imgs = encoder(imgs)
895 | # Flatten image
896 | batch_size = imgs.size(0)
897 | encoder_dim = imgs.size(-1)
898 | imgs = imgs.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim)
899 | caplens = caplens.squeeze(1)
900 |
901 | # Sort input data by decreasing lengths
902 | # caplens, sort_ind = caplens.squeeze(1).sort(dim=0, descending=True)
903 | # imgs = imgs[sort_ind]
904 | # caps_sorted = caps[sort_ind]
905 |
906 | # Initialize LSTM state
907 | if isinstance(self.decode_step_pred, nn.LSTMCell):
908 | h_LM, c_LM, h_pred, c_pred = self.init_hidden_state(imgs) # (batch_size, decoder_dim)
909 | else:
910 | h_LM, h_pred = self.init_hidden_state(imgs) # (batch_size, decoder_dim)
911 | c_LM = c_pred = None
912 |
913 | max_cap_length = max(caplens.tolist())
914 | # TBPTT
915 | j = 0
916 | while j < max_cap_length:
917 | if j == 0:
918 | # bptt tokens after
919 | sub_seq_len = min(args.bptt + 1, max_cap_length - j)
920 | else:
921 | sub_seq_len = min(args.bptt, max_cap_length - j)
922 | # Do not leave too short tails (less than 10 tokens)
923 | short_tail = (caplens - (j + sub_seq_len) < 10) & (caplens - (j + sub_seq_len) > 0)
924 | if short_tail.any():
925 | sub_seq_len += max((caplens - (j + sub_seq_len))[short_tail].tolist())
926 |
927 | sub_seq_caplens = caplens - j
928 | sub_seq_caplens[sub_seq_caplens > sub_seq_len] = sub_seq_len
929 | batch_size_t = (sub_seq_caplens > 0).sum().item()
930 | sub_seq_caplens = sub_seq_caplens[:batch_size_t]
931 | sub_seq_cap = caps_sorted[:batch_size_t, j:j + sub_seq_len]
932 |
933 | h_LM = repackage_hidden(h_LM)
934 | h_pred = repackage_hidden(h_pred)
935 | if isinstance(self.decode_step_pred, nn.LSTMCell):
936 | c_LM = repackage_hidden(c_LM)
937 | c_pred = repackage_hidden(c_pred)
938 |
939 | decoder_optimizer.zero_grad()
940 | if encoder_optimizer is not None:
941 | encoder_optimizer.zero_grad()
942 | if j == 0:
943 | scores, decode_lengths, alphas, h_LM, h_pred, c_LM, c_pred = self(
944 | imgs[:batch_size_t],
945 | sub_seq_cap,
946 | sub_seq_caplens,
947 | h_LM, h_pred, c_LM, c_pred)
948 | # Since we decoded starting with , the targets are all words after , up to
949 | targets = sub_seq_cap[:, 1:]
950 | else:
951 | scores, decode_lengths, alphas, h_LM, h_pred, c_LM, c_pred = self(
952 | imgs[:batch_size_t],
953 | sub_seq_cap,
954 | sub_seq_caplens,
955 | h_LM, h_pred, c_LM, c_pred,
956 | caps_sorted[:batch_size_t, j - 1].unsqueeze(1))
957 | targets = sub_seq_cap
958 |
959 | # Remove timesteps that we didn't decode at, or are pads
960 | # pack_padded_sequence is an easy trick to do this
961 | scores = pack_padded_sequence(scores, decode_lengths, batch_first=True)[0]
962 | targets = pack_padded_sequence(targets, decode_lengths, batch_first=True)[0]
963 |
964 | # Calculate loss
965 | loss = criterion(scores, targets)
966 |
967 | # Add doubly stochastic attention regularization
968 | loss += args.alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean()
969 |
970 | # Back prop.
971 | if j + sub_seq_len < max_cap_length:
972 | loss.backward(retain_graph=True)
973 | else:
974 | loss.backward()
975 |
976 | # Clip gradients
977 | if args.grad_clip is not None:
978 | clip_gradient(decoder_optimizer, args.grad_clip)
979 | if encoder_optimizer is not None:
980 | clip_gradient(encoder_optimizer, args.grad_clip)
981 |
982 | # Update weights
983 | decoder_optimizer.step()
984 | if encoder_optimizer is not None:
985 | encoder_optimizer.step()
986 |
987 | # Keep track of metrics
988 | top1 = accuracy(scores, targets, 1)
989 | losses.update(loss.item(), sum(decode_lengths))
990 | top1accs.update(top1, sum(decode_lengths))
991 | j += sub_seq_len
992 | batch_time.update(time.time() - start)
993 |
994 | start = time.time()
995 |
996 | # Print status
997 | if i % args.print_freq == 0:
998 | print('Epoch: [{0}][{1}/{2}]\t'
999 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
1000 | 'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
1001 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
1002 | 'Top-1 Accuracy {top1.val:.3f} ({top1.avg:.3f})'.format(epoch, i, len(train_loader),
1003 | batch_time=batch_time,
1004 | data_time=data_time, loss=losses,
1005 | top1=top1accs), file=sys.stderr)
1006 | sys.stderr.flush()
1007 |
1008 | class TagDecoder(DecoderWithAttention):
1009 | '''
1010 | TagDecoder generates structure of the table
1011 | '''
1012 | def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size,
1013 | td_encode, decoder_cell=nn.LSTMCell, encoder_dim=512,
1014 | dropout=0.5, cnn_layer_stride=None, tag_H_grad=True):
1015 | """
1016 | :param attention_dim: size of attention network
1017 | :param embed_dim: embedding size
1018 | :param decoder_dim: size of decoder's RNN
1019 | :param vocab_size: size of vocabulary
1020 | :param encoder_dim: feature size of encoded images
1021 | :param dropout: dropout
1022 | """
1023 | super(TagDecoder, self).__init__(
1024 | attention_dim,
1025 | embed_dim,
1026 | decoder_dim,
1027 | vocab_size,
1028 | decoder_cell,
1029 | encoder_dim,
1030 | dropout)
1031 | self.td_encode = td_encode
1032 | self.tag_H_grad = tag_H_grad
1033 | if cnn_layer_stride is not None:
1034 | self.input_filter = resnet_block(cnn_layer_stride)
1035 |
1036 | def inference(self, encoder_out, word_map, max_steps=400, beam_size=5, return_attention=False):
1037 | """
1038 | Inference on test images with beam search
1039 | """
1040 | if hasattr(self, 'input_filter'):
1041 | encoder_out = self.input_filter(encoder_out.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
1042 | enc_image_size = encoder_out.size(1)
1043 | encoder_dim = encoder_out.size(3)
1044 |
1045 | # Flatten encoding
1046 | encoder_out = encoder_out.view(1, -1, encoder_dim) # (1, num_pixels, encoder_dim)
1047 | num_pixels = encoder_out.size(1)
1048 |
1049 | k = beam_size
1050 | vocab_size = len(word_map)
1051 |
1052 | # We'll treat the problem as having a batch size of k
1053 | encoder_out = encoder_out.expand(k, num_pixels, encoder_dim) # (k, num_pixels, encoder_dim)
1054 |
1055 | # Tensor to store top k previous words at each step; now they're just
1056 | k_prev_words = torch.LongTensor([[word_map['']]] * k).to(device) # (k, 1)
1057 |
1058 | # Tensor to store top k sequences; now they're just
1059 | seqs = k_prev_words # (k, 1)
1060 |
1061 | # Tensor to store top k sequences' scores; now they're just 0
1062 | top_k_scores = torch.zeros(k, 1).to(device) # (k, 1)
1063 |
1064 | # Tensor to store top k sequences' alphas; now they're just 1s
1065 | if return_attention:
1066 | seqs_alpha = torch.ones(k, 1, enc_image_size, enc_image_size).to(device) # (k, 1, enc_image_size, enc_image_size)
1067 |
1068 | # Lists to store completed sequences, their alphas and scores
1069 | complete_seqs = list()
1070 | if return_attention:
1071 | complete_seqs_alpha = list()
1072 | complete_seqs_scores = list()
1073 | complete_seqs_tag_H = list()
1074 |
1075 | # Start decoding
1076 | step = 1
1077 | if isinstance(self.decode_step, nn.LSTMCell):
1078 | h, c = self.init_hidden_state(encoder_out)
1079 | else:
1080 | h = self.init_hidden_state(encoder_out)
1081 | tag_H = [[] for i in range(k)]
1082 | # s is a number less than or equal to k, because sequences are removed from this process once they hit
1083 | while True:
1084 | embeddings = self.embedding(k_prev_words).squeeze(1) # (s, embed_dim)
1085 |
1086 | if return_attention:
1087 | awe, alpha = self.attention(encoder_out, h) # (s, encoder_dim), (s, num_pixels)
1088 | alpha = alpha.view(-1, enc_image_size, enc_image_size) # (s, enc_image_size, enc_image_size)
1089 | else:
1090 | awe, _ = self.attention(encoder_out, h) # (s, encoder_dim), (s, num_pixels)
1091 |
1092 | gate = self.sigmoid(self.f_beta(h)) # gating scalar, (s, encoder_dim)
1093 | awe = gate * awe
1094 |
1095 | if isinstance(self.decode_step, nn.LSTMCell):
1096 | h, c = self.decode_step(torch.cat([embeddings, awe], dim=1), (h, c)) # (s, decoder_dim)
1097 | else:
1098 | h = self.decode_step(torch.cat([embeddings, awe], dim=1), h) # (s, decoder_dim)
1099 |
1100 | h = repackage_hidden(h)
1101 | if isinstance(self.decode_step, nn.LSTMCell):
1102 | c = repackage_hidden(c)
1103 |
1104 | for i, w in enumerate(k_prev_words):
1105 | if w[0].item() in (word_map[''], word_map['>']):
1106 | tag_H[i].append(h[i].unsqueeze(0))
1107 |
1108 | scores = self.fc(h) # (s, vocab_size)
1109 | scores = F.log_softmax(scores, dim=1)
1110 |
1111 | # Add
1112 | scores = top_k_scores.expand_as(scores) + scores # (s, vocab_size)
1113 |
1114 | # For the first step, all k points will have the same scores (since same k previous words, h, c)
1115 | if step == 1:
1116 | top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s)
1117 | else:
1118 | # Unroll and find top scores, and their unrolled indices
1119 | top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s)
1120 |
1121 | # Convert unrolled indices to actual indices of scores
1122 | prev_word_inds = top_k_words // vocab_size # (s)
1123 | next_word_inds = top_k_words % vocab_size # (s)
1124 |
1125 |
1126 | # Add new words to sequences, alphas
1127 | seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1)
1128 | if return_attention:
1129 | seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)],
1130 | dim=1) # (s, step+1, enc_image_size, enc_image_size)
1131 |
1132 | # Which sequences are incomplete (didn't reach )?
1133 | incomplete_inds = []
1134 | complete_inds = []
1135 | for ind, next_word in enumerate(next_word_inds):
1136 | if next_word == word_map['']:
1137 | complete_inds.append(ind)
1138 | else:
1139 | incomplete_inds.append(ind)
1140 |
1141 | # Set aside complete sequences
1142 | if len(complete_inds) > 0:
1143 | complete_seqs.extend(seqs[complete_inds].tolist())
1144 | if return_attention:
1145 | complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist())
1146 | complete_seqs_scores.extend(top_k_scores[complete_inds])
1147 | complete_seqs_tag_H.extend([tag_H[i].copy() for i in prev_word_inds[complete_inds]])
1148 | k -= len(complete_inds) # reduce beam length accordingly
1149 |
1150 | # Break if all sequences are complete
1151 | if k == 0:
1152 | break
1153 | # Break if things have been going on too long
1154 | if step > max_steps:
1155 | # If no complete sequence is generated, finish the incomplete
1156 | # sequences with
1157 | if not complete_seqs_scores:
1158 | complete_seqs = seqs.tolist()
1159 | for i in range(len(complete_seqs)):
1160 | complete_seqs[i].append(word_map[''])
1161 | if return_attention:
1162 | complete_seqs_alpha = seqs_alpha.tolist()
1163 | complete_seqs_scores = top_k_scores.tolist()
1164 | complete_seqs_tag_H = [tag_H[i].copy() for i in prev_word_inds]
1165 | break
1166 |
1167 | # Proceed with incomplete sequences
1168 | seqs = seqs[incomplete_inds]
1169 | if return_attention:
1170 | seqs_alpha = seqs_alpha[incomplete_inds]
1171 | tag_H = [tag_H[i].copy() for i in prev_word_inds[incomplete_inds]]
1172 | h = h[prev_word_inds[incomplete_inds]]
1173 | if isinstance(self.decode_step, nn.LSTMCell):
1174 | c = c[prev_word_inds[incomplete_inds]]
1175 | encoder_out = encoder_out[prev_word_inds[incomplete_inds]]
1176 | top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
1177 | k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)
1178 |
1179 | step += 1
1180 | i = complete_seqs_scores.index(max(complete_seqs_scores))
1181 | seq = complete_seqs[i]
1182 | if complete_seqs_tag_H[i]:
1183 | tag_H = torch.cat(complete_seqs_tag_H[i]).to(device)
1184 | else:
1185 | tag_H = torch.zeros(0).to(device)
1186 | if return_attention:
1187 | alphas = complete_seqs_alpha[i]
1188 | return seq, alphas, tag_H
1189 | else:
1190 | return seq, tag_H
1191 |
1192 | def forward(self, encoder_out, encoded_tags_sorted, tag_lengths, num_cells=None, max_tag_len=None):
1193 | # Flatten image
1194 | if hasattr(self, 'input_filter'):
1195 | encoder_out = self.input_filter(encoder_out.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
1196 | batch_size = encoder_out.size(0)
1197 | encoder_dim = encoder_out.size(-1)
1198 | encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim)
1199 | num_pixels = encoder_out.size(1)
1200 |
1201 | # Embedding
1202 | embeddings = self.embedding(encoded_tags_sorted) # (batch_size, max_caption_length, embed_dim)
1203 | # We won't decode at the position, since we've finished generating as soon as we generate
1204 | # So, decoding lengths are actual lengths - 1
1205 | decode_lengths = (tag_lengths - 1).tolist()
1206 | max_decode_lengths = decode_lengths[0] if max_tag_len is None else max_tag_len
1207 | # Create tensors to hold word predicion scores and alphas
1208 | predictions = torch.zeros(batch_size, max_decode_lengths, self.vocab_size).to(device)
1209 | alphas = torch.zeros(batch_size, max_decode_lengths, num_pixels).to(device)
1210 |
1211 | if num_cells is not None:
1212 | # Create tensors to hold hidden state of tag decoder for cell decoder
1213 | tag_H = [torch.zeros(n.item(), self.decoder_dim).to(device) for n in num_cells]
1214 | pointer = torch.zeros(batch_size, dtype=torch.long).to(device)
1215 |
1216 | # Initialize LSTM state
1217 | if isinstance(self.decode_step, nn.LSTMCell):
1218 | h, c = self.init_hidden_state(encoder_out)
1219 | else:
1220 | h = self.init_hidden_state(encoder_out)
1221 |
1222 | # Decode table structure
1223 | for t in range(max_decode_lengths):
1224 | batch_size_t = sum([l > t for l in decode_lengths])
1225 | if batch_size_t > 0:
1226 | attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t],
1227 | h[:batch_size_t])
1228 | gate = self.sigmoid(self.f_beta(h[:batch_size_t])) # gating scalar, (batch_size_t, encoder_dim)
1229 | attention_weighted_encoding = gate * attention_weighted_encoding
1230 | if isinstance(self.decode_step, nn.LSTMCell):
1231 | h, c = self.decode_step(
1232 | torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
1233 | (h[:batch_size_t], c[:batch_size_t])) # (batch_size_t, decoder_dim)
1234 | else:
1235 | h = self.decode_step(
1236 | torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
1237 | h[:batch_size_t]) # (batch_size_t, decoder_dim)
1238 | predictions[:batch_size_t, t, :] = self.fc(self.dropout(h)) # (batch_size_t, vocab_size)
1239 | alphas[:batch_size_t, t, :] = alpha
1240 | if num_cells is not None:
1241 | for i in range(batch_size_t):
1242 | if encoded_tags_sorted[i, t] in self.td_encode:
1243 | if self.tag_H_grad:
1244 | tag_H[i][pointer[i]] = h[i]
1245 | else:
1246 | tag_H[i][pointer[i]] = repackage_hidden(h[i])
1247 | pointer[i] += 1
1248 | if num_cells is None:
1249 | return predictions, decode_lengths, alphas
1250 | else:
1251 | return predictions, decode_lengths, alphas, tag_H
1252 |
1253 | def train_epoch(self, train_loader, encoder, criterion, encoder_optimizer, decoder_optimizer, epoch, args):
1254 | """
1255 | Performs one epoch's training.
1256 |
1257 | :param train_loader: DataLoader for training data
1258 | :param encoder: encoder model
1259 | :param criterion: loss layer
1260 | :param encoder_optimizer: optimizer to update encoder's weights (if fine-tuning)
1261 | :param decoder_optimizer: optimizer to update decoder's weights
1262 | :param epoch: epoch number
1263 | """
1264 | self.train() # train mode (dropout and batchnorm is used)
1265 | encoder.train()
1266 |
1267 | batch_time = AverageMeter() # forward prop. + back prop. time
1268 | losses = AverageMeter() # loss (per word decoded)
1269 | top1accs = AverageMeter() # top1 accuracy
1270 |
1271 | start = time.time()
1272 | # Batches
1273 | for i, (imgs, tags, tag_lens) in enumerate(train_loader):
1274 | # Move to GPU, if available
1275 | imgs = imgs.to(device)
1276 | tags = tags.to(device)
1277 | tag_lens = tag_lens.to(device)
1278 |
1279 | # Flatten image
1280 | batch_size = imgs.size(0)
1281 | encoder_dim = imgs.size(-1)
1282 | imgs = imgs.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim)
1283 | tag_lens = tag_lens.squeeze(1)
1284 |
1285 | # Sort input data by decreasing lengths
1286 | tag_lens, sort_ind = tag_lens.sort(dim=0, descending=True)
1287 | imgs = imgs[sort_ind]
1288 | tags_sorted = tags[sort_ind]
1289 |
1290 | # Forward prop.
1291 | imgs = encoder(imgs)
1292 | if hasattr(self, 'input_filter'):
1293 | imgs = self.input_filter(imgs.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
1294 |
1295 | scores_tag, decode_lengths_tag, alphas_tag = self(
1296 | imgs, tags_sorted, tag_lens)
1297 |
1298 | # Calculate tag loss
1299 | targets_tag = tags_sorted[:, 1:]
1300 | scores_tag = pack_padded_sequence(scores_tag, decode_lengths_tag, batch_first=True)[0]
1301 | targets_tag = pack_padded_sequence(targets_tag, decode_lengths_tag, batch_first=True)[0]
1302 | loss = criterion(scores_tag, targets_tag)
1303 | # Add doubly stochastic attention regularization
1304 | loss += args.alpha_c * ((1. - alphas_tag.sum(dim=1)) ** 2).mean()
1305 | top1 = accuracy(scores_tag, targets_tag, 1)
1306 | tag_count = sum(decode_lengths_tag)
1307 | losses.update(loss.item(), tag_count)
1308 | top1accs.update(top1, tag_count)
1309 |
1310 | # Back prop.
1311 | decoder_optimizer.zero_grad()
1312 | if encoder_optimizer is not None:
1313 | encoder_optimizer.zero_grad()
1314 | loss.backward()
1315 |
1316 | # Clip gradients
1317 | if args.grad_clip is not None:
1318 | clip_gradient(decoder_optimizer, args.grad_clip)
1319 | if encoder_optimizer is not None:
1320 | clip_gradient(encoder_optimizer, args.grad_clip)
1321 |
1322 | # Update weights
1323 | decoder_optimizer.step()
1324 | if encoder_optimizer is not None:
1325 | encoder_optimizer.step()
1326 |
1327 | batch_time.update(time.time() - start)
1328 | start = time.time()
1329 |
1330 | # Print status
1331 | if i % args.print_freq == 0:
1332 | print('Epoch: [{0}][{1}/{2}]\t'
1333 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
1334 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
1335 | 'Accuracy {top1.val:.3f} ({top1.avg:.3f})'.format(epoch, i, len(train_loader),
1336 | batch_time=batch_time,
1337 | loss=losses,
1338 | top1=top1accs), file=sys.stderr)
1339 | sys.stderr.flush()
1340 |
1341 | class BBoxLoss(nn.Module):
1342 | def __init__(self):
1343 | super(BBoxLoss, self).__init__()
1344 | self.CE = nn.CrossEntropyLoss()
1345 |
1346 | def bbox_loss(self, gt, pred):
1347 | center_loss = (gt[:, :2] - pred[:, :2]).square().sum(dim=1)
1348 | size_loss = (gt[:, 2:].sqrt() - pred[:, 2:].sqrt()).square().sum(dim=1)
1349 | return center_loss + size_loss
1350 |
1351 | def forward(self, gt, pred):
1352 | empty_loss = self.CE(pred[:, :2], gt[:, 0].long()) # Empty cell classification loss
1353 | bbox_loss = (gt[:, 0] * self.bbox_loss(gt[:, 1:], pred[:, 2:])).mean() # Only compute for non-empty cells
1354 | return empty_loss + bbox_loss
1355 |
1356 | class CellBBox(nn.Module):
1357 | """
1358 | Regression network for bbox of table cell.
1359 | """
1360 |
1361 | def __init__(self, tag_decoder_dim):
1362 | """
1363 | :param tag_decoder_dim: size of tag decoder's RNN
1364 | """
1365 | super(CellBBox, self).__init__()
1366 | # linear layers to predict bbox (x_c, y_c, w, h)
1367 | self.bbox = nn.Sequential(
1368 | nn.Linear(tag_decoder_dim, tag_decoder_dim),
1369 | nn.ReLU(),
1370 | nn.Linear(tag_decoder_dim, 4),
1371 | nn.Sigmoid()
1372 | )
1373 | # linear layers to predict if a cell is empty
1374 | self.empty_cls = nn.Sequential(
1375 | nn.Linear(tag_decoder_dim, tag_decoder_dim),
1376 | nn.ReLU(),
1377 | nn.Linear(tag_decoder_dim, 2)
1378 | )
1379 |
1380 | def forward(self, decoder_hidden):
1381 | """
1382 | Forward propagation.
1383 | :param decoder_hidden: tag decoder output, a tensor of dimension (batch_size, tag_decoder_dim)
1384 | """
1385 | batch_size = encoder_out.size(0)
1386 | output = []
1387 | for i in range(batch_size):
1388 | not_empty = self.empty_cls(decoder_hidden[i]) # (num_cells, 2)
1389 | bbox_pred = self.bbox(decoder_hidden[i]) # (num_cells, 4)
1390 | output.append(torch.cat([not_empty, bbox_pred])) # (num_cells, 6)
1391 | return output
1392 |
1393 |
1394 | class BBoxLoss_Yolo(nn.Module):
1395 | def __init__(self, w_coor=5.0, w_noobj=0.5, image_size=(28, 28)):
1396 | super(BBoxLoss_Yolo, self).__init__()
1397 | self.w_coor = w_coor
1398 | self.w_noobj = w_noobj
1399 | self.image_size = image_size
1400 |
1401 | def IoU(self, pred, idx, gt):
1402 | ''' Calculates IoU between prediction boxes and table cell
1403 | '''
1404 | pred_xmin = pred[:, 1::5] - pred[:, 3::5] / 2
1405 | pred_xmax = pred[:, 1::5] + pred[:, 3::5] / 2
1406 | pred_ymin = pred[:, 2::5] - pred[:, 4::5] / 2
1407 | pred_ymax = pred[:, 2::5] + pred[:, 4::5] / 2
1408 | gt_xmin = (gt[:, 1] - gt[:, 3] / 2).unsqueeze(1)
1409 | gt_xmax = (gt[:, 1] + gt[:, 3] / 2).unsqueeze(1)
1410 | gt_ymin = (gt[:, 2] - gt[:, 4] / 2).unsqueeze(1)
1411 | gt_ymax = (gt[:, 2] + gt[:, 4] / 2).unsqueeze(1)
1412 |
1413 | I_w = torch.max(torch.FloatTensor([0]), torch.min(pred_xmax, gt_xmax) - torch.max(pred_xmin, gt_xmin))
1414 | I_h = torch.max(torch.FloatTensor([0]), torch.min(pred_ymax, gt_ymax) - torch.max(pred_ymin, gt_ymin))
1415 | I = I_w * I_h
1416 | U = pred[:, 3::5] * pred[:, 4::5] + (gt[:, 3] * gt[:, 4]).unsqueeze(1) - I
1417 | IoU = I / (U + 1e-8) # Avoid dividing by 0
1418 | return IoU
1419 |
1420 | def find_responsible_box(self, pred, idx, gt):
1421 | ''' Finds which prediction box is responsible for the table cell
1422 | '''
1423 | pred = pred[idx[0], idx[1]]
1424 | IoU = self.IoU(pred, gt)
1425 | num_cells = gt.size(0)
1426 | IoU, responsible_box = torch.max(IoU, dim=1)
1427 | return responsible_box, IoU
1428 |
1429 | def forward(self, gt, pred):
1430 | '''
1431 | :param gt: ground truth (num_cells, 5)
1432 | :param pred: prediction of CellBBoxYolo (num_cells, num_pixels, 5 * num_bboxes_per_pixel)
1433 | '''
1434 | num_cells = gt.size(0)
1435 | image_width, image_height = self.image_size28
1436 | non_empty_cell = gt[:, 0] == 1
1437 |
1438 | gt_non_empty, gt_empty = gt[non_empty_cell], gt[~non_empty_cell]
1439 | pred_non_empty, pred_empty = pred[non_empty_cell], pred[~non_empty_cell]
1440 | loss_empty = self.w_noobj * pred_empty[:, :, 0::5].square().sum()
1441 |
1442 | # Encode gt as Yolo format
1443 | # Find center pixel
1444 | x_c, y_c = torch.floor(gt_non_empty[:, 1] * image_width), torch.floor(gt_non_empty[:, 2] * image_height)
1445 | idx = (torch.LongTensor(torch.arange(gt_non_empty.size(0))), (x_c * image_width + y_c).long())
1446 |
1447 | # Compute offset
1448 | gt_non_empty[:, 1], gt_non_empty[:, 2] = gt_non_empty[:, 1] * image_width - x_c, gt_non_empty[:, 2] * image_height - y_c
1449 | gt_non_empty[:, 3], gt_non_empty[:, 4] = gt_non_empty[:, 3] * image_width, gt_non_empty[:, 4] * image_height
1450 |
1451 | responsible_box, IoU = self.find_responsible_box(pred_non_empty, idx, gt_non_empty)
1452 | responsible_box = responsible_box * 5
1453 | gt_non_empty[:, 0] = IoU
1454 | gt_non_empty[:, 3:5] = gt_non_empty[:, 3:5].sqrt()
1455 |
1456 | responsible_box = torch.cat((
1457 | pred_non_empty[idx[0], idx[1], responsible_box],
1458 | pred_non_empty[idx[0], idx[1], responsible_box + 1],
1459 | pred_non_empty[idx[0], idx[1], responsible_box + 2],
1460 | pred_non_empty[idx[0], idx[1], responsible_box + 3].sqrt(),
1461 | pred_non_empty[idx[0], idx[1], responsible_box + 4].sqrt()
1462 | ), dim=1)
1463 |
1464 |
1465 | loss_coor = self.w_coor * (responsible_box[:, 1:5] - gt_non_empty[:, 1:5]).square().sum()
1466 | loss_noobj = (responsible_box[:, 0] - gt_non_empty[:, 0]).square().sum() + \
1467 | self.w_noobj * 0 + \
1468 | loss_empty
1469 |
1470 | return loss_coor + loss_noobj
1471 |
1472 |
1473 | class CellBBoxYolo(nn.Module):
1474 | """
1475 | NOT READY
1476 | Table cell detection network (based on the idea of Yolo).
1477 | """
1478 |
1479 | def __init__(self, encoder_dim, tag_decoder_dim, feature_dim, num_bboxes_per_pixel=2):
1480 | """
1481 | :param encoder_dim: feature size of encoded images
1482 | :param tag_decoder_dim: size of tag decoder's RNN
1483 | :param feature_dim: size of the features
1484 | """
1485 | super(CellBBoxYolo, self).__init__()
1486 | self.encoder_att = nn.Linear(encoder_dim, feature_dim) # linear layer to transform encoded image
1487 | self.tag_decoder_att = nn.Linear(tag_decoder_dim, feature_dim) # linear layer to transform tag decoder output
1488 | self.bbox = nn.Linear(feature_dim, 5 * num_bboxes_per_pixel) # linear layer to predict bboxes [c, x_c, y_c, w, h] * num_bboxes_per_pixel
1489 | self.relu = nn.ReLU()
1490 | self.sigmoid = nn.Sigmoid() # sigmoid to scale bbox between 0 and 1
1491 |
1492 | def forward(self, encoder_out, decoder_hidden):
1493 | """
1494 | Forward propagation.
1495 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
1496 | :param decoder_hidden: tag decoder output, a tensor of dimension [(num_cells, tag_decoder_dim)] * batch_size
1497 | :return: [(num_cells, 5 * num_bboxes_per_pixel)] * batch_size
1498 | """
1499 | batch_size = encoder_out.size(0)
1500 | output = []
1501 | for i in range(batch_size):
1502 | att1 = self.encoder_att(encoder_out[i].unsqueeze(0)) # (1, num_pixels, feature_dim)
1503 | att2 = self.tag_decoder_att(decoder_hidden[i]).unsqueeze(1) # (num_cells, 1, feature_dim)
1504 | att = self.relu(att1 + att2) # (num_cells, num_pixels, feature_dim)
1505 | bboxes = self.sigmoid(self.bbox(att)) # (num_cells, num_pixels, 5 * num_bboxes_per_pixel)
1506 | output.append(bboxes)
1507 | return output
1508 |
1509 |
1510 | class CellDecoder_baseline(nn.Module):
1511 | '''
1512 | CellDecoder generates cell content
1513 | '''
1514 | def __init__(self, attention_dim, embed_dim, tag_decoder_dim, decoder_dim,
1515 | vocab_size, decoder_cell=nn.LSTMCell, encoder_dim=512,
1516 | dropout=0.5, cnn_layer_stride=None):
1517 | """
1518 | :param attention_dim: size of attention network
1519 | :param embed_dim: embedding size
1520 | :param tag_decoder_dim: size of tag decoder's RNN
1521 | :param decoder_dim: size of decoder's RNN
1522 | :param vocab_size: size of vocabulary
1523 | :param encoder_dim: feature size of encoded images
1524 | :param dropout: dropout
1525 | :param mini_batch_size: batch size of cells to reduce GPU memory usage
1526 | """
1527 | super(CellDecoder_baseline, self).__init__()
1528 |
1529 | self.encoder_dim = encoder_dim
1530 | self.attention_dim = attention_dim
1531 | self.embed_dim = embed_dim
1532 | self.decoder_dim = decoder_dim
1533 | self.vocab_size = vocab_size
1534 | self.dropout = dropout
1535 |
1536 | self.attention = CellAttention(encoder_dim, tag_decoder_dim, decoder_dim, attention_dim) # attention network
1537 |
1538 | self.embedding = nn.Embedding(vocab_size, embed_dim) # embedding layer
1539 |
1540 | self.decode_step = decoder_cell(embed_dim + encoder_dim, decoder_dim, bias=True) # decoder LSTMCell
1541 |
1542 | self.init_h = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial hidden state of LSTMCell
1543 | if isinstance(self.decode_step, nn.LSTMCell):
1544 | self.init_c = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial cell state of LSTMCell
1545 | self.f_beta = nn.Linear(decoder_dim, encoder_dim) # linear layer to create a sigmoid-activated gate
1546 | self.sigmoid = nn.Sigmoid()
1547 | self.fc = nn.Linear(decoder_dim, vocab_size) # linear layer to find scores over vocabulary
1548 | self.dropout = nn.Dropout(p=self.dropout)
1549 |
1550 | if cnn_layer_stride is not None:
1551 | self.input_filter = resnet_block(cnn_layer_stride)
1552 |
1553 | self.init_weights() # initialize some layers with the uniform distribution
1554 |
1555 | def init_weights(self):
1556 | """
1557 | Initializes some parameters with values from the uniform distribution, for easier convergence.
1558 | """
1559 | self.embedding.weight.data.uniform_(-0.1, 0.1)
1560 | self.fc.bias.data.fill_(0)
1561 | self.fc.weight.data.uniform_(-0.1, 0.1)
1562 |
1563 | def init_hidden_state(self, encoder_out, batch_size):
1564 | """
1565 | Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images.
1566 |
1567 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
1568 | :return: hidden state, cell state
1569 | """
1570 | mean_encoder_out = encoder_out.mean(dim=1)
1571 | h = self.init_h(mean_encoder_out).expand(batch_size, -1)
1572 | if isinstance(self.decode_step, nn.LSTMCell):
1573 | c = self.init_c(mean_encoder_out).expand(batch_size, -1)
1574 | return h, c
1575 | else:
1576 | return h
1577 |
1578 | def inference(self, encoder_out, tag_H, word_map, max_steps=400, beam_size=5, return_attention=False):
1579 | """
1580 | Inference on test images with beam search
1581 | """
1582 |
1583 | if hasattr(self, 'input_filter'):
1584 | encoder_out = self.input_filter(encoder_out.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
1585 | enc_image_size = encoder_out.size(1)
1586 | encoder_dim = encoder_out.size(3)
1587 |
1588 | # Flatten encoding
1589 | encoder_out = encoder_out.view(1, -1, encoder_dim) # (1, num_pixels, encoder_dim)
1590 |
1591 | num_cells = tag_H.size(0)
1592 | cell_seqs = []
1593 | if return_attention:
1594 | cell_alphas = []
1595 | vocab_size = len(word_map)
1596 |
1597 | for c_id in range(num_cells):
1598 | k = beam_size
1599 | # Tensor to store top k previous words at each step; now they're just
1600 | k_prev_words = torch.LongTensor([[word_map['']]] * k).to(device) # (k, 1)
1601 |
1602 | # Tensor to store top k sequences; now they're just
1603 | seqs = k_prev_words # (k, 1)
1604 |
1605 | # Tensor to store top k sequences' scores; now they're just 0
1606 | top_k_scores = torch.zeros(k, 1).to(device) # (k, 1)
1607 |
1608 | if return_attention:
1609 | # Tensor to store top k sequences' alphas; now they're just 1s
1610 | seqs_alpha = torch.ones(k, 1, enc_image_size, enc_image_size).to(device) # (k, 1, enc_image_size, enc_image_size)
1611 |
1612 | # Lists to store completed sequences, their alphas and scores
1613 | complete_seqs = list()
1614 | if return_attention:
1615 | complete_seqs_alpha = list()
1616 | complete_seqs_scores = list()
1617 |
1618 | # Start decoding
1619 | step = 1
1620 | if isinstance(self.decode_step, nn.LSTMCell):
1621 | h, c = self.init_hidden_state(encoder_out, k)
1622 | else:
1623 | h = self.init_hidden_state(encoder_out, k)
1624 |
1625 | cell_tag_H = tag_H[c_id].expand(k, -1)
1626 | # s is a number less than or equal to k, because sequences are removed from this process once they hit
1627 | while True:
1628 | embeddings = self.embedding(k_prev_words).squeeze(1) # (s, embed_dim)
1629 | if return_attention:
1630 | awe, alpha = self.attention(encoder_out, cell_tag_H, h) # (s, encoder_dim), (s, num_pixels)
1631 | alpha = alpha.view(-1, enc_image_size, enc_image_size) # (s, enc_image_size, enc_image_size)
1632 | else:
1633 | awe, _ = self.attention(encoder_out, cell_tag_H, h) # (s, encoder_dim), (s, num_pixels)
1634 |
1635 | gate = self.sigmoid(self.f_beta(h)) # gating scalar, (s, encoder_dim)
1636 | awe = gate * awe
1637 |
1638 | if isinstance(self.decode_step, nn.LSTMCell):
1639 | h, c = self.decode_step(torch.cat([embeddings, awe], dim=1), (h, c)) # (s, decoder_dim)
1640 | else:
1641 | h = self.decode_step(torch.cat([embeddings, awe], dim=1), h) # (s, decoder_dim)
1642 |
1643 | h = repackage_hidden(h)
1644 | if isinstance(self.decode_step, nn.LSTMCell):
1645 | c = repackage_hidden(c)
1646 |
1647 | scores = self.fc(h) # (s, vocab_size)
1648 | scores = F.log_softmax(scores, dim=1)
1649 |
1650 | # Add
1651 | scores = top_k_scores.expand_as(scores) + scores # (s, vocab_size)
1652 |
1653 | # For the first step, all k points will have the same scores (since same k previous words, h, c)
1654 | if step == 1:
1655 | top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s)
1656 | else:
1657 | # Unroll and find top scores, and their unrolled indices
1658 | top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s)
1659 |
1660 | # Convert unrolled indices to actual indices of scores
1661 | prev_word_inds = top_k_words // vocab_size # (s)
1662 | next_word_inds = top_k_words % vocab_size # (s)
1663 |
1664 | # Add new words to sequences, alphas
1665 | seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1)
1666 | if return_attention:
1667 | seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)],
1668 | dim=1) # (s, step+1, enc_image_size, enc_image_size)
1669 |
1670 | # Which sequences are incomplete (didn't reach )?
1671 | incomplete_inds = []
1672 | complete_inds = []
1673 | for ind, next_word in enumerate(next_word_inds):
1674 | if next_word == word_map['']:
1675 | complete_inds.append(ind)
1676 | else:
1677 | incomplete_inds.append(ind)
1678 |
1679 | # Set aside complete sequences
1680 | if len(complete_inds) > 0:
1681 | complete_seqs.extend(seqs[complete_inds].tolist())
1682 | if return_attention:
1683 | complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist())
1684 | complete_seqs_scores.extend(top_k_scores[complete_inds])
1685 | k -= len(complete_inds) # reduce beam length accordingly
1686 |
1687 | # Break if all sequences are complete
1688 | if k == 0:
1689 | break
1690 | # Break if things have been going on too long
1691 | if step > max_steps:
1692 | # If no complete sequence is generated, finish the incomplete
1693 | # sequences with
1694 | if not complete_seqs_scores:
1695 | complete_seqs = seqs.tolist()
1696 | for i in range(len(complete_seqs)):
1697 | complete_seqs[i].append(word_map[''])
1698 | if return_attention:
1699 | complete_seqs_alpha = seqs_alpha.tolist()
1700 | complete_seqs_scores = top_k_scores.tolist()
1701 | break
1702 |
1703 | # Proceed with incomplete sequences
1704 | seqs = seqs[incomplete_inds]
1705 | if return_attention:
1706 | seqs_alpha = seqs_alpha[incomplete_inds]
1707 | cell_tag_H = cell_tag_H[prev_word_inds[incomplete_inds]]
1708 | h = h[prev_word_inds[incomplete_inds]]
1709 | if isinstance(self.decode_step, nn.LSTMCell):
1710 | c = c[prev_word_inds[incomplete_inds]]
1711 | top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
1712 | k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)
1713 |
1714 | step += 1
1715 | i = complete_seqs_scores.index(max(complete_seqs_scores))
1716 | cell_seqs.append(complete_seqs[i])
1717 | if return_attention:
1718 | cell_alphas.append(complete_seqs_alpha[i])
1719 | if return_attention:
1720 | return cell_seqs, cell_alphas
1721 | else:
1722 | return cell_seqs
1723 |
1724 | def forward(self, encoder_out, encoded_cells_sorted, cell_lengths, tag_H):
1725 | """
1726 | Forward propagation.
1727 | :param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim)
1728 | :param encoded_cells_sorted: encoded cells, a list of batch_size tensors of dimension (num_cells, max_cell_length)
1729 | :param tag_H: hidden state from TagDeoder, a list of batch_size tensors of dimension (num_cells, TagDecoder's decoder_dim)
1730 | :param cell_lengths: caption lengths, a list of batch_size tensor of dimension (num_cells, 1)
1731 | :return: scores for vocabulary, sorted encoded captions, decode lengths, weights
1732 | """
1733 | if hasattr(self, 'input_filter'):
1734 | encoder_out = self.input_filter(encoder_out.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
1735 |
1736 | # Flatten image
1737 | batch_size = encoder_out.size(0)
1738 | encoder_dim = encoder_out.size(-1)
1739 | encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim)
1740 | num_pixels = encoder_out.size(1)
1741 |
1742 | # Decode cell content
1743 | predictions_cell = []
1744 | alphas_cell = []
1745 | decode_lengths_cell = []
1746 | for i in range(batch_size):
1747 | num_cells = cell_lengths[i].size(0)
1748 | embeddings = self.embedding(encoded_cells_sorted[i])
1749 | decode_lengths = (cell_lengths[i] - 1).tolist()
1750 | max_decode_lengths = decode_lengths[0]
1751 | predictions = torch.zeros(num_cells, max_decode_lengths, self.vocab_size).to(device)
1752 | alphas = torch.zeros(num_cells, max_decode_lengths, num_pixels).to(device)
1753 | if isinstance(self.decode_step, nn.LSTMCell):
1754 | h, c = self.init_hidden_state(encoder_out[i].unsqueeze(0), num_cells)
1755 | else:
1756 | h = self.init_hidden_state(encoder_out[i].unsqueeze(0), num_cells)
1757 | for t in range(max_decode_lengths):
1758 | batch_size_t = sum([l > t for l in decode_lengths])
1759 | attention_weighted_encoding, alpha = self.attention(encoder_out[i].unsqueeze(0),
1760 | tag_H[i][:batch_size_t],
1761 | h[:batch_size_t])
1762 | gate = self.sigmoid(self.f_beta(h[:batch_size_t])) # gating scalar, (batch_size_t, encoder_dim)
1763 | attention_weighted_encoding = gate * attention_weighted_encoding
1764 | if isinstance(self.decode_step, nn.LSTMCell):
1765 | h, c = self.decode_step(
1766 | torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
1767 | (h[:batch_size_t], c[:batch_size_t])) # (batch_size_t, decoder_dim)
1768 | else:
1769 | h = self.decode_step(
1770 | torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
1771 | h[:batch_size_t]) # (batch_size_t, decoder_dim)
1772 | predictions[:batch_size_t, t, :] = self.fc(self.dropout(h)) # (batch_size_t, vocab_size)
1773 | alphas[:batch_size_t, t, :] = alpha
1774 | predictions_cell.append(predictions)
1775 | alphas_cell.append(alphas)
1776 | decode_lengths_cell.append(decode_lengths)
1777 | return predictions_cell, decode_lengths_cell, alphas_cell
1778 |
1779 | class CellDecoder(nn.Module):
1780 | '''
1781 | CellDecoder generates cell content
1782 | '''
1783 | def __init__(self, attention_dim, embed_dim, tag_decoder_dim, language_dim,
1784 | decoder_dim, vocab_size, decoder_cell=nn.LSTMCell,
1785 | encoder_dim=512, dropout=0.5, cnn_layer_stride=None):
1786 | """
1787 | :param attention_dim: size of attention network
1788 | :param embed_dim: embedding size
1789 | :param tag_decoder_dim: size of tag decoder's RNN
1790 | :param language_dim: size of language model's RNN
1791 | :param decoder_dim: size of decoder's RNN
1792 | :param vocab_size: size of vocabulary
1793 | :param encoder_dim: feature size of encoded images
1794 | :param dropout: dropout
1795 | :param mini_batch_size: batch size of cells to reduce GPU memory usage
1796 | """
1797 | super(CellDecoder, self).__init__()
1798 |
1799 | self.encoder_dim = encoder_dim
1800 | self.attention_dim = attention_dim
1801 | self.embed_dim = embed_dim
1802 | self.language_dim = language_dim
1803 | self.decoder_dim = decoder_dim
1804 | self.vocab_size = vocab_size
1805 | self.dropout = dropout
1806 |
1807 | self.attention = CellAttention(encoder_dim, tag_decoder_dim, language_dim, attention_dim) # attention network
1808 |
1809 | self.embedding = nn.Embedding(vocab_size, embed_dim) # embedding layer
1810 |
1811 | self.decode_step_LM = decoder_cell(embed_dim, language_dim, bias=True) # language model LSTMCell
1812 |
1813 | self.decode_step_pred = decoder_cell(encoder_dim, decoder_dim, bias=True) # decoding LSTMCell
1814 | self.init_h = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial hidden state of LSTMCell
1815 | if isinstance(self.decode_step_pred, nn.LSTMCell):
1816 | self.init_c = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial cell state of LSTMCell
1817 | self.f_beta = nn.Linear(decoder_dim, encoder_dim) # linear layer to create a sigmoid-activated gate
1818 | self.sigmoid = nn.Sigmoid()
1819 | self.fc = nn.Linear(decoder_dim, vocab_size) # linear layer to find scores over vocabulary
1820 | self.dropout = nn.Dropout(p=self.dropout)
1821 |
1822 | if cnn_layer_stride is not None:
1823 | self.input_filter = resnet_block(cnn_layer_stride)
1824 |
1825 | self.init_weights() # initialize some layers with the uniform distribution
1826 |
1827 | def init_weights(self):
1828 | """
1829 | Initializes some parameters with values from the uniform distribution, for easier convergence.
1830 | """
1831 | self.embedding.weight.data.uniform_(-0.1, 0.1)
1832 | self.fc.bias.data.fill_(0)
1833 | self.fc.weight.data.uniform_(-0.1, 0.1)
1834 |
1835 | def init_hidden_state(self, encoder_out, batch_size):
1836 | """
1837 | Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images.
1838 |
1839 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
1840 | :return: hidden state, cell state
1841 | """
1842 | mean_encoder_out = encoder_out.mean(dim=1)
1843 | h_LM = torch.zeros(batch_size, self.language_dim).to(device)
1844 | h_pred = self.init_h(mean_encoder_out).expand(batch_size, -1)
1845 | if isinstance(self.decode_step_pred, nn.LSTMCell):
1846 | c_LM = torch.zeros(batch_size, self.language_dim).to(device)
1847 | c_pred = self.init_c(mean_encoder_out).expand(batch_size, -1)
1848 | return h_LM, c_LM, h_pred, c_pred
1849 | else:
1850 | return h_LM, h_pred
1851 |
1852 | def inference(self, encoder_out, tag_H, word_map, max_steps=400, beam_size=5, return_attention=False):
1853 | """
1854 | Inference on test images with beam search
1855 | """
1856 | if hasattr(self, 'input_filter'):
1857 | encoder_out = self.input_filter(encoder_out.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
1858 |
1859 | enc_image_size = encoder_out.size(1)
1860 | encoder_dim = encoder_out.size(3)
1861 |
1862 | # Flatten encoding
1863 | encoder_out = encoder_out.view(1, -1, encoder_dim) # (1, num_pixels, encoder_dim)
1864 |
1865 | num_cells = tag_H.size(0)
1866 | cell_seqs = []
1867 | if return_attention:
1868 | cell_alphas = []
1869 | vocab_size = len(word_map)
1870 |
1871 | for c in range(num_cells):
1872 | k = beam_size
1873 | # Tensor to store top k previous words at each step; now they're just
1874 | k_prev_words = torch.LongTensor([[word_map['']]] * k).to(device) # (k, 1)
1875 |
1876 | # Tensor to store top k sequences; now they're just
1877 | seqs = k_prev_words # (k, 1)
1878 |
1879 | # Tensor to store top k sequences' scores; now they're just 0
1880 | top_k_scores = torch.zeros(k, 1).to(device) # (k, 1)
1881 |
1882 | if return_attention:
1883 | # Tensor to store top k sequences' alphas; now they're just 1s
1884 | seqs_alpha = torch.ones(k, 1, enc_image_size, enc_image_size).to(device) # (k, 1, enc_image_size, enc_image_size)
1885 |
1886 | # Lists to store completed sequences, their alphas and scores
1887 | complete_seqs = list()
1888 | if return_attention:
1889 | complete_seqs_alpha = list()
1890 | complete_seqs_scores = list()
1891 |
1892 | # Start decoding
1893 | step = 1
1894 | if isinstance(self.decode_step_pred, nn.LSTMCell):
1895 | h_LM, c_LM, h_cell, c_cell = self.init_hidden_state(encoder_out, k)
1896 | else:
1897 | h_LM, h_cell = self.init_hidden_state(encoder_out, k)
1898 |
1899 | cell_tag_H = tag_H[c].expand(k, -1)
1900 | # s is a number less than or equal to k, because sequences are removed from this process once they hit
1901 | while True:
1902 | embeddings = self.embedding(k_prev_words).squeeze(1) # (s, embed_dim)
1903 |
1904 | if isinstance(self.decode_step_LM, nn.LSTMCell):
1905 | h_LM, c_LM = self.decode_step_LM(embeddings, (h_LM, c_LM))
1906 | else:
1907 | h_LM = self.decode_step_LM(embeddings, h_LM)
1908 |
1909 | if return_attention:
1910 | awe, alpha = self.attention(
1911 | encoder_out,
1912 | cell_tag_H,
1913 | h_LM)
1914 | alpha = alpha.view(-1, enc_image_size, enc_image_size) # (s, enc_image_size, enc_image_size)
1915 | else:
1916 | awe, _ = self.attention(
1917 | encoder_out,
1918 | cell_tag_H,
1919 | h_LM)
1920 | gate = self.sigmoid(self.f_beta(h_cell)) # gating scalar, (s, encoder_dim)
1921 | awe = gate * awe
1922 |
1923 | if isinstance(self.decode_step_pred, nn.LSTMCell):
1924 | h_cell, c_cell = self.decode_step_pred(awe, (h_cell, c_cell)) # (batch_size_t, decoder_dim)
1925 | else:
1926 | h_cell = self.decode_step_pred(awe, h_cell) # (batch_size_t, decoder_dim)
1927 |
1928 | h_LM = repackage_hidden(h_LM)
1929 | h_cell = repackage_hidden(h_cell)
1930 | if isinstance(self.decode_step_pred, nn.LSTMCell):
1931 | c_LM = repackage_hidden(c_LM)
1932 | c_cell = repackage_hidden(c_cell)
1933 |
1934 | scores = self.fc(h_cell) # (s, vocab_size)
1935 | scores = F.log_softmax(scores, dim=1)
1936 |
1937 | # Add
1938 | scores = top_k_scores.expand_as(scores) + scores # (s, vocab_size)
1939 |
1940 | # For the first step, all k points will have the same scores (since same k previous words, h, c)
1941 | if step == 1:
1942 | top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s)
1943 | else:
1944 | # Unroll and find top scores, and their unrolled indices
1945 | top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s)
1946 |
1947 | # Convert unrolled indices to actual indices of scores
1948 | prev_word_inds = top_k_words / vocab_size # (s)
1949 | next_word_inds = top_k_words % vocab_size # (s)
1950 |
1951 | # Add new words to sequences, alphas
1952 | seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1)
1953 | if return_attention:
1954 | seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)],
1955 | dim=1) # (s, step+1, enc_image_size, enc_image_size)
1956 |
1957 | # Which sequences are incomplete (didn't reach )?
1958 | incomplete_inds = []
1959 | complete_inds = []
1960 | for ind, next_word in enumerate(next_word_inds):
1961 | if next_word == word_map['']:
1962 | complete_inds.append(ind)
1963 | else:
1964 | incomplete_inds.append(ind)
1965 |
1966 | # Set aside complete sequences
1967 | if len(complete_inds) > 0:
1968 | complete_seqs.extend(seqs[complete_inds].tolist())
1969 | if return_attention:
1970 | complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist())
1971 | complete_seqs_scores.extend(top_k_scores[complete_inds])
1972 | k -= len(complete_inds) # reduce beam length accordingly
1973 |
1974 | # Break if all sequences are complete
1975 | if k == 0:
1976 | break
1977 | # Break if things have been going on too long
1978 | if step > max_steps:
1979 | # If no complete sequence is generated, finish the incomplete
1980 | # sequences with
1981 | if not complete_seqs_scores:
1982 | complete_seqs = seqs.tolist()
1983 | for i in range(len(complete_seqs)):
1984 | complete_seqs[i].append(word_map[''])
1985 | if return_attention:
1986 | complete_seqs_alpha = seqs_alpha.tolist()
1987 | complete_seqs_scores = top_k_scores.tolist()
1988 | break
1989 |
1990 | # Proceed with incomplete sequences
1991 | seqs = seqs[incomplete_inds]
1992 | if return_attention:
1993 | seqs_alpha = seqs_alpha[incomplete_inds]
1994 | cell_tag_H = cell_tag_H[prev_word_inds[incomplete_inds]]
1995 | h_LM = h_LM[prev_word_inds[incomplete_inds]]
1996 | h_cell = h_cell[prev_word_inds[incomplete_inds]]
1997 | if isinstance(self.decode_step_pred, nn.LSTMCell):
1998 | c_LM = c_LM[prev_word_inds[incomplete_inds]]
1999 | c_cell = c_cell[prev_word_inds[incomplete_inds]]
2000 | top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1)
2001 | k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1)
2002 |
2003 | step += 1
2004 | i = complete_seqs_scores.index(max(complete_seqs_scores))
2005 | cell_seqs.append(complete_seqs[i])
2006 | if return_attention:
2007 | cell_alphas.append(complete_seqs_alpha[i])
2008 | if return_attention:
2009 | return cell_seqs, cell_alphas
2010 | else:
2011 | return cell_seqs
2012 |
2013 | def forward(self, encoder_out, encoded_cells_sorted, cell_lengths, tag_H):
2014 | """
2015 | Forward propagation.
2016 | :param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim)
2017 | :param encoded_cells_sorted: encoded cells, a list of batch_size tensors of dimension (num_cells, max_cell_length)
2018 | :param tag_H: hidden state from TagDeoder, a list of batch_size tensors of dimension (num_cells, TagDecoder's decoder_dim)
2019 | :param cell_lengths: caption lengths, a list of batch_size tensor of dimension (num_cells, 1)
2020 | :return: scores for vocabulary, sorted encoded captions, decode lengths, weights
2021 | """
2022 | if hasattr(self, 'input_filter'):
2023 | encoder_out = self.input_filter(encoder_out.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
2024 |
2025 | # Flatten image
2026 | batch_size = encoder_out.size(0)
2027 | encoder_dim = encoder_out.size(-1)
2028 | encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim)
2029 | num_pixels = encoder_out.size(1)
2030 |
2031 | # Decode cell content
2032 | predictions_cell = []
2033 | alphas_cell = []
2034 | decode_lengths_cell = []
2035 | for i in range(batch_size):
2036 | num_cells = cell_lengths[i].size(0)
2037 | embeddings = self.embedding(encoded_cells_sorted[i])
2038 | decode_lengths = (cell_lengths[i] - 1).tolist()
2039 | max_decode_lengths = decode_lengths[0]
2040 | predictions = torch.zeros(num_cells, max_decode_lengths, self.vocab_size).to(device)
2041 | alphas = torch.zeros(num_cells, max_decode_lengths, num_pixels).to(device)
2042 | if isinstance(self.decode_step_pred, nn.LSTMCell):
2043 | h_LM, c_LM, h_cell, c_cell = self.init_hidden_state(encoder_out[i].unsqueeze(0), num_cells)
2044 | else:
2045 | h_LM, h_cell = self.init_hidden_state(encoder_out[i].unsqueeze(0), num_cells)
2046 | for t in range(max_decode_lengths):
2047 | batch_size_t = sum([l > t for l in decode_lengths])
2048 | # Language LSTM
2049 | if isinstance(self.decode_step_LM, nn.LSTMCell):
2050 | h_LM, c_LM = self.decode_step_LM(
2051 | embeddings[:batch_size_t, t, :],
2052 | (h_LM[:batch_size_t], c_LM[:batch_size_t])) # (batch_size_t, decoder_dim)
2053 | else:
2054 | h_LM = self.decode_step_LM(
2055 | embeddings[:batch_size_t, t, :],
2056 | h_LM[:batch_size_t]) # (batch_size_t, decoder_dim)
2057 |
2058 | # Attention
2059 | attention_weighted_encoding, alpha = self.attention(
2060 | encoder_out[i].unsqueeze(0), tag_H[i][:batch_size_t],
2061 | h_LM)
2062 | # Decoder LSTM
2063 | gate = self.sigmoid(self.f_beta(h_cell[:batch_size_t])) # gating scalar, (batch_size_t, encoder_dim)
2064 | attention_weighted_encoding = gate * attention_weighted_encoding
2065 | if isinstance(self.decode_step_pred, nn.LSTMCell):
2066 | h_cell, c_cell = self.decode_step_pred(
2067 | attention_weighted_encoding,
2068 | (h_cell[:batch_size_t], c_cell[:batch_size_t])) # (batch_size_t, decoder_dim)
2069 | else:
2070 | h_cell = self.decode_step_pred(
2071 | attention_weighted_encoding,
2072 | h_cell[:batch_size_t]) # (batch_size_t, decoder_dim)
2073 | predictions[:batch_size_t, t, :] = self.fc(self.dropout(h_cell)) # (batch_size_t, vocab_size)
2074 | alphas[:batch_size_t, t, :] = alpha
2075 | predictions_cell.append(predictions)
2076 | alphas_cell.append(alphas)
2077 | decode_lengths_cell.append(decode_lengths)
2078 |
2079 | return predictions_cell, decode_lengths_cell, alphas_cell
2080 |
2081 | class DualDecoder(nn.Module):
2082 | """
2083 | Dual decoder model:
2084 | first decoder generates structure of the table
2085 | second decoder generates cell content
2086 | """
2087 | def __init__(self, tag_attention_dim, cell_attention_dim, tag_embed_dim, cell_embed_dim,
2088 | tag_decoder_dim, language_dim, cell_decoder_dim,
2089 | tag_vocab_size, cell_vocab_size, td_encode,
2090 | decoder_cell=nn.LSTMCell, encoder_dim=512, dropout=0.5,
2091 | cell_decoder_type=1,
2092 | cnn_layer_stride=None, tag_H_grad=True, predict_content=True, predict_bbox=False):
2093 | """
2094 | :param tag_attention_dim: size of attention network for tags
2095 | :param cell_attention_dim: size of attention network for cells
2096 | :param tag_embed_dim: embedding size of tags
2097 | :param cell_embed_dim: embedding size of cell content
2098 | :param tag_decoder_dim: size of tag decoder's RNN
2099 | :param language_dim: size of language model's RNN
2100 | :param cell_decoder_dim: size of cell decoder's RNN
2101 | :param tag_vocab_size: size of tag vocabulary
2102 | :param cell_vocab_size: size of cellvocabulary
2103 | :param td_encode: encodings for ('', ' >')
2104 | :param encoder_dim: feature size of encoded images
2105 | :param dropout: dropout
2106 | :param mini_batch_size: batch size of cells to reduce GPU memory usage
2107 | """
2108 | super(DualDecoder, self).__init__()
2109 |
2110 | self.tag_attention_dim = tag_attention_dim
2111 | self.cell_attention_dim = cell_attention_dim
2112 | self.tag_embed_dim = tag_embed_dim
2113 | self.cell_embed_dim = cell_embed_dim
2114 | self.tag_decoder_dim = tag_decoder_dim
2115 | self.language_dim = language_dim
2116 | self.cell_decoder_dim = cell_decoder_dim
2117 | self.tag_vocab_size = tag_vocab_size
2118 | self.cell_vocab_size = cell_vocab_size
2119 | self.decoder_cell = decoder_cell
2120 | self.encoder_dim = encoder_dim
2121 | self.dropout = dropout
2122 | self.td_encode = td_encode
2123 | self.tag_H_grad = tag_H_grad
2124 | self.predict_content = predict_content
2125 | self.predict_bbox = predict_bbox
2126 | self.relu_tag = nn.ReLU()
2127 | self.relu_cell = nn.ReLU()
2128 |
2129 | self.tag_decoder = TagDecoder(
2130 | tag_attention_dim,
2131 | tag_embed_dim,
2132 | tag_decoder_dim,
2133 | tag_vocab_size,
2134 | td_encode,
2135 | decoder_cell,
2136 | encoder_dim,
2137 | dropout,
2138 | cnn_layer_stride['tag'] if isinstance(cnn_layer_stride, dict) else None,
2139 | self.tag_H_grad)
2140 | if cell_decoder_type == 1:
2141 | self.cell_decoder = CellDecoder_baseline(
2142 | cell_attention_dim,
2143 | cell_embed_dim,
2144 | tag_decoder_dim,
2145 | cell_decoder_dim,
2146 | cell_vocab_size,
2147 | decoder_cell,
2148 | encoder_dim,
2149 | dropout,
2150 | cnn_layer_stride['cell'] if isinstance(cnn_layer_stride, dict) else None)
2151 | elif cell_decoder_type == 2:
2152 | self.cell_decoder = CellDecoder(
2153 | cell_attention_dim,
2154 | cell_embed_dim,
2155 | tag_decoder_dim,
2156 | language_dim,
2157 | cell_decoder_dim,
2158 | cell_vocab_size,
2159 | decoder_cell,
2160 | encoder_dim,
2161 | dropout,
2162 | cnn_layer_stride['cell'] if isinstance(cnn_layer_stride, dict) else None)
2163 | self.bbox_loss = BBoxLoss()
2164 | self.cell_bbox_regressor = CellBBox(tag_decoder_dim)
2165 |
2166 | if torch.cuda.device_count() > 1:
2167 | self.tag_decoder = MyDataParallel(self.tag_decoder)
2168 | self.cell_decoder = MyDataParallel(self.cell_decoder)
2169 | self.cell_bbox_regressor = MyDataParallel(self.cell_bbox_regressor)
2170 |
2171 | def load_pretrained_tag_decoder(self, tag_decoder):
2172 | self.tag_decoder = tag_decoder
2173 |
2174 | def fine_tune_tag_decoder(self, fine_tune=False):
2175 | for p in self.tag_decoder.parameters():
2176 | p.requires_grad = fine_tune
2177 |
2178 | def inference(self, encoder_out, word_map,
2179 | max_steps={'tag': 400, 'cell': 200},
2180 | beam_size={'tag': 5, 'cell': 5},
2181 | return_attention=False):
2182 | """
2183 | Inference on test images with beam search
2184 | """
2185 | res = self.tag_decoder.inference(
2186 | encoder_out,
2187 | word_map['word_map_tag'],
2188 | max_steps['tag'],
2189 | beam_size['tag'],
2190 | return_attention=return_attention
2191 | )
2192 | if res is not None:
2193 | output, tag_H = res[:-1], res[-1]
2194 | if self.predict_content:
2195 | cell_res = self.cell_decoder.inference(
2196 | encoder_out,
2197 | tag_H,
2198 | word_map['word_map_cell'],
2199 | max_steps['cell'],
2200 | beam_size['cell'],
2201 | return_attention=return_attention
2202 | )
2203 | if return_attention:
2204 | cell_seqs, cell_alphas = cell_res
2205 | output += (cell_seqs, cell_alphas)
2206 | else:
2207 | cell_seqs = cell_res
2208 | output += (cell_seqs,)
2209 | if self.predict_bbox:
2210 | cell_bbox = self.cell_bbox_regressor(
2211 | encoder_out,
2212 | tag_H
2213 | )
2214 | output += (cell_bbox,)
2215 | return output
2216 | else:
2217 | return None
2218 |
2219 | def forward(self, encoder_out, encoded_tags_sorted, tag_lengths, cells=None, cell_lens=None, num_cells=None):
2220 | """
2221 | Forward propagation.
2222 | :param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim)
2223 | :param encoded_tags_sorted: encoded tags, a tensor of dimension (batch_size, max_tag_length)
2224 | :param tag_lengths: caption lengths, a tensor of dimension (batch_size, 1)
2225 | :param encoded_cells: encoded cells, a list of batch_size tensors of dimension (num_cells, max_cell_length)
2226 | :param cell_lengths: caption lengths, a list of batch_size tensor of dimension (num_cells, 1)
2227 | :return: scores for vocabulary, sorted encoded captions, decode lengths, weights
2228 | """
2229 | batch_size = encoder_out.size(0)
2230 | N_GPUS = torch.cuda.device_count()
2231 | if N_GPUS > 1 and N_GPUS != batch_size:
2232 | # WHen multiple GPUs are available, rearrange the samples
2233 | # in the batch so that partition is more balanced. This
2234 | # increases training speed and reduce the chance of
2235 | # GPU memory overflow.
2236 | balance_inds = np.arange(np.ceil(batch_size / N_GPUS) * N_GPUS, dtype=int).reshape(-1, N_GPUS).flatten('F')[:batch_size]
2237 | encoder_out = encoder_out[balance_inds]
2238 | encoded_tags_sorted = encoded_tags_sorted[balance_inds]
2239 | tag_lengths = tag_lengths[balance_inds]
2240 | if num_cells is not None:
2241 | num_cells = num_cells[balance_inds]
2242 | if self.predict_content:
2243 | cells = [cells[ind] for ind in balance_inds]
2244 | cell_lens = [cell_lens[ind] for ind in balance_inds]
2245 |
2246 | output = self.tag_decoder(
2247 | encoder_out,
2248 | encoded_tags_sorted,
2249 | tag_lengths,
2250 | num_cells=num_cells if self.predict_content or self.predict_bbox else None,
2251 | max_tag_len=(tag_lengths[0] - 1).item()
2252 | )
2253 |
2254 | if self.predict_content or self.predict_bbox:
2255 | tag_H = output[-1]
2256 | if self.predict_bbox:
2257 | predictions_cell_bboxes = self.cell_bbox_regressor(
2258 | encoder_out,
2259 | tag_H
2260 | )
2261 |
2262 | if self.predict_content:
2263 | # Sort cells of each sample by decreasing length
2264 | for j in range(len(cells)):
2265 | cell_lens[j], s_ind = cell_lens[j].sort(dim=0, descending=True)
2266 | cells[j] = cells[j][s_ind]
2267 | tag_H[j] = tag_H[j][s_ind]
2268 |
2269 | predictions_cell, decode_lengths_cell, alphas_cell = self.cell_decoder(
2270 | encoder_out,
2271 | cells,
2272 | cell_lens,
2273 | tag_H
2274 | )
2275 |
2276 | output = output[:3]
2277 | if self.predict_content:
2278 | output += (predictions_cell, decode_lengths_cell, alphas_cell, cells)
2279 | if self.predict_bbox:
2280 | output += (predictions_cell_bboxes,)
2281 |
2282 | if N_GPUS > 1 and N_GPUS != batch_size:
2283 | # Restore the correct order of samples in the batch to compute
2284 | # the correct loss
2285 | restore_inds = np.arange(np.ceil(batch_size / N_GPUS) * N_GPUS, dtype=int).reshape(N_GPUS, -1).flatten('F')[:batch_size]
2286 | output = tuple([item[ind] for ind in restore_inds] if isinstance(item, list) else item[restore_inds] for item in output)
2287 | return output
2288 |
2289 | def train_epoch(self, train_loader, encoder, criterion,
2290 | encoder_optimizer, tag_decoder_optimizer, cell_decoder_optimizer, cell_bbox_regressor_optimizer,
2291 | epoch, args):
2292 | """
2293 | Performs one epoch's training.
2294 |
2295 | :param train_loader: DataLoader for training data
2296 | :param encoder: encoder model
2297 | :param criterion: loss layer
2298 | :param encoder_optimizer: optimizer to update encoder's weights (if fine-tuning)
2299 | :param tag_decoder_optimizer: optimizer to update tag decoder's weights
2300 | :param cell_decoder_optimizer: optimizer to update cell decoder's weights
2301 | :param epoch: epoch number
2302 | """
2303 | self.train() # train mode (dropout and batchnorm is used)
2304 | encoder.train()
2305 |
2306 | batch_time = AverageMeter() # forward prop. + back prop. time
2307 | losses_tag = AverageMeter() # loss (per word decoded)
2308 | losses_total = AverageMeter() # loss (per word decoded)
2309 | top1accs_tag = AverageMeter() # top1 accuracy
2310 | if self.predict_content:
2311 | losses_cell = AverageMeter() # loss (per word decoded)
2312 | top1accs_cell = AverageMeter() # top1 accuracy
2313 | if self.predict_bbox:
2314 | losses_cell_box = AverageMeter() # top1 accuracy
2315 |
2316 | start = time.time()
2317 | # Batches
2318 | train_loader.shuffle()
2319 | for i, batch in enumerate(train_loader):
2320 | try:
2321 | imgs, tags, tag_lens, num_cells = batch[:4]
2322 | # Move to GPU, if available
2323 | imgs = imgs.to(device)
2324 | tags = tags.to(device)
2325 | tag_lens = tag_lens.to(device)
2326 | num_cells = num_cells.to(device)
2327 | if self.predict_content:
2328 | cells, cell_lens = batch[4:6]
2329 | cells = [c.to(device) for c in cells]
2330 | cell_lens = [c.to(device) for c in cell_lens]
2331 | else:
2332 | cells = None
2333 | cell_lens = None
2334 |
2335 | if self.predict_bbox:
2336 | cell_bboxes = batch[-1]
2337 | cell_bboxes = [c.to(device) for c in cell_bboxes]
2338 |
2339 | # Forward prop.
2340 | imgs = encoder(imgs)
2341 |
2342 | # Flatten image
2343 | batch_size = imgs.size(0)
2344 | # encoder_dim = imgs.size(-1)
2345 | # imgs = imgs.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim)
2346 | tag_lens = tag_lens.squeeze(1)
2347 |
2348 | # Sort input data by decreasing tag lengths
2349 | tag_lens, sort_ind = tag_lens.sort(dim=0, descending=True)
2350 | imgs = imgs[sort_ind]
2351 | tags_sorted = tags[sort_ind]
2352 | num_cells = num_cells[sort_ind]
2353 | if self.predict_content:
2354 | cells = [cells[ind] for ind in sort_ind]
2355 | cell_lens = [cell_lens[ind] for ind in sort_ind]
2356 | if self.predict_bbox:
2357 | cell_bboxes = [cell_bboxes[ind] for ind in sort_ind]
2358 |
2359 | output = self(imgs, tags_sorted, tag_lens, cells, cell_lens, num_cells)
2360 |
2361 | scores_tag, decode_lengths_tag, alphas_tag = output[:3]
2362 | if self.predict_content:
2363 | scores_cell, decode_lengths_cell, alphas_cell, cells = output[3:7]
2364 | if self.predict_bbox:
2365 | predictions_cell_bboxes = output[-1]
2366 |
2367 | # Gather results to the same GPU
2368 | if torch.cuda.device_count() > 1:
2369 | if self.predict_content:
2370 | for s, cell in zip(range(len(scores_cell)), cells):
2371 | if scores_cell[s].get_device() != cell.get_device():
2372 | scores_cell[s] = scores_cell[s].to(device)
2373 | alphas_cell[s] = alphas_cell[s].to(device)
2374 | if self.predict_bbox:
2375 | for s, cell_bbox in zip(range(len(predictions_cell_bboxes)), cell_bboxes):
2376 | if predictions_cell_bboxes[s].get_device() != cell_bbox.get_device():
2377 | predictions_cell_bboxes[s] = predictions_cell_bboxes[s].to(device)
2378 |
2379 | # Calculate tag loss
2380 | targets_tag = tags_sorted[:, 1:]
2381 | scores_tag = pack_padded_sequence(scores_tag, decode_lengths_tag, batch_first=True)[0]
2382 | targets_tag = pack_padded_sequence(targets_tag, decode_lengths_tag, batch_first=True)[0]
2383 | loss_tag = criterion['tag'](scores_tag, targets_tag)
2384 | # Add doubly stochastic attention regularization
2385 | # loss_tag += args.alpha_c * ((1. - alphas_tag.sum(dim=1)) ** 2).mean()
2386 | loss_tag += args.alpha_tag * (self.relu_tag(1. - alphas_tag.sum(dim=1)) ** 2).mean()
2387 | loss = args.tag_loss_weight * loss_tag
2388 | top1_tag = accuracy(scores_tag, targets_tag, 1)
2389 | tag_count = sum(decode_lengths_tag)
2390 | losses_tag.update(loss_tag.item(), tag_count)
2391 | top1accs_tag.update(top1_tag, tag_count)
2392 |
2393 | # Calculate cell loss
2394 | if self.predict_content and args.cell_loss_weight > 0:
2395 | loss_cell = 0.
2396 | reg_alphas_cell = 0
2397 | for scores, gt, decode_lengths, alpha in zip(scores_cell, cells, decode_lengths_cell, alphas_cell):
2398 | targets = gt[:, 1:]
2399 | scores = pack_padded_sequence(scores, decode_lengths, batch_first=True)[0]
2400 | targets = pack_padded_sequence(targets, decode_lengths, batch_first=True)[0]
2401 | __loss_cell = criterion['cell'](scores, targets)
2402 | # __loss_cell += args.alpha_c * ((1. - alpha.sum(dim=1)) ** 2).mean()
2403 | reg_alphas_cell += args.alpha_cell * (self.relu_cell(1. - alpha.sum(dim=(0, 1))) ** 2).mean()
2404 | top1_cell = accuracy(scores, targets, 1)
2405 | cell_count = sum(decode_lengths)
2406 | losses_cell.update(__loss_cell.item(), cell_count)
2407 | top1accs_cell.update(top1_cell, cell_count)
2408 | loss_cell += __loss_cell
2409 | loss_cell /= batch_size
2410 | loss_cell += reg_alphas_cell / batch_size
2411 | loss += args.cell_loss_weight * loss_cell
2412 | # Calculate cell bbox loss
2413 | if self.predict_bbox and args.cell_bbox_loss_weight > 0:
2414 | loss_cell_bbox = 0.
2415 | for pred, gt in zip(predictions_cell_bboxes, cell_bboxes):
2416 | __loss_cell_bbox = self.bbox_loss(gt, pred)
2417 | losses_cell_bbox.update(__loss_cell_bbox.item(), pred.size(0))
2418 | loss_cell_bbox += __loss_cell_bbox
2419 | loss_cell_bbox /= batch_size
2420 | loss += args.cell_bbox_loss_weight * loss_cell_bbox
2421 |
2422 | losses_total.update(loss.item(), 1)
2423 |
2424 | # Back prop.
2425 | if encoder_optimizer is not None:
2426 | encoder_optimizer.zero_grad()
2427 | if tag_decoder_optimizer is not None:
2428 | tag_decoder_optimizer.zero_grad()
2429 | if self.predict_content:
2430 | cell_decoder_optimizer.zero_grad()
2431 | if self.predict_bbox:
2432 | cell_bbox_regressor_optimizer.zero_grad()
2433 | loss.backward()
2434 |
2435 | # Clip gradients
2436 | if args.grad_clip is not None:
2437 | if encoder_optimizer is not None:
2438 | clip_gradient(encoder_optimizer, args.grad_clip)
2439 | if tag_decoder_optimizer is not None:
2440 | clip_gradient(tag_decoder_optimizer, args.grad_clip)
2441 | if self.predict_content:
2442 | clip_gradient(cell_decoder_optimizer, args.grad_clip)
2443 | if self.predict_bbox:
2444 | clip_gradient(cell_bbox_regressor_optimizer, args.grad_clip)
2445 |
2446 | # Update weights
2447 | if encoder_optimizer is not None:
2448 | encoder_optimizer.step()
2449 | if tag_decoder_optimizer is not None:
2450 | tag_decoder_optimizer.step()
2451 | if self.predict_content:
2452 | cell_decoder_optimizer.step()
2453 | if self.predict_bbox:
2454 | cell_bbox_regressor_optimizer.step()
2455 |
2456 | batch_time.update(time.time() - start)
2457 | start = time.time()
2458 |
2459 | # Print status
2460 | if i % args.print_freq == 0:
2461 | verbose = 'Epoch: [{0}][{1}/{2}]\t'.format(epoch, i, len(train_loader)) + \
2462 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'.format(batch_time=batch_time) + \
2463 | 'Loss_total {loss_total.val:.4f} ({loss_total.avg:.4f})\t'.format(loss_total=losses_total) + \
2464 | 'Loss_tag {loss_tag.val:.4f} ({loss_tag.avg:.4f})\t'.format(loss_tag=losses_tag) + \
2465 | 'Acc_tag {top1_tag.val:.3f} ({top1_tag.avg:.3f})\t'.format(top1_tag=top1accs_tag)
2466 | if self.predict_content:
2467 | verbose += 'Loss_cell {loss_cell.val:.4f} ({loss_cell.avg:.4f})\t'.format(loss_cell=losses_cell) + \
2468 | 'Acc_cell {top1_cell.val:.3f} ({top1_cell.avg:.3f})\t'.format(top1_cell=top1accs_cell)
2469 | if self.predict_bbox:
2470 | verbose += 'Loss_cell_bbox {loss_cell_bbox.val:.4f} ({loss_cell_bbox.avg:.4f})\t'.format(loss_cell_bbox=losses_cell_bbox)
2471 |
2472 | print(verbose, file=sys.stderr)
2473 | sys.stderr.flush()
2474 |
2475 | batch_time.reset()
2476 | losses_total.reset()
2477 | losses_tag.reset()
2478 | top1accs_tag.reset()
2479 | if self.predict_content:
2480 | losses_cell.reset()
2481 | top1accs_cell.reset()
2482 | if self.predict_bbox:
2483 | losses_cell_bbox.reset()
2484 | except Exception as e:
2485 | raise
2486 |
--------------------------------------------------------------------------------
/parallel.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | from concurrent.futures import ProcessPoolExecutor, as_completed
3 |
4 | def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=0):
5 | """
6 | A parallel version of the map function with a progress bar.
7 |
8 | Args:
9 | array (array-like): An array to iterate over.
10 | function (function): A python function to apply to the elements of array
11 | n_jobs (int, default=16): The number of cores to use
12 | use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of
13 | keyword arguments to function
14 | front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job.
15 | Useful for catching bugs
16 | Returns:
17 | [function(array[0]), function(array[1]), ...]
18 | """
19 | # We run the first few iterations serially to catch bugs
20 | if front_num > 0:
21 | front = [function(**a) if use_kwargs else function(a) for a in array[:front_num]]
22 | else:
23 | front = []
24 | # If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging.
25 | if n_jobs == 1:
26 | return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])]
27 | # Assemble the workers
28 | with ProcessPoolExecutor(max_workers=n_jobs) as pool:
29 | # Pass the elements of array into function
30 | if use_kwargs:
31 | futures = [pool.submit(function, **a) for a in array[front_num:]]
32 | else:
33 | futures = [pool.submit(function, a) for a in array[front_num:]]
34 | kwargs = {
35 | 'total': len(futures),
36 | 'unit': 'it',
37 | 'unit_scale': True,
38 | 'leave': True
39 | }
40 | # Print out the progress as tasks complete
41 | for f in tqdm(as_completed(futures), **kwargs):
42 | pass
43 | out = []
44 | # Get the results from the futures.
45 | for i, future in tqdm(enumerate(futures)):
46 | try:
47 | out.append(future.result())
48 | except Exception as e:
49 | out.append(e)
50 | return front + out
51 |
--------------------------------------------------------------------------------
/prepare_data.py:
--------------------------------------------------------------------------------
1 | '''
2 | Prepares data files for training and validating EDD model. The following files
3 | are generated in :
4 | - TRAIN_IMAGES_.h5 # Training images
5 | - TRAIN_TAGS_.json # Training structural tokens
6 | - TRAIN_TAGLENS_.json # Length of training structural tokens
7 | - TRAIN_CELLS_.json # Training cell tokens
8 | - TRAIN_CELLLENS_.json # Length of training cell tokens
9 | - TRAIN_CELLBBOXES_.json # Training cell bboxes
10 | - VAL.json # Validation ground truth
11 | - WORDMAP_.json # Vocab
12 |
13 | is formatted according to input args (keep_AR, max_tag_len, ...)
14 | '''
15 | import json
16 | import jsonlines
17 | from tqdm import tqdm
18 | import argparse
19 | from collections import Counter
20 | import os
21 | from PIL import Image
22 | import h5py
23 | import numpy as np
24 | from utils import image_rescale
25 | from html import escape
26 | from lxml import html
27 |
28 | def is_valid(img):
29 | if len(img['html']['structure']['tokens']) > args.max_tag_len:
30 | return False
31 | for cell in img['html']['cells']:
32 | if len(cell['tokens']) > args.max_cell_len:
33 | return False
34 | with Image.open(os.path.join(args.image_dir, img['split'], img['filename'])) as im:
35 | if im.width > args.max_image_size or im.height > args.max_image_size:
36 | return False
37 | return True
38 |
39 | def scale(bbox, orig_size):
40 | ''' Normalizes bbox to 0 - 1
41 | '''
42 | if bbox[0] == 0:
43 | return bbox
44 | else:
45 | x = float((bbox[3] + bbox[1]) / 2) / orig_size[0] # x center
46 | y = float((bbox[4] + bbox[2]) / 2) / orig_size[1] # y center
47 | width = float(bbox[3] - bbox[1]) / orig_size[0]
48 | height = float(bbox[4] - bbox[2]) / orig_size[1]
49 | return [1, x, y, width, height]
50 |
51 | def format_html(img):
52 | ''' Formats HTML code from tokenized annotation of img
53 | '''
54 | tag_len = len(img['html']['structure']['tokens'])
55 | cell_len_max = max([len(c['tokens']) for c in img['html']['cells']])
56 | HTML = img['html']['structure']['tokens'].copy()
57 | to_insert = [i for i, tag in enumerate(HTML) if tag in ('', '>')]
58 | for i, cell in zip(to_insert[::-1], img['html']['cells'][::-1]):
59 | if cell:
60 | cell = ''.join([escape(token) if len(token) == 1 else token for token in cell['tokens']])
61 | HTML.insert(i + 1, cell)
62 | HTML = '' % ''.join(HTML)
63 | root = html.fromstring(HTML)
64 | for td, cell in zip(root.iter('td'), img['html']['cells']):
65 | if 'bbox' in cell:
66 | bbox = cell['bbox']
67 | td.attrib['x'] = str(bbox[0])
68 | td.attrib['y'] = str(bbox[1])
69 | td.attrib['width'] = str(bbox[2] - bbox[0])
70 | td.attrib['height'] = str(bbox[3] - bbox[1])
71 | HTML = html.tostring(root, encoding='utf-8').decode()
72 | return HTML, tag_len, cell_len_max
73 |
74 |
75 | parser = argparse.ArgumentParser(description='Prepares data files for training EDD')
76 | parser.add_argument('--annotation', type=str, help='path to annotation file')
77 | parser.add_argument('--image_dir', type=str, help='path to image folder')
78 | parser.add_argument('--out_dir', type=str, help='path to folder to save data files')
79 | parser.add_argument('--min_word_freq', default=5, type=int, help='minimium frequency for a token to be included in vocab')
80 | parser.add_argument('--max_tag_len', default=300, type=int, help='maximium number of structural tokens for a sample to be included')
81 | parser.add_argument('--max_cell_len', default=100, type=int, help='maximium number tokens in a cell for a sample to be included')
82 | parser.add_argument('--max_image_size', default=512, type=int, help='maximium image width/height a sample to be included')
83 | parser.add_argument('--image_size', default=448, type=int, help='target image rescaling size')
84 | parser.add_argument('--keep_AR', default=False, action='store_true', help='keep aspect ratio and pad with zeros when rescaling images')
85 |
86 | args = parser.parse_args()
87 |
88 | # Read image paths and captions for each image
89 | dataset = 'PubTabNet'
90 | train_image_paths = []
91 | train_image_tags = []
92 | train_image_cells = []
93 | train_image_cell_bboxes = []
94 | val_gt = dict()
95 | word_freq_tag = Counter()
96 |
97 | word_freq_cell = Counter()
98 | with jsonlines.open(args.annotation, 'r') as reader:
99 | for img in tqdm(reader):
100 | if img['split'] == 'train':
101 | if is_valid(img):
102 | tags = []
103 | cells = []
104 | cell_bboxes = []
105 | word_freq_tag.update(img['html']['structure']['tokens'])
106 | tags.append(img['html']['structure']['tokens'])
107 | for cell in img['html']['cells']:
108 | word_freq_cell.update(cell['tokens'])
109 | cells.append(cell['tokens'])
110 | if 'bbox' in cell:
111 | cell_bboxes.append([1] + cell['bbox'])
112 | else:
113 | cell_bboxes.append([0, 0, 0, 0, 0])
114 |
115 | path = os.path.join(args.image_dir, img['split'], img['filename'])
116 |
117 | train_image_paths.append(path)
118 | train_image_tags.append(tags)
119 | train_image_cells.append(cells)
120 | train_image_cell_bboxes.append(cell_bboxes)
121 | elif img['split'] == 'val':
122 | HTML, tag_len, cell_len_max = format_html(img)
123 | with Image.open(os.path.join(args.image_dir, img['split'], img['filename'])) as im:
124 | val_gt[img['filename']] = {
125 | 'html': HTML,
126 | 'tag_len': tag_len,
127 | 'cell_len_max': cell_len_max,
128 | 'width': im.width,
129 | 'height': im.height,
130 | 'type': 'complex' if '>' in img['html']['structure']['tokens'] else 'simple'
131 | }
132 |
133 |
134 | if not os.path.exists(args.out_dir):
135 | os.makedirs(args.out_dir)
136 |
137 | # Save ground truth html of validation set
138 | with open(os.path.join(args.out_dir, 'VAL.json'), 'w') as j:
139 | json.dump(val_gt, j)
140 |
141 | # Sanity check
142 | assert len(train_image_paths) == len(train_image_tags)
143 |
144 | # Create a base/root name for all output files
145 | base_filename = dataset + '_' + \
146 | str(args.keep_AR) + '_keep_AR_' + \
147 | str(args.max_tag_len) + '_max_tag_len_' + \
148 | str(args.max_cell_len) + '_max_cell_len_' + \
149 | str(args.max_image_size) + '_max_image_size'
150 |
151 | words_tag = [w for w in word_freq_tag.keys() if word_freq_tag[w] >= args.min_word_freq]
152 | words_cell = [w for w in word_freq_cell.keys() if word_freq_cell[w] >= args.min_word_freq]
153 |
154 | word_map_tag = {k: v + 1 for v, k in enumerate(words_tag)}
155 | word_map_tag[''] = len(word_map_tag) + 1
156 | word_map_tag[''] = len(word_map_tag) + 1
157 | word_map_tag[''] = len(word_map_tag) + 1
158 | word_map_tag[''] = 0
159 |
160 | word_map_cell = {k: v + 1 for v, k in enumerate(words_cell)}
161 | word_map_cell[''] = len(word_map_cell) + 1
162 | word_map_cell[''] = len(word_map_cell) + 1
163 | word_map_cell[''] = len(word_map_cell) + 1
164 | word_map_cell[''] = 0
165 |
166 | # Save word map to a JSON
167 | with open(os.path.join(args.out_dir, 'WORDMAP_' + base_filename + '.json'), 'w') as j:
168 | json.dump({"word_map_tag": word_map_tag, "word_map_cell": word_map_cell}, j)
169 |
170 | with h5py.File(os.path.join(args.out_dir, 'TRAIN_IMAGES_' + base_filename + '.hdf5'), 'a') as h:
171 | dataset_name = 'images'
172 |
173 | # Check if the dataset already exists and delete it if it does
174 | if dataset_name in h:
175 | del h[dataset_name]
176 |
177 | # Create dataset inside HDF5 file to store images
178 | images = h.create_dataset(dataset_name, (len(train_image_paths), 3, args.image_size, args.image_size), dtype='uint8')
179 |
180 | enc_tags = []
181 | tag_lens = []
182 | enc_cells = []
183 | cell_lens = []
184 | cell_bboxes = []
185 |
186 | for i, path in enumerate(tqdm(train_image_paths)):
187 | # Read images
188 | img, orig_size = image_rescale(train_image_paths[i], args.image_size, args.keep_AR, return_size=True)
189 | assert img.shape == (3, args.image_size, args.image_size)
190 | assert np.max(img) <= 255
191 |
192 | # Save image to HDF5 file
193 | images[i] = img
194 |
195 | for tag in train_image_tags[i]:
196 | # Encode captions
197 | enc_tag = [word_map_tag['']] + [word_map_tag.get(word, word_map_tag['']) for word in tag] + \
198 | [word_map_tag['']] + [word_map_tag['']] * (args.max_tag_len - len(tag))
199 | # Find caption lengths
200 | tag_len = len(tag) + 2
201 |
202 | enc_tags.append(enc_tag)
203 | tag_lens.append(tag_len)
204 |
205 | __enc_cell = []
206 | __cell_len = []
207 | for cell in train_image_cells[i]:
208 | # Encode captions
209 | enc_cell = [word_map_cell['']] + [word_map_cell.get(word, word_map_cell['']) for word in cell] + \
210 | [word_map_cell['']] + [word_map_cell['']] * (args.max_cell_len - len(cell))
211 | # Find caption lengths
212 | cell_len = len(cell) + 2
213 |
214 | __enc_cell.append(enc_cell)
215 | __cell_len.append(cell_len)
216 | enc_cells.append(__enc_cell)
217 | cell_lens.append(__cell_len)
218 |
219 | __cell_bbox = []
220 | for bbox in train_image_cell_bboxes[i]:
221 | __cell_bbox.append(scale(bbox, orig_size))
222 | cell_bboxes.append(__cell_bbox)
223 |
224 | # Save encoded captions and their lengths to JSON files
225 | with open(os.path.join(args.out_dir, 'TRAIN_TAGS_' + base_filename + '.json'), 'w') as j:
226 | json.dump(enc_tags, j)
227 |
228 | with open(os.path.join(args.out_dir, 'TRAIN_TAGLENS_' + base_filename + '.json'), 'w') as j:
229 | json.dump(tag_lens, j)
230 |
231 | with open(os.path.join(args.out_dir, 'TRAIN_CELLS_' + base_filename + '.json'), 'w') as j:
232 | json.dump(enc_cells, j)
233 |
234 | with open(os.path.join(args.out_dir, 'TRAIN_CELLLENS_' + base_filename + '.json'), 'w') as j:
235 | json.dump(cell_lens, j)
236 |
237 | with open(os.path.join(args.out_dir, 'TRAIN_CELLBBOXES_' + base_filename + '.json'), 'w') as j:
238 | json.dump(cell_bboxes, j)
239 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | Distance>=0.1.3
2 | apted>=1.0.3
3 | torch>=1.0
4 | torchvision>=0.4.2
5 |
6 | jsonlines
7 | tqdm
8 | Pillow
9 | h5py
10 | lxml
11 |
--------------------------------------------------------------------------------
/train_dual_decoder.py:
--------------------------------------------------------------------------------
1 | '''
2 | Trains encoder-dual-decoder model
3 | '''
4 | import torch.backends.cudnn as cudnn
5 | import torch.optim
6 | import torch.utils.data
7 | import torchvision.transforms as transforms
8 | from torch import nn
9 | from models import Encoder, DualDecoder
10 | from datasets import *
11 | from utils import *
12 | import argparse
13 | import sys
14 | from glob import glob
15 | import time
16 |
17 | def create_model():
18 | encoder = Encoder(args.encoded_image_size,
19 | use_RNN=args.use_RNN,
20 | rnn_size=args.encoder_RNN_size,
21 | last_layer_stride=args.cnn_stride if isinstance(args.cnn_stride, int) else None)
22 | encoder.fine_tune(args.fine_tune_encoder)
23 | encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
24 | lr=args.encoder_lr) if args.fine_tune_encoder else None
25 |
26 | decoder = DualDecoder(tag_attention_dim=args.tag_attention_dim,
27 | cell_attention_dim=args.cell_attention_dim,
28 | tag_embed_dim=args.tag_embed_dim,
29 | cell_embed_dim=args.cell_embed_dim,
30 | tag_decoder_dim=args.tag_decoder_dim,
31 | language_dim=args.language_dim,
32 | cell_decoder_dim=args.cell_decoder_dim,
33 | tag_vocab_size=len(word_map['word_map_tag']),
34 | cell_vocab_size=len(word_map['word_map_cell']),
35 | td_encode=(word_map['word_map_tag'][''], word_map['word_map_tag']['>']),
36 | decoder_cell=nn.LSTMCell if args.decoder_cell == 'LSTM' else nn.GRUCell,
37 | encoder_dim=512,
38 | dropout=args.dropout,
39 | cell_decoder_type=args.cell_decoder_type,
40 | cnn_layer_stride=args.cnn_stride if isinstance(args.cnn_stride, dict) else None,
41 | tag_H_grad=not args.detach,
42 | predict_content=args.predict_content,
43 | predict_bbox=args.predict_bbox)
44 | decoder.fine_tune_tag_decoder(args.fine_tune_tag_decoder)
45 | tag_decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.tag_decoder.parameters()),
46 | lr=args.tag_decoder_lr) if args.fine_tune_tag_decoder else None
47 |
48 | cell_decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.cell_decoder.parameters()),
49 | lr=args.cell_decoder_lr) if args.predict_content else None
50 | cell_bbox_regressor_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.cell_bbox_regressor.parameters()),
51 | lr=args.cell_bbox_regressor_lr) if args.predict_bbox else None
52 | return encoder, decoder, encoder_optimizer, tag_decoder_optimizer, cell_decoder_optimizer, cell_bbox_regressor_optimizer
53 |
54 | def load_checkpoint(checkpoint):
55 | # Wait until model file exists
56 | if not os.path.isfile(checkpoint):
57 | while not os.path.isfile(checkpoint):
58 | print('Model not found, retry in 10 minutes', file=sys.stderr)
59 | sys.stderr.flush()
60 | time.sleep(600)
61 | # Make sure model file is saved completely
62 | time.sleep(10)
63 |
64 | checkpoint = torch.load(checkpoint)
65 | start_epoch = checkpoint['epoch'] + 1
66 |
67 | encoder = checkpoint['encoder']
68 | encoder_optimizer = checkpoint['encoder_optimizer']
69 | encoder.fine_tune(args.fine_tune_encoder)
70 | if args.fine_tune_encoder:
71 | if encoder_optimizer is None:
72 | encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
73 | lr=args.encoder_lr)
74 | elif encoder_optimizer.param_groups[0]['lr'] != args.encoder_lr:
75 | change_learning_rate(encoder_optimizer, args.encoder_lr)
76 | print('Encoder LR changed to %f' % args.encoder_lr, file=sys.stderr)
77 | sys.stderr.flush()
78 |
79 | decoder = checkpoint['decoder']
80 | decoder.tag_H_grad = not args.detach
81 | decoder.tag_decoder.tag_H_grad = not args.detach
82 | decoder.predict_content = args.predict_content
83 | decoder.predict_bbox = args.predict_bbox
84 |
85 | tag_decoder_optimizer = checkpoint['tag_decoder_optimizer']
86 | decoder.fine_tune_tag_decoder(args.fine_tune_tag_decoder)
87 | if args.fine_tune_tag_decoder:
88 | if tag_decoder_optimizer is None:
89 | tag_decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.tag_decoder.parameters()),
90 | lr=args.tag_decoder_lr)
91 | elif tag_decoder_optimizer.param_groups[0]['lr'] != args.tag_decoder_lr:
92 | change_learning_rate(tag_decoder_optimizer, args.tag_decoder_lr)
93 | print('Tag Decoder LR changed to %f' % args.tag_decoder_lr, file=sys.stderr)
94 | sys.stderr.flush()
95 |
96 | cell_decoder_optimizer = checkpoint['cell_decoder_optimizer']
97 | if args.predict_content:
98 | if cell_decoder_optimizer is None:
99 | cell_decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.cell_decoder.parameters()),
100 | lr=args.cell_decoder_lr)
101 | elif cell_decoder_optimizer.param_groups[0]['lr'] != args.cell_decoder_lr:
102 | change_learning_rate(cell_decoder_optimizer, args.cell_decoder_lr)
103 | print('Cell Decoder LR changed to %f' % args.cell_decoder_lr, file=sys.stderr)
104 | sys.stderr.flush()
105 |
106 | cell_bbox_regressor_optimizer = checkpoint['cell_bbox_regressor_optimizer']
107 | if args.predict_bbox:
108 | if cell_bbox_regressor_optimizer is None:
109 | cell_bbox_regressor_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.cell_bbox_regressor.parameters()),
110 | lr=args.cell_bbox_regressor_lr)
111 | elif cell_bbox_regressor_optimizer.param_groups[0]['lr'] != args.cell_bbox_regressor_lr:
112 | change_learning_rate(cell_bbox_regressor_optimizer, args.cell_bbox_regressor_lr)
113 | print('Cell bbox regressor LR changed to %f' % args.cell_bbox_regressor_lr, file=sys.stderr)
114 | sys.stderr.flush()
115 |
116 | return start_epoch, encoder, decoder, encoder_optimizer, tag_decoder_optimizer, cell_decoder_optimizer, cell_bbox_regressor_optimizer
117 |
118 |
119 | if __name__ == '__main__':
120 | parser = argparse.ArgumentParser(description='Train encoder-dual-decoder table2html model')
121 | parser.add_argument('--cnn_stride', default=2, type=json.loads, help='stride for last CNN layer in encoder')
122 | parser.add_argument('--tag_embed_dim', default=16, type=int, help='embedding dimension')
123 | parser.add_argument('--cell_embed_dim', default=16, type=int, help='embedding dimension')
124 | parser.add_argument('--encoded_image_size', default=14, type=int, help='encoded image size')
125 | parser.add_argument('--tag_attention_dim', default=512, type=int, help='tag attention dimension')
126 | parser.add_argument('--cell_attention_dim', default=512, type=int, help='tag attention dimension')
127 | parser.add_argument('--language_dim', default=512, type=int, help='language model dimension')
128 | parser.add_argument('--tag_decoder_dim', default=512, type=int, help='tag decoder dimension')
129 | parser.add_argument('--cell_decoder_dim', default=512, type=int, help='cell decoder dimension')
130 | parser.add_argument('--dropout', default=0.5, type=float, help='dropout')
131 | parser.add_argument('--epochs', default=10, type=int, help='epochs to train')
132 | parser.add_argument('--batch_size', default=16, type=int, help='batch size')
133 | parser.add_argument('--encoder_lr', default=0.001, type=float, help='encoder learning rate')
134 | parser.add_argument('--tag_decoder_lr', default=0.001, type=float, help='tag decoder learning rate')
135 | parser.add_argument('--cell_decoder_lr', default=0.001, type=float, help='cell decoder learning rate')
136 | parser.add_argument('--cell_bbox_regressor_lr', default=0.001, type=float, help='cell bbox regressor learning rate')
137 | parser.add_argument('--grad_clip', default=5., type=float, help='clip gradients at an absolute value')
138 | parser.add_argument('--alpha_tag', default=0., type=float, help='regularization parameter in tag decoder for doubly stochastic attention')
139 | parser.add_argument('--alpha_cell', default=0., type=float, help='regularization parameter in cell decoder for doubly stochastic attention')
140 | parser.add_argument('--tag_loss_weight', default=0.5, type=float, help='weight of tag loss')
141 | parser.add_argument('--cell_loss_weight', default=0.5, type=float, help='weight of cell content loss')
142 | parser.add_argument('--cell_bbox_loss_weight', default=0.0, type=float, help='weight of cell bbox loss')
143 | parser.add_argument('--print_freq', default=100, type=int, help='verbose frequency')
144 | parser.add_argument('--fine_tune_encoder', dest='fine_tune_encoder', action='store_true', help='fine-tune encoder')
145 | parser.add_argument('--fine_tune_tag_decoder', dest='fine_tune_tag_decoder', action='store_true', help='fine-tune tag decoder')
146 | parser.add_argument('--cell_decoder_type', default=1, type=int, help='Type of cell decoder (1: baseline, 2: with LM)')
147 | parser.add_argument('--decoder_cell', default='LSTM', type=str, help='RNN Cell (LSTM or GRU)')
148 | parser.add_argument('--use_RNN', dest='use_RNN', action='store_true', help='transform image features with LSTM')
149 | parser.add_argument('--detach', dest='detach', default=False, action='store_true', help='detach the hidden state between structure and cell decoders')
150 | parser.add_argument('--encoder_RNN_size', default=512, type=int, help='LSTM size for the encoder')
151 | parser.add_argument('--checkpoint', default=None, type=str, help='path to checkpoint')
152 | parser.add_argument('--data_folder', default='data/pubmed_dual', type=str, help='path to folder with data files saved by create_input_files.py')
153 | parser.add_argument('--data_name', default='pubmed_False_keep_AR_300_max_tag_len_100_max_cell_len_512_max_image_size', type=str, help='base name shared by data files')
154 | parser.add_argument('--out_dir', type=str, help='path to save checkpoints')
155 | parser.add_argument('--resume', dest='resume', action='store_true', help='Resume from latest checkpoint if exists')
156 | parser.add_argument('--predict_content', dest='predict_content', default=False, action='store_true', help='Predict cell content')
157 | parser.add_argument('--predict_bbox', dest='predict_bbox', default=False, action='store_true', help='Predict cell bbox')
158 |
159 | args = parser.parse_args()
160 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # sets device for model and PyTorch tensors
161 | cudnn.benchmark = True # set to true only if inputs to model are fixed size; otherwise lot of computational overhead
162 |
163 | # Read word map
164 | word_map_file = os.path.join(args.data_folder, 'WORDMAP_' + args.data_name + '.json')
165 | with open(word_map_file, 'r') as j:
166 | word_map = json.load(j)
167 |
168 | # Initialize / load checkpoint
169 | if args.resume:
170 | existing_ckps = glob(os.path.join(args.out_dir, args.data_name, 'checkpoint_*.pth.tar'))
171 | existing_ckps = [ckp for ckp in existing_ckps if len(os.path.basename(ckp).split('_')) == 2]
172 | if existing_ckps:
173 | existing_ckps = sorted(existing_ckps, key=lambda x: int(os.path.basename(x).split('.')[0].split('_')[1]))
174 | latest_ckp = existing_ckps[-1]
175 | if args.checkpoint is not None:
176 | latest_epoch = int(os.path.basename(latest_ckp).split('.')[0].split('_')[1])
177 | checkpoint_epoch = int(os.path.basename(args.checkpoint).split('.')[0].split('_')[1])
178 | if latest_epoch > checkpoint_epoch:
179 | print('Resume from latest checkpoint: %s' % latest_ckp, file=sys.stderr)
180 | start_epoch, encoder, decoder, encoder_optimizer, tag_decoder_optimizer, cell_decoder_optimizer, cell_bbox_regressor_optimizer = load_checkpoint(latest_ckp)
181 | else:
182 | print('Start from checkpoint: %s' % args.checkpoint, file=sys.stderr)
183 | start_epoch, encoder, decoder, encoder_optimizer, tag_decoder_optimizer, cell_decoder_optimizer, cell_bbox_regressor_optimizer = load_checkpoint(args.checkpoint)
184 | else:
185 | print('Resume from latest checkpoint: %s' % latest_ckp, file=sys.stderr)
186 | start_epoch, encoder, decoder, encoder_optimizer, tag_decoder_optimizer, cell_decoder_optimizer, cell_bbox_regressor_optimizer = load_checkpoint(latest_ckp)
187 | elif args.checkpoint is not None:
188 | start_epoch, encoder, decoder, encoder_optimizer, tag_decoder_optimizer, cell_decoder_optimizer, cell_bbox_regressor_optimizer = load_checkpoint(args.checkpoint)
189 | else:
190 | encoder, decoder, encoder_optimizer, tag_decoder_optimizer, cell_decoder_optimizer, cell_bbox_regressor_optimizer = create_model()
191 | start_epoch = 0
192 | else:
193 | if args.checkpoint is not None:
194 | start_epoch, encoder, decoder, encoder_optimizer, tag_decoder_optimizer, cell_decoder_optimizer, cell_bbox_regressor_optimizer = load_checkpoint(args.checkpoint)
195 | else:
196 | encoder, decoder, encoder_optimizer, tag_decoder_optimizer, cell_decoder_optimizer, cell_bbox_regressor_optimizer = create_model()
197 | start_epoch = 0
198 |
199 | # Move to GPU, if available
200 | if torch.cuda.device_count() > 1:
201 | print('Using %d GPUs' % torch.cuda.device_count(), file=sys.stderr)
202 | if not hasattr(encoder, 'module'):
203 | print('Parallelize encoder', file=sys.stderr)
204 | encoder = MyDataParallel(encoder)
205 | if not hasattr(decoder.tag_decoder, 'module'):
206 | print('Parallelize tag decoder', file=sys.stderr)
207 | decoder.tag_decoder = MyDataParallel(decoder.tag_decoder)
208 | if not hasattr(decoder.cell_decoder, 'module'):
209 | print('Parallelize cell decoder', file=sys.stderr)
210 | decoder.cell_decoder = MyDataParallel(decoder.cell_decoder)
211 | decoder = decoder.to(device)
212 | encoder = encoder.to(device)
213 |
214 | # Loss function
215 | criterion = {'tag': nn.CrossEntropyLoss().to(device),
216 | 'cell': nn.CrossEntropyLoss().to(device)}
217 |
218 | # mean and std of PubMed Central table images
219 | normalize = transforms.Normalize(mean=[0.94247851, 0.94254675, 0.94292611],
220 | std=[0.17910956, 0.17940403, 0.17931663])
221 | mode = 'tag'
222 | if args.predict_content:
223 | mode += '+cell'
224 | if args.predict_bbox:
225 | mode += '+bbox'
226 | train_loader = TagCellDataset(args.data_folder, args.data_name, 'TRAIN',
227 | batch_size=args.batch_size, mode=mode,
228 | transform=transforms.Compose([normalize]))
229 |
230 | # Epochs
231 | for epoch in range(start_epoch, args.epochs):
232 | # One epoch's training
233 | decoder.train_epoch(train_loader=train_loader,
234 | encoder=encoder,
235 | criterion=criterion,
236 | encoder_optimizer=encoder_optimizer,
237 | tag_decoder_optimizer=tag_decoder_optimizer,
238 | cell_decoder_optimizer=cell_decoder_optimizer,
239 | cell_bbox_regressor_optimizer=cell_decoder_optimizer,
240 | epoch=epoch,
241 | args=args)
242 |
243 | # Save checkpoint
244 | save_checkpoint_dual(args.out_dir, args.data_name, epoch, encoder, decoder, encoder_optimizer,
245 | tag_decoder_optimizer, cell_decoder_optimizer, cell_bbox_regressor_optimizer)
246 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | from torch import nn
5 | from torch.nn.parallel._functions import Scatter, Gather
6 | from PIL import Image, ImageOps
7 | from math import ceil
8 |
9 | def scatter(inputs, target_gpus, dim=0):
10 | r"""
11 | Slices tensors into approximately equal chunks and
12 | distributes them across given GPUs. Duplicates
13 | references to objects that are not tensors.
14 | """
15 | def scatter_map(obj):
16 | if isinstance(obj, torch.Tensor):
17 | return Scatter.apply(target_gpus, None, dim, obj)
18 | if isinstance(obj, tuple) and len(obj) > 0:
19 | return list(zip(*map(scatter_map, obj)))
20 | if isinstance(obj, list) and len(obj) > 0:
21 | per_gpu = ceil(len(obj) / len(target_gpus))
22 | partition = [obj[k * per_gpu: min(len(obj), (k + 1) * per_gpu)] for k, _ in enumerate(target_gpus)]
23 | for i, target in zip(range(len(partition)), target_gpus):
24 | for j in range(len(partition[i])):
25 | partition[i][j] = partition[i][j].to(torch.device('cuda:%d' % target))
26 | return partition
27 | # return list(map(list, zip(*map(scatter_map, obj))))
28 | if isinstance(obj, dict) and len(obj) > 0:
29 | return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
30 | return [obj for targets in target_gpus]
31 |
32 | # After scatter_map is called, a scatter_map cell will exist. This cell
33 | # has a reference to the actual function scatter_map, which has references
34 | # to a closure that has a reference to the scatter_map cell (because the
35 | # fn is recursive). To avoid this reference cycle, we set the function to
36 | # None, clearing the cell
37 | try:
38 | return scatter_map(inputs)
39 | finally:
40 | scatter_map = None
41 |
42 | def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
43 | r"""Scatter with support for kwargs dictionary"""
44 | inputs = scatter(inputs, target_gpus, dim) if inputs else []
45 | kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
46 | if len(inputs) < len(kwargs):
47 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
48 | elif len(kwargs) < len(inputs):
49 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
50 | inputs = tuple(inputs)
51 | kwargs = tuple(kwargs)
52 | return inputs, kwargs
53 |
54 | def gather(outputs, target_device, dim=0):
55 | r"""
56 | Gathers tensors from different GPUs on a specified device
57 | (-1 means the CPU).
58 | """
59 | def gather_map(outputs):
60 | out = outputs[0]
61 | if isinstance(out, torch.Tensor):
62 | return Gather.apply(target_device, dim, *outputs)
63 | if out is None:
64 | return None
65 | if isinstance(out, dict):
66 | if not all((len(out) == len(d) for d in outputs)):
67 | raise ValueError('All dicts must have the same number of keys')
68 | return type(out)(((k, gather_map([d[k] for d in outputs]))
69 | for k in out))
70 | if isinstance(out, list):
71 | return [item for output in outputs for item in output]
72 | return type(out)(map(gather_map, zip(*outputs)))
73 |
74 | # Recursive function calls like this create reference cycles.
75 | # Setting the function to None clears the refcycle.
76 | try:
77 | return gather_map(outputs)
78 | finally:
79 | gather_map = None
80 |
81 | class MyDataParallel(nn.DataParallel):
82 | def __init__(self, model):
83 | super(MyDataParallel, self).__init__(model)
84 |
85 | def __getattr__(self, name):
86 | try:
87 | return super(MyDataParallel, self).__getattr__(name)
88 | except AttributeError:
89 | return getattr(self.module, name)
90 |
91 | def scatter(self, inputs, kwargs, device_ids):
92 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
93 |
94 | def gather(self, outputs, output_device):
95 | return gather(outputs, output_device, dim=self.dim)
96 |
97 | def init_embedding(embeddings):
98 | """
99 | Fills embedding tensor with values from the uniform distribution.
100 |
101 | :param embeddings: embedding tensor
102 | """
103 | bias = np.sqrt(3.0 / embeddings.size(1))
104 | torch.nn.init.uniform_(embeddings, -bias, bias)
105 |
106 |
107 | def load_embeddings(emb_file, word_map):
108 | """
109 | Creates an embedding tensor for the specified word map, for loading into the model.
110 |
111 | :param emb_file: file containing embeddings (stored in GloVe format)
112 | :param word_map: word map
113 | :return: embeddings in the same order as the words in the word map, dimension of embeddings
114 | """
115 |
116 | # Find embedding dimension
117 | with open(emb_file, 'r') as f:
118 | emb_dim = len(f.readline().split(' ')) - 1
119 |
120 | vocab = set(word_map.keys())
121 |
122 | # Create tensor to hold embeddings, initialize
123 | embeddings = torch.FloatTensor(len(vocab), emb_dim)
124 | init_embedding(embeddings)
125 |
126 | # Read embedding file
127 | print("\nLoading embeddings...")
128 | for line in open(emb_file, 'r'):
129 | line = line.split(' ')
130 |
131 | emb_word = line[0]
132 | embedding = list(map(lambda t: float(t), filter(lambda n: n and not n.isspace(), line[1:])))
133 |
134 | # Ignore word if not in train_vocab
135 | if emb_word not in vocab:
136 | continue
137 |
138 | embeddings[word_map[emb_word]] = torch.FloatTensor(embedding)
139 |
140 | return embeddings, emb_dim
141 |
142 |
143 | def clip_gradient(optimizer, grad_clip):
144 | """
145 | Clips gradients computed during backpropagation to avoid explosion of gradients.
146 |
147 | :param optimizer: optimizer with the gradients to be clipped
148 | :param grad_clip: clip value
149 | """
150 | for group in optimizer.param_groups:
151 | for param in group['params']:
152 | if param.grad is not None:
153 | param.grad.data.clamp_(-grad_clip, grad_clip)
154 |
155 |
156 | def save_checkpoint(out_dir, data_name, epoch, encoder, decoder, encoder_optimizer, decoder_optimizer):
157 | """
158 | Saves model checkpoint.
159 | :param out_dir: output dir
160 | :param data_name: base name of processed dataset
161 | :param epoch: epoch number
162 | :param encoder: encoder model
163 | :param decoder: decoder model
164 | :param encoder_optimizer: optimizer to update encoder's weights, if fine-tuning
165 | :param decoder_optimizer: optimizer to update decoder's weights
166 | """
167 | state = {'epoch': epoch,
168 | 'encoder': encoder,
169 | 'decoder': decoder,
170 | 'encoder_optimizer': encoder_optimizer,
171 | 'decoder_optimizer': decoder_optimizer}
172 | filename = 'checkpoint_' + str(epoch) + '.pth.tar'
173 | try:
174 | if not os.path.exists(os.path.join(out_dir, data_name)):
175 | os.makedirs(os.path.join(out_dir, data_name))
176 | torch.save(state, os.path.join(out_dir, data_name, filename))
177 | except Exception:
178 | torch.save(state, os.path.join(os.environ['RESULT_DIR'], filename))
179 |
180 | def save_checkpoint_dual(out_dir, data_name, epoch,
181 | encoder, decoder, encoder_optimizer,
182 | tag_decoder_optimizer, cell_decoder_optimizer,
183 | cell_bbox_regressor_optimizer):
184 | """
185 | Saves EDD model checkpoint.
186 | :param out_dir: output dir
187 | :param data_name: base name of processed dataset
188 | :param epoch: epoch number
189 | :param encoder: encoder model
190 | :param decoder: decoder model
191 | :param encoder_optimizer: optimizer to update encoder's weights, if fine-tuning
192 | :param tag_decoder_optimizer: optimizer to update tag decoder's weights
193 | :param cell_decoder_optimizer: optimizer to update cell decoder's weights
194 | :param cell_bbox_regressor_optimizer: optimizer to update cell bbox regressor's weights
195 | """
196 | state = {'epoch': epoch,
197 | 'encoder': encoder,
198 | 'decoder': decoder,
199 | 'encoder_optimizer': encoder_optimizer,
200 | 'tag_decoder_optimizer': tag_decoder_optimizer,
201 | 'cell_decoder_optimizer': cell_decoder_optimizer,
202 | 'cell_bbox_regressor_optimizer': cell_bbox_regressor_optimizer}
203 | filename = 'checkpoint_' + str(epoch) + '.pth.tar'
204 | if not os.path.exists(os.path.join(out_dir, data_name)):
205 | os.makedirs(os.path.join(out_dir, data_name))
206 | torch.save(state, os.path.join(out_dir, data_name, filename))
207 |
208 | class AverageMeter(object):
209 | """
210 | Keeps track of most recent, average, sum, and count of a metric.
211 | """
212 |
213 | def __init__(self):
214 | self.reset()
215 |
216 | def reset(self):
217 | self.val = 0
218 | self.avg = 0
219 | self.sum = 0
220 | self.count = 0
221 |
222 | def update(self, val, n=1):
223 | self.val = val
224 | self.sum += val * n
225 | self.count += n
226 | self.avg = self.sum / self.count
227 |
228 |
229 | def adjust_learning_rate(optimizer, shrink_factor):
230 | """
231 | Shrinks learning rate by a specified factor.
232 |
233 | :param optimizer: optimizer whose learning rate must be shrunk.
234 | :param shrink_factor: factor in interval (0, 1) to multiply learning rate with.
235 | """
236 |
237 | print("\nDECAYING learning rate.")
238 | for param_group in optimizer.param_groups:
239 | param_group['lr'] = param_group['lr'] * shrink_factor
240 | print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],))
241 |
242 | def change_learning_rate(optimizer, new_lr):
243 | """
244 | Change learning rate.
245 |
246 | :param optimizer: optimizer whose learning rate must be shrunk.
247 | :param new_lr: new learning rate.
248 | """
249 | for param_group in optimizer.param_groups:
250 | param_group['lr'] = new_lr
251 |
252 | def image_resize(imagepath, image_size, keep_AR=True):
253 | with Image.open(imagepath) as im:
254 | old_size = im.size # old_size[0] is in (width, height) format
255 | if keep_AR:
256 | ratio = float(image_size) / max(old_size)
257 | new_size = tuple([int(x * ratio) for x in old_size])
258 | im = im.resize(new_size, Image.Resampling.LANCZOS)
259 | delta_w = image_size - new_size[0]
260 | delta_h = image_size - new_size[1]
261 | padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
262 | new_im = ImageOps.expand(im, padding)
263 | else:
264 | new_im = im.resize((image_size, image_size), Image.Resampling.LANCZOS)
265 | return new_im, old_size
266 |
267 | def image_rescale(imagepath, image_size, keep_AR=True, transpose=True, return_size=False):
268 | new_im, old_size = image_resize(imagepath, image_size, keep_AR)
269 | img = np.array(new_im)
270 | if img.shape[2] > 3:
271 | img = img[:, :, :3]
272 | if transpose:
273 | img = img.transpose(2, 0, 1)
274 | if return_size:
275 | return img, old_size
276 | else:
277 | return img
278 |
279 | def accuracy(scores, targets, k):
280 | """
281 | Computes top-k accuracy, from predicted and true labels.
282 |
283 | :param scores: scores from the model
284 | :param targets: true labels
285 | :param k: k in top-k accuracy
286 | :return: top-k accuracy
287 | """
288 |
289 | batch_size = targets.size(0)
290 | _, ind = scores.topk(k, 1, True, True)
291 | correct = ind.eq(targets.view(-1, 1).expand_as(ind))
292 | correct_total = correct.view(-1).float().sum() # 0D tensor
293 | return correct_total.item() * (100.0 / batch_size)
294 |
295 |
296 | if __name__ == '__main__':
297 | pass
298 |
--------------------------------------------------------------------------------
| | | | |