├── .gitignore ├── DejaVuSansMono-Bold.ttf ├── LICENSE.txt ├── README.md ├── convert.py ├── datasets.py ├── eval.py ├── img ├── DualDecoderArch.png └── tokenization.png ├── metric.py ├── models.py ├── parallel.py ├── prepare_data.py ├── requirements.txt ├── train_dual_decoder.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | -------------------------------------------------------------------------------- /DejaVuSansMono-Bold.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ibm-aur-nlp/EDD/e6bb2bd509d7c89cdf1fb9608f9db9a044413bed/DejaVuSansMono-Bold.ttf -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright 2021 IBM 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image-based table recognition: data, model, evaluation 2 | 3 | ## Task 4 | 5 | Converting table images into HTML code 6 | 7 | ## Dataset 8 | 9 | [PubTabNet](https://github.com/ibm-aur-nlp/PubTabNet) contains over 500k table 10 | images annotated with the corresponding HTML representation. 11 | 12 | ## Model 13 | 14 | Encoder-Dual-Decoder (EDD) 15 | 16 | ![Encoder-Dual-Decoder (EDD)](img/DualDecoderArch.png "Encoder-Dual-Decoder (EDD)") 17 | 18 | ## Evaluation 19 | 20 | **T**ree-**E**dit-**D**istance-based **S**imilarity (TEDS) 21 | 22 | `TEDS(T_1, T_2) = 1 - EditDistance(T_1, T_2) / max(|T_1|, |T_2|)`, where `EditDistance(T_1, T_2)` is the tree edit distance between `T_1` and `T_2`, and `|T|` is the number of nodes in `T`. 23 | 24 | ## Installation 25 | 26 | Please use python 3 (>=3.6) environment. 27 | 28 | `pip install -r requirements` 29 | 30 | ## Training and testing on PubTabNet 31 | 32 | ### Prepare data 33 | 34 | Download PubTabNet and extract the files into the following file structure 35 | ``` 36 | {DATA_DIR} 37 | | 38 | -- train 39 | | 40 | -- PMCXXXXXXX.png 41 | -- ... 42 | -- val 43 | | 44 | -- PMCXXXXXXX.png 45 | -- ... 46 | -- test 47 | | 48 | -- PMCXXXXXXX.png 49 | -- ... 50 | -- PubTabNet_2.0.0.jsonl 51 | ``` 52 | 53 | Prepare data for training 54 | ``` 55 | python prepare_data.py \ 56 | --annotation {DATA_DIR}/PubTabNet_2.0.0.jsonl \ 57 | --image_dir {DATA_DIR} \ 58 | --out_dir {TRAIN_DATA_DIR} 59 | ``` 60 | 61 | The following files will be generated in {TRAIN_DATA_DIR}: 62 | ``` 63 | - TRAIN_IMAGES_{POSTFIX}.h5 # Training images 64 | - TRAIN_TAGS_{POSTFIX}.json # Training structural tokens 65 | - TRAIN_TAGLENS_{POSTFIX}.json # Length of training structural tokens 66 | - TRAIN_CELLS_{POSTFIX}.json # Training cell tokens 67 | - TRAIN_CELLLENS_{POSTFIX}.json # Length of training cell tokens 68 | - TRAIN_CELLBBOXES_{POSTFIX}.json # Training cell bboxes 69 | - VAL.json # Validation ground truth 70 | - WORDMAP_{POSTFIX}.json # Vocab 71 | ``` 72 | where `{POSTFIX}` is `PubTabNet_False_keep_AR_300_max_tag_len_100_max_cell_len_512_max_image_size` 73 | ### Train tag decoder 74 | 75 | Use larger (0.001) learning rate in the first 10 epochs 76 | ``` 77 | python train_dual_decoder.py \ 78 | --out_dir {CHECKPOINT_DIR} \ 79 | --data_folder {TRAIN_DATA_DIR} \ 80 | --data_name PubTabNet_False_keep_AR_300_max_tag_len_100_max_cell_len_512_max_image_size \ 81 | --epochs 10 \ 82 | --batch_size 10 \ 83 | --fine_tune_encoder \ 84 | --encoder_lr 0.001 \ 85 | --fine_tune_tag_decoder \ 86 | --tag_decoder_lr 0.001 \ 87 | --tag_loss_weight 1.0 \ 88 | --cell_decoder_lr 0.001 \ 89 | --cell_loss_weight 0.0 \ 90 | --tag_embed_dim 16 \ 91 | --cell_embed_dim 80 \ 92 | --encoded_image_size 28 \ 93 | --decoder_cell LSTM \ 94 | --tag_attention_dim 256 \ 95 | --cell_attention_dim 256 \ 96 | --tag_decoder_dim 256 \ 97 | --cell_decoder_dim 512 \ 98 | --cell_decoder_type 1 \ 99 | --cnn_stride '{"tag":1, "cell":1}' \ 100 | --resume 101 | ``` 102 | 103 | Use smaller (0.0001) learning rate for another 3 epochs 104 | ``` 105 | python train_dual_decoder.py \ 106 | --out_dir {CHECKPOINT_DIR} \ 107 | --data_folder {TRAIN_DATA_DIR} \ 108 | --data_name PubTabNet_False_keep_AR_300_max_tag_len_100_max_cell_len_512_max_image_size \ 109 | --epochs 13 \ 110 | --batch_size 10 \ 111 | --fine_tune_encoder \ 112 | --encoder_lr 0.0001 \ 113 | --fine_tune_tag_decoder \ 114 | --tag_decoder_lr 0.0001 \ 115 | --tag_loss_weight 1.0 \ 116 | --cell_decoder_lr 0.001 \ 117 | --cell_loss_weight 0.0 \ 118 | --tag_embed_dim 16 \ 119 | --cell_embed_dim 80 \ 120 | --encoded_image_size 28 \ 121 | --decoder_cell LSTM \ 122 | --tag_attention_dim 256 \ 123 | --cell_attention_dim 256 \ 124 | --tag_decoder_dim 256 \ 125 | --cell_decoder_dim 512 \ 126 | --cell_decoder_type 1 \ 127 | --cnn_stride '{"tag":1, "cell":1}' \ 128 | --resume 129 | ``` 130 | 131 | ### Train dual decoders 132 | 133 | **NOTE**: 134 | - Sometimes when a random batch is too large, it may exceeds the GPU memory. When this happens, just re-execute the training command, which will resume from the latest checkpoint. 135 | - Training dual decoders requires 2 V100 GPUs. 136 | 137 | Use larger (0.001) learning rate in the first 10 epochs 138 | ``` 139 | python train_dual_decoder.py \ 140 | --checkpoint {CHECKPOINT_DIR}/PubTabNet_False_keep_AR_300_max_tag_len_100_max_cell_len_512_max_image_size/checkpoint_12.pth.tar \ 141 | --out_dir {CHECKPOINT_DIR}/cell_decoder \ 142 | --data_folder {TRAIN_DATA_DIR} \ 143 | --data_name PubTabNet_False_keep_AR_300_max_tag_len_100_max_cell_len_512_max_image_size \ 144 | --epochs 23 \ 145 | --batch_size 8 \ 146 | --fine_tune_encoder \ 147 | --encoder_lr 0.001 \ 148 | --fine_tune_tag_decoder \ 149 | --tag_decoder_lr 0.001 \ 150 | --tag_loss_weight 0.5 \ 151 | --cell_decoder_lr 0.001 \ 152 | --cell_loss_weight 0.5 \ 153 | --tag_embed_dim 16 \ 154 | --cell_embed_dim 80 \ 155 | --encoded_image_size 28 \ 156 | --decoder_cell LSTM \ 157 | --tag_attention_dim 256 \ 158 | --cell_attention_dim 256 \ 159 | --tag_decoder_dim 256 \ 160 | --cell_decoder_dim 512 \ 161 | --cell_decoder_type 1 \ 162 | --cnn_stride '{"tag":1, "cell":1}' \ 163 | --resume \ 164 | --predict_content 165 | ``` 166 | 167 | Use smaller (0.0001) learning rate for another 2 epochs 168 | ``` 169 | python train_dual_decoder.py \ 170 | --out_dir {CHECKPOINT_DIR}/cell_decoder \ 171 | --data_folder {TRAIN_DATA_DIR} \ 172 | --data_name PubTabNet_False_keep_AR_300_max_tag_len_100_max_cell_len_512_max_image_size \ 173 | --epochs 25 \ 174 | --batch_size 8 \ 175 | --fine_tune_encoder \ 176 | --encoder_lr 0.0001 \ 177 | --fine_tune_tag_decoder \ 178 | --tag_decoder_lr 0.0001 \ 179 | --tag_loss_weight 0.5 \ 180 | --cell_decoder_lr 0.0001 \ 181 | --cell_loss_weight 0.5 \ 182 | --tag_embed_dim 16 \ 183 | --cell_embed_dim 80 \ 184 | --encoded_image_size 28 \ 185 | --decoder_cell LSTM \ 186 | --tag_attention_dim 256 \ 187 | --cell_attention_dim 256 \ 188 | --tag_decoder_dim 256 \ 189 | --cell_decoder_dim 512 \ 190 | --cell_decoder_type 1 \ 191 | --cnn_stride '{"tag":1, "cell":1}' \ 192 | --resume \ 193 | --predict_content 194 | ``` 195 | 196 | 197 | ### Inferencing 198 | 199 | Get validation performance 200 | ``` 201 | python eval.py \ 202 | --image_folder {DATA_DIR}/val \ 203 | --result_json {RESULT_DIR}/RESULT_FILE.json \ 204 | --gt {TRAIN_DATA_DIR}/VAL.json \ 205 | --model {CHECKPOINT_DIR}/cell_decoder/PubTabNet_False_keep_AR_300_max_tag_len_100_max_cell_len_512_max_image_size/checkpoint_24.pth.tar \ 206 | --word_map {TRAIN_DATA_DIR}/WORDMAP_PubTabNet_False_keep_AR_300_max_tag_len_100_max_cell_len_512_max_image_size.json \ 207 | --image_size 448 \ 208 | --dual_decoder \ 209 | --beam_size '{"tag":3, "cell":3}' \ 210 | --max_steps '{"tag":1800, "cell":600}' 211 | ``` 212 | This will save the TEDS score of every validation sample in `{RESULT_DIR}/RESULT_FILE.json` in the following format: 213 | ``` 214 | { 215 | 'PMCXXXXXXX.png': float, 216 | } 217 | ``` 218 | 219 | Get testing performance 220 | ``` 221 | python eval.py \ 222 | --image_folder {DATA_DIR}/test \ 223 | --result_json {RESULT_DIR}/RESULT_FILE.json \ 224 | --model {CHECKPOINT_DIR}/cell_decoder/PubTabNet_False_keep_AR_300_max_tag_len_100_max_cell_len_512_max_image_size/checkpoint_24.pth.tar \ 225 | --word_map {TRAIN_DATA_DIR}/WORDMAP_PubTabNet_False_keep_AR_300_max_tag_len_100_max_cell_len_512_max_image_size.json \ 226 | --image_size 448 \ 227 | --dual_decoder \ 228 | --beam_size '{"tag":3, "cell":3}' \ 229 | --max_steps '{"tag":1800, "cell":600}' 230 | ``` 231 | This will save the inference result (HTML code) of every testing sample in `{RESULT_DIR}/RESULT_FILE.json` in the following format: 232 | ``` 233 | { 234 | 'PMCXXXXXXX.png': str, 235 | } 236 | ``` 237 | The json file can be compared agains the ground truth using the code [here](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src). The ground truth of test set has been kept secret. 238 | 239 | ## Cite us 240 | 241 | ``` 242 | @article{zhong2019image, 243 | title={Image-based table recognition: data, model, and evaluation}, 244 | author={Zhong, Xu and ShafieiBavani, Elaheh and Yepes, Antonio Jimeno}, 245 | journal={arXiv preprint arXiv:1911.10683}, 246 | year={2019} 247 | } 248 | ``` 249 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import json 4 | import torchvision.transforms as transforms 5 | import skimage.transform 6 | import argparse 7 | from PIL import Image, ImageDraw, ImageFont 8 | from utils import image_rescale, image_resize 9 | from metric import format_html 10 | import os 11 | from glob import glob 12 | from tqdm import tqdm 13 | import shutil 14 | import textwrap 15 | 16 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 17 | 18 | def caption_image_beam_search(encoder, decoder, image_path, word_map, 19 | image_size=448, max_steps=400, beam_size=3, 20 | vis_att=False): 21 | """ 22 | Reads an image and captions it with beam search. 23 | 24 | :param encoder: encoder model 25 | :param decoder: decoder model 26 | :param image_path: path to image 27 | :param word_map: word map 28 | :param beam_size: number of sequences to consider at each decode-step 29 | :return: caption, weights for visualization 30 | """ 31 | # Read image and process 32 | img = image_rescale(image_path, image_size, False) 33 | img = img / 255. 34 | img = torch.FloatTensor(img) 35 | normalize = transforms.Normalize(mean=[0.94247851, 0.94254675, 0.94292611], 36 | std=[0.17910956, 0.17940403, 0.17931663]) 37 | transform = transforms.Compose([normalize]) 38 | image = transform(img).to(device) # (3, image_size, image_size) 39 | 40 | # Encode 41 | image = image.unsqueeze(0) # (1, 3, image_size, image_size) 42 | encoder_out = encoder(image) # (1, enc_image_size, enc_image_size, encoder_dim) 43 | 44 | return decoder.inference(encoder_out, word_map, max_steps, beam_size, return_attention=vis_att) 45 | 46 | def visualize_result(image_path, res, rev_word_map, smooth=True, image_size=448): 47 | """ 48 | Visualizes caption with weights at every word. 49 | 50 | Adapted from paper authors' repo: https://github.com/kelvinxu/arctic-captions/blob/master/alpha_visualization.ipynb 51 | 52 | :param image_path: path to image that has been captioned 53 | :param res: result of inference model 54 | :param rev_word_map: reverse word mapping, i.e. ix2word 55 | :param smooth: smooth weights? 56 | """ 57 | 58 | def vis_attention(c, image, alpha, smooth, image_size, original_size, x_offset=0, cap=None): 59 | alpha = np.array(alpha) 60 | if smooth: 61 | alpha = skimage.transform.pyramid_expand(alpha, upscale=image_size / alpha.shape[0], sigma=4) 62 | else: 63 | alpha = alpha.repeat(image_size / alpha.shape[0], axis=0).repeat(image_size / alpha.shape[0], axis=1) 64 | if cap is None: 65 | alpha = (alpha - np.min(alpha)) / (np.max(alpha) - np.min(alpha)) 66 | else: 67 | alpha *= 1 / cap 68 | alpha[alpha > 1.] = 1. 69 | alpha *= 255. 70 | alpha = alpha.astype('uint8') 71 | alpha = Image.fromarray(alpha) 72 | image = image.convert("RGBA") 73 | alpha = alpha.convert("RGBA") 74 | new_img = Image.blend(image, alpha, 0.6) 75 | new_img = new_img.resize(original_size, Image.LANCZOS) 76 | if c: 77 | font = ImageFont.truetype("DejaVuSansMono-Bold.ttf", 24) 78 | # font = ImageFont.truetype(os.environ["DATA_DIR"] + "/Table2HTML/dejavu/DejaVuSansMono-Bold.ttf", 24) 79 | lines = textwrap.wrap(c, width=25) 80 | w, h = font.getsize(lines[0]) 81 | H = h * len(lines) 82 | y_text = original_size[1] / 2 - H / 2 83 | draw = ImageDraw.Draw(new_img) 84 | for line in lines: 85 | w, h = font.getsize(line) 86 | draw.text(((original_size[0] - w) / 2 + x_offset, y_text), line, (255, 255, 255), font=font) 87 | y_text += h 88 | return new_img 89 | 90 | if len(res) == 2: 91 | tags, cells = res 92 | elif len(res) == 4: 93 | tags, tag_alphas, cells, cell_alphas = res 94 | with open(image_path.replace('.png', '.html'), 'w') as fp: 95 | fp.write(format_html(tags, rev_word_map['tag'], cells, rev_word_map['cell'])) 96 | 97 | if len(res) == 4: 98 | image, original_size = image_resize(image_path, image_size, False) 99 | folder = image_path[:-4] 100 | if os.path.exists(folder): 101 | shutil.rmtree(folder) 102 | os.makedirs(folder) 103 | os.makedirs(os.path.join(folder, 'structure')) 104 | os.makedirs(os.path.join(folder, 'cells')) 105 | 106 | for ind, (c, alpha) in enumerate(zip(tags[1:], tag_alphas[1:]), 1): 107 | if ind <= 50 or len(tags[1:]) - ind <= 50: 108 | new_img = vis_attention(rev_word_map['tag'][c], image, alpha, smooth, image_size, original_size, cap=None) 109 | new_img.save(os.path.join(folder, 'structure', '%03d.png' % (ind)), "PNG") 110 | 111 | for j, (cell, alphas) in enumerate(zip(cells, cell_alphas)): 112 | if cell is not None: 113 | # for ind, (c, alpha) in enumerate(zip(cell[1:], alphas[1:]), 1): 114 | # # if ind <= 5 or len(cell[1:]) - ind <= 5: 115 | # new_img = vis_attention(rev_word_map['cell'][c], image, alpha, smooth, image_size, original_size) 116 | # new_img.save(os.path.join(folder, 'cells', '%03d_%03d.png' % (j, ind)), "PNG") 117 | new_img = vis_attention(''.join([rev_word_map['cell'][c] for c in cell[1:-1]]), 118 | image, 119 | np.mean(alphas[1:-1], axis=0) if len(alphas[1:-1]) else np.mean(alphas[1:], axis=0), 120 | smooth, image_size, original_size, 121 | x_offset=50 if j % 3 == 0 and j > 0 else 0, 122 | cap=None) 123 | new_img.save(os.path.join(folder, 'cells', '%03d.png' % (j)), "PNG") 124 | 125 | 126 | if __name__ == '__main__': 127 | parser = argparse.ArgumentParser(description='Inference on given images') 128 | 129 | parser.add_argument('--input', '-i', help='path to image') 130 | parser.add_argument('--model', '-m', help='path to model') 131 | parser.add_argument('--word_map', '-wm', help='path to word map JSON') 132 | parser.add_argument('--image_size', '-is', default=448, type=int, help='target size of image rescaling') 133 | parser.add_argument('--beam_size', '-b', default={"tag": 3, "cell": 3}, type=json.loads, help='beam size for beam search') 134 | parser.add_argument('--max_steps', '-ms', default=400, type=json.loads, help='max output steps of decoder') 135 | parser.add_argument('--dont_smooth', dest='smooth', action='store_false', help='do not smooth alpha overlay') 136 | parser.add_argument('--vis_attention', dest='vis_attention', action='store_true', help='visualize attention') 137 | 138 | args = parser.parse_args() 139 | 140 | # Load model 141 | checkpoint = torch.load(args.model) 142 | decoder = checkpoint['decoder'] 143 | decoder = decoder.to(device) 144 | decoder.eval() 145 | encoder = checkpoint['encoder'] 146 | encoder = encoder.to(device) 147 | encoder.eval() 148 | 149 | # Load word map (word2ix) 150 | with open(args.word_map, 'r') as j: 151 | word_map = json.load(j) 152 | rev_word_map = {'tag': {v: k for k, v in word_map['word_map_tag'].items()}, 153 | 'cell': {v: k for k, v in word_map['word_map_cell'].items()}} 154 | 155 | if os.path.isfile(args.input): 156 | # Encode, decode with attention and beam search 157 | res = caption_image_beam_search(encoder, decoder, args.input, word_map, args.image_size, args.max_steps, args.beam_size, args.vis_attention) 158 | if res is None: 159 | print('No complete sequence is generated') 160 | else: 161 | # Visualize caption and attention of best sequence 162 | visualize_result(args.input, res, rev_word_map, args.smooth, args.image_size) 163 | elif os.path.exists(args.input): 164 | images = glob(os.path.join(args.input, '*.png')) + glob(os.path.join(args.input, '*.jpg')) 165 | for image in tqdm(images): 166 | # Encode, decode with attention and beam search 167 | try: 168 | res = caption_image_beam_search(encoder, decoder, image, word_map, args.image_size, args.max_steps, args.beam_size, args.vis_attention) 169 | except Exception as e: 170 | print(e) 171 | res = None 172 | if res is not None: 173 | # Visualize caption and attention of best sequence 174 | visualize_result(image, res, rev_word_map, args.smooth, args.image_size) 175 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import h5py 3 | import json 4 | import os 5 | import numpy as np 6 | import random 7 | 8 | class TableDatasetEvenLength(object): 9 | """ 10 | Data loader for training baseline encoder-decoder model (WYGIWYS, Dent et al. 2017) 11 | """ 12 | 13 | def __init__(self, data_folder, data_name, batch_size, transform=None): 14 | # Open hdf5 file where images are stored 15 | f = os.path.join(data_folder, 'TRAIN_IMAGES_' + data_name + '.hdf5') 16 | self.h = h5py.File(f, 'r') 17 | 18 | self.imgs = self.h['images'] 19 | 20 | # Load encoded tables (completely into memory) 21 | with open(os.path.join(data_folder, 'TRAIN_TABLES_' + data_name + '.json'), 'r') as j: 22 | self.tables = json.load(j) 23 | 24 | # Load table lengths (completely into memory) 25 | with open(os.path.join(data_folder, 'TRAIN_TABLELENS_' + data_name + '.json'), 'r') as j: 26 | self.tablelens = json.load(j) 27 | 28 | # PyTorch transformation pipeline for the image (normalizing, etc.) 29 | self.transform = transform 30 | self.batch_size = batch_size 31 | self.batch_id = 0 32 | 33 | def shuffle(self): 34 | self.batch_id = 0 35 | self.batches = [[]] 36 | len_dict = dict() 37 | # Split samples into groups by table lengths 38 | for i, l in enumerate(self.tablelens): 39 | if l in len_dict: 40 | len_dict[l].append(i) 41 | else: 42 | len_dict[l] = [i] 43 | # Fill with long samples first, so that the samples do not need to be sorted before training 44 | lens = sorted(list(len_dict.keys()), key=lambda x: -x) 45 | # Shuffle each group 46 | for l in lens: 47 | random.shuffle(len_dict[l]) 48 | # Generate batches 49 | for l in lens: 50 | k = 0 51 | # Fill previous incomplete batch 52 | if len(self.batches[-1]) < self.batch_size: 53 | deficit = min(len(len_dict[l]), self.batch_size - len(self.batches[-1])) 54 | self.batches[-1] += len_dict[l][k:k + deficit] 55 | k = deficit 56 | # Generate complete batches 57 | while len(len_dict[l]) - k >= self.batch_size: 58 | self.batches.append(len_dict[l][k:k + self.batch_size]) 59 | k += self.batch_size 60 | # Create an incomplete batch with left overs 61 | if k < len(len_dict[l]): 62 | self.batches.append(len_dict[l][k:]) 63 | # Shuffle the order of batches 64 | random.shuffle(self.batches) 65 | 66 | def __iter__(self): 67 | return self 68 | 69 | def __next__(self): 70 | if self.batch_id < len(self.batches): 71 | samples = self.batches[self.batch_id] 72 | image_size = self.imgs[samples[0]].shape 73 | imgs = torch.zeros(len(samples), image_size[0], image_size[1], image_size[2], dtype=torch.float) 74 | table_size = len(self.tables[samples[0]]) 75 | tables = torch.zeros(len(samples), table_size, dtype=torch.long) 76 | tablelens = torch.zeros(len(samples), 1, dtype=torch.long) 77 | for i, sample in enumerate(samples): 78 | img = torch.FloatTensor(self.imgs[sample] / 255.) 79 | if self.transform is not None: 80 | imgs[i] = self.transform(img) 81 | else: 82 | imgs[i] = img 83 | tables[i] = torch.LongTensor(self.tables[sample]) 84 | tablelens[i] = torch.LongTensor([self.tablelens[sample]]) 85 | self.batch_id += 1 86 | return imgs, tables, tablelens 87 | else: 88 | raise StopIteration() 89 | 90 | def __len__(self): 91 | return len(self.batches) 92 | 93 | class TagCellDataset(object): 94 | """ 95 | Data loader for training encoder-dual-decoder model 96 | """ 97 | 98 | def __init__(self, data_folder, data_name, split, batch_size, mode='all', transform=None): 99 | """ 100 | :param data_folder: folder where data files are stored 101 | :param data_name: base name of processed datasets 102 | :param split: split, one of 'TRAIN', 'VAL', or 'TEST' 103 | :param batch_size: batch size 104 | :param mode: 'tag', 'tag+cell', 'tag+bbox', or 'tag+cell+bbox' 105 | :param transform: image transform pipeline 106 | """ 107 | 108 | assert split in {'TRAIN', 'VAL', 'TEST'} 109 | assert mode in {'tag', 'tag+cell', 'tag+bbox', 'tag+cell+bbox'} 110 | 111 | self.split = split 112 | self.mode = mode 113 | self.batch_size = batch_size 114 | 115 | # Open hdf5 file where images are stored 116 | f = os.path.join(data_folder, self.split + '_IMAGES_' + data_name + '.hdf5') 117 | self.h = h5py.File(f, 'r') 118 | self.imgs = self.h['images'] 119 | 120 | # Load encoded tags (completely into memory) 121 | with open(os.path.join(data_folder, self.split + '_TAGS_' + data_name + '.json'), 'r') as j: 122 | self.tags = json.load(j) 123 | 124 | # Load tag lengths (completely into memory) 125 | with open(os.path.join(data_folder, self.split + '_TAGLENS_' + data_name + '.json'), 'r') as j: 126 | self.taglens = json.load(j) 127 | 128 | # Load cell lengths (completely into memory) 129 | with open(os.path.join(data_folder, self.split + '_CELLLENS_' + data_name + '.json'), 'r') as j: 130 | self.celllens = json.load(j) 131 | 132 | if 'cell' in self.mode: 133 | # Load encoded cell tokens (completely into memory) 134 | with open(os.path.join(data_folder, self.split + '_CELLS_' + data_name + '.json'), 'r') as j: 135 | self.cells = json.load(j) 136 | 137 | if 'bbox' in self.mode: 138 | # Load encoded tags (completely into memory) 139 | with open(os.path.join(data_folder, self.split + '_CELLBBOXES_' + data_name + '.json'), 'r') as j: 140 | self.cellbboxes = json.load(j) 141 | 142 | # PyTorch transformation pipeline for the image (normalizing, etc.) 143 | self.transform = transform 144 | 145 | # Total number of datapoints 146 | self.dataset_size = len(self.tags) 147 | self.ind = np.array(range(self.dataset_size)) 148 | self.pointer = 0 149 | 150 | def shuffle(self): 151 | self.ind = np.random.permutation(self.dataset_size) 152 | self.pointer = 0 153 | 154 | def __iter__(self): 155 | return self 156 | 157 | def __getitem__(self, i): 158 | img = torch.FloatTensor(self.imgs[i]) 159 | tags = self.tags[i] 160 | taglens = self.taglens[i] 161 | cells = self.cells[i] 162 | celllens = self.celllens[i] 163 | image_size = self.imgsizes[i] 164 | return img, tags, taglens, cells, celllens, image_size 165 | 166 | def __next__(self): 167 | if self.pointer < self.dataset_size: 168 | if self.dataset_size - self.pointer >= self.batch_size: 169 | step = self.batch_size 170 | samples = self.ind[self.pointer:self.pointer + step] 171 | else: 172 | step = self.dataset_size - self.pointer 173 | lack = self.batch_size - step 174 | samples = np.hstack((self.ind[self.pointer:self.pointer + step], np.array(range(lack)))) 175 | image_size = self.imgs[samples[0]].shape 176 | imgs = torch.zeros(len(samples), image_size[0], image_size[1], image_size[2], dtype=torch.float) 177 | max_tag_len = max([self.taglens[sample] for sample in samples]) 178 | tags = torch.zeros(len(samples), max_tag_len, dtype=torch.long) 179 | taglens = torch.zeros(len(samples), 1, dtype=torch.long) 180 | num_cells = torch.zeros(len(samples), 1, dtype=torch.long) 181 | if 'cell' in self.mode: 182 | cells = [] 183 | celllens = [] 184 | if 'bbox' in self.mode: 185 | cellbboxes = [] 186 | 187 | for i, sample in enumerate(samples): 188 | img = torch.FloatTensor(self.imgs[sample] / 255.) 189 | if self.transform is not None: 190 | imgs[i] = self.transform(img) 191 | else: 192 | imgs[i] = img 193 | tags[i] = torch.LongTensor(self.tags[sample][:max_tag_len]) 194 | taglens[i] = torch.LongTensor([self.taglens[sample]]) 195 | num_cells[i] = len(self.celllens[sample]) 196 | if 'cell' in self.mode: 197 | max_cell_len = max(self.celllens[sample]) 198 | cells.append(torch.LongTensor(self.cells[sample])[:, :max_cell_len]) 199 | celllens.append(torch.LongTensor(self.celllens[sample])) 200 | if 'bbox' in self.mode: 201 | cellbboxes.append(torch.FloatTensor(self.cellbboxes[sample])) 202 | 203 | self.pointer += step 204 | output = (imgs, tags, taglens, num_cells) 205 | if 'cell' in self.mode: 206 | output += (cells, celllens) 207 | if 'bbox' in self.mode: 208 | output += (cellbboxes,) 209 | return output 210 | else: 211 | raise StopIteration() 212 | 213 | def __len__(self): 214 | return int(np.ceil(self.dataset_size / self.batch_size)) 215 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import json 4 | import torchvision.transforms as transforms 5 | import argparse 6 | import os 7 | from tqdm import tqdm 8 | import sys 9 | import time 10 | from utils import image_rescale 11 | from metric import format_html, similarity_eval_html 12 | from lxml import html 13 | import numpy as np 14 | from glob import glob 15 | import traceback 16 | 17 | 18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 19 | 20 | def convert_table_beam_search(encoder, decoder, image_path, word_map, rev_word_map, 21 | image_size=448, max_steps=400, beam_size=3, 22 | dual_decoder=True): 23 | """ 24 | Reads an image and captions it with beam search. 25 | 26 | :param encoder: encoder model 27 | :param decoder: decoder model 28 | :param image_path: path to image 29 | :param word_map: word map 30 | :param max_steps: max numerb of decoding steps 31 | :param beam_size: number of sequences to consider at each decode-step 32 | :param dual_decoder: if the model has dual decoders 33 | :return: HTML code of input table image 34 | """ 35 | # Read image and process 36 | img = image_rescale(image_path, image_size, False) 37 | img = img / 255. 38 | img = torch.FloatTensor(img) 39 | normalize = transforms.Normalize(mean=[0.94247851, 0.94254675, 0.94292611], 40 | std=[0.17910956, 0.17940403, 0.17931663]) 41 | transform = transforms.Compose([normalize]) 42 | image = transform(img).to(device) # (3, image_size, image_size) 43 | 44 | # Encode 45 | image = image.unsqueeze(0) # (1, 3, image_size, image_size) 46 | encoder_out = encoder(image) # (1, enc_image_size, enc_image_size, encoder_dim) 47 | 48 | res = decoder.inference(encoder_out, word_map, max_steps, beam_size, return_attention=False) 49 | if res is not None: 50 | if dual_decoder: 51 | if len(res) == 2: 52 | html_string = format_html(res[0], rev_word_map['tag'], res[1], rev_word_map['cell']) 53 | else: 54 | html_string = format_html(res[0], rev_word_map['tag']) 55 | else: 56 | html_string = format_html(res, rev_word_map) 57 | else: 58 | html_string = '' 59 | return html_string 60 | 61 | 62 | if __name__ == '__main__': 63 | parser = argparse.ArgumentParser(description='Evaluation of table2html conversion models') 64 | 65 | parser.add_argument('--image_folder', type=str, help='path to image folder') 66 | parser.add_argument('--result_json', type=str, help='path to save results (json)') 67 | parser.add_argument('--model', help='path to model') 68 | parser.add_argument('--word_map', help='path to word map JSON') 69 | parser.add_argument('--gt', default=None, type=str, help='path to ground truth') 70 | parser.add_argument('--image_size', default=448, type=int, help='target size of image rescaling') 71 | parser.add_argument('--dual_decoder', default=False, dest='dual_decoder', action='store_true', help='the decoder is a dual decoder') 72 | parser.add_argument('--beam_size', default={"tag": 3, "cell": 3}, type=json.loads, help='beam size for beam search') 73 | parser.add_argument('--max_steps', default={"tag": 1800, "cell": 600}, type=json.loads, help='max output steps of decoder') 74 | 75 | args = parser.parse_args() 76 | 77 | # Wait until model file exists 78 | if not os.path.isfile(args.model): 79 | while not os.path.isfile(args.model): 80 | print('Model not found, retry in 10 minutes', file=sys.stderr) 81 | sys.stderr.flush() 82 | time.sleep(600) 83 | # Make sure model file is saved completely 84 | time.sleep(10) 85 | # Load model 86 | checkpoint = torch.load(args.model) 87 | 88 | decoder = checkpoint['decoder'] 89 | decoder = decoder.to(device) 90 | decoder.eval() 91 | encoder = checkpoint['encoder'] 92 | encoder = encoder.to(device) 93 | encoder.eval() 94 | 95 | # Load word map (word2ix) 96 | with open(args.word_map, 'r') as j: 97 | word_map = json.load(j) 98 | 99 | if args.dual_decoder: 100 | rev_word_map = {'tag': {v: k for k, v in word_map['word_map_tag'].items()}, 101 | 'cell': {v: k for k, v in word_map['word_map_cell'].items()}} 102 | else: 103 | rev_word_map = {v: k for k, v in word_map.items()} # ix2word 104 | 105 | # Load ground truth 106 | if args.gt is not None: 107 | with open(args.gt, 'r') as j: 108 | gt = json.load(j) 109 | 110 | normalize = transforms.Normalize(mean=[0.94247851, 0.94254675, 0.94292611], 111 | std=[0.17910956, 0.17940403, 0.17931663]) 112 | transform = transforms.Compose([normalize]) 113 | 114 | if args.gt is None: 115 | # Ground truth of test set is not provide. To evaluate test performance, 116 | # Please do not specify the ground truth file, and all png images in 117 | # image_folderwill be converted. Conversion results are saved in a json, 118 | # which can be uploaded to our evaluation service (coming soon) for 119 | # evaluation. 120 | HTML = dict() 121 | images = glob(os.path.join(args.image_folder, '*.png')) 122 | for filename in tqdm(images): 123 | try: 124 | html_pred = convert_table_beam_search( 125 | encoder, decoder, filename, word_map, rev_word_map, 126 | args.image_size, args.max_steps, args.beam_size, 127 | args.dual_decoder) 128 | except Exception as e: 129 | traceback.print_exc() 130 | html_pred = '' 131 | HTML[os.path.basename(filename)] = html_pred 132 | if not os.path.exists(os.path.dirname(args.result_json)): 133 | os.makedirs(os.path.dirname(args.result_json)) 134 | with open(args.result_json, 'w') as fp: 135 | json.dump(HTML, fp) 136 | else: 137 | # Ground truth of validation set is provide. Please specify the ground 138 | # truth file, and the TEDS scores on simple, complex, and all table 139 | # samples will be computed. 140 | TEDS = dict() 141 | for filename, attributes in tqdm(gt.items()): 142 | try: 143 | html_pred = convert_table_beam_search( 144 | encoder, decoder, 145 | os.path.join(args.image_folder, filename), 146 | word_map, rev_word_map, 147 | args.image_size, args.max_steps, args.beam_size, 148 | args.dual_decoder) 149 | if html_pred: 150 | TEDS[filename] = similarity_eval_html(html.fromstring(html_pred), html.fromstring(attributes['html'])) 151 | else: 152 | TEDS[filename] = 0. 153 | except Exception as e: 154 | traceback.print_exc() 155 | TEDS[filename] = 0. 156 | 157 | simple = [TEDS[filename] for filename, attributes in gt.items() if attributes['type'] == 'simple'] 158 | complex = [TEDS[filename] for filename, attributes in gt.items() if attributes['type'] == 'complex'] 159 | total = [TEDS[filename] for filename, attributes in gt.items()] 160 | 161 | print('TEDS of %d simple tables: %.3f' % (len(simple), np.mean(simple))) 162 | print('TEDS of %d complex tables: %.3f' % (len(complex), np.mean(complex))) 163 | print('TEDS of %d all tables: %.3f' % (len(total), np.mean(total))) 164 | 165 | if not os.path.exists(os.path.dirname(args.result_json)): 166 | os.makedirs(os.path.dirname(args.result_json)) 167 | with open(args.result_json, 'w') as fp: 168 | json.dump(TEDS, fp) 169 | -------------------------------------------------------------------------------- /img/DualDecoderArch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ibm-aur-nlp/EDD/e6bb2bd509d7c89cdf1fb9608f9db9a044413bed/img/DualDecoderArch.png -------------------------------------------------------------------------------- /img/tokenization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ibm-aur-nlp/EDD/e6bb2bd509d7c89cdf1fb9608f9db9a044413bed/img/tokenization.png -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | import distance 2 | from apted import APTED, Config 3 | from apted.helpers import Tree 4 | from lxml import html 5 | from collections import deque 6 | from parallel import parallel_process 7 | import numpy as np 8 | import subprocess 9 | import re 10 | import os 11 | import sys 12 | from html import escape 13 | 14 | class TableTree(Tree): 15 | def __init__(self, tag, colspan=None, rowspan=None, content=None, *children): 16 | self.tag = tag 17 | self.colspan = colspan 18 | self.rowspan = rowspan 19 | self.content = content 20 | self.children = list(children) 21 | 22 | def bracket(self): 23 | """Show tree using brackets notation""" 24 | if self.tag == 'td': 25 | result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \ 26 | (self.tag, self.colspan, self.rowspan, self.content) 27 | else: 28 | result = '"tag": %s' % self.tag 29 | for child in self.children: 30 | result += child.bracket() 31 | return "{{{}}}".format(result) 32 | 33 | class CustomConfig(Config): 34 | @staticmethod 35 | def maximum(*sequences): 36 | """Get maximum possible value 37 | """ 38 | return max(map(len, sequences)) 39 | 40 | def normalized_distance(self, *sequences): 41 | """Get distance from 0 to 1 42 | """ 43 | return float(distance.levenshtein(*sequences)) / self.maximum(*sequences) 44 | 45 | def rename(self, node1, node2): 46 | """Compares attributes of trees""" 47 | if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan): 48 | return 1. 49 | if node1.tag == 'td': 50 | if node1.content or node2.content: 51 | return self.normalized_distance(node1.content, node2.content) 52 | return 0. 53 | 54 | def tokenize(node): 55 | ''' Tokenizes table cells 56 | ''' 57 | global __tokens__ 58 | __tokens__.append('<%s>' % node.tag) 59 | if node.text is not None: 60 | __tokens__ += list(node.text) 61 | for n in node.getchildren(): 62 | tokenize(n) 63 | if node.tag != 'unk': 64 | __tokens__.append('' % node.tag) 65 | if node.tag != 'td' and node.tail is not None: 66 | __tokens__ += list(node.tail) 67 | 68 | def format_html(tags, rev_word_map_tags, cells=None, rev_word_map_cells=None): 69 | ''' Formats html code from raw model output 70 | ''' 71 | HTML = [rev_word_map_tags[ind] for ind in tags[1:-1]] 72 | if cells is not None: 73 | to_insert = [i for i, tag in enumerate(HTML) if tag in ('', '>')] 74 | for i, cell in zip(to_insert[::-1], cells[::-1]): 75 | if cell is not None: 76 | cell = [rev_word_map_cells[ind] for ind in cell[1:-1]] 77 | cell = ''.join([escape(token) if len(token) == 1 else token for token in cell]) 78 | HTML.insert(i + 1, cell) 79 | 80 | HTML = ''' 81 | 82 | 83 | 89 | 90 | 91 | 92 | %s 93 |
94 | 95 | ''' % ''.join(HTML) 96 | return HTML 97 | 98 | def tree_convert_html(node, convert_cell=False, parent=None): 99 | ''' Converts HTML tree to the format required by apted 100 | ''' 101 | global __tokens__ 102 | if node.tag == 'td': 103 | if convert_cell: 104 | __tokens__ = [] 105 | tokenize(node) 106 | cell = __tokens__[1:-1].copy() 107 | else: 108 | cell = [] 109 | new_node = TableTree(node.tag, 110 | int(node.attrib.get('colspan', '1')), 111 | int(node.attrib.get('rowspan', '1')), 112 | cell, *deque()) 113 | else: 114 | new_node = TableTree(node.tag, None, None, None, *deque()) 115 | if parent is not None: 116 | parent.children.append(new_node) 117 | if node.tag != 'td': 118 | for n in node.getchildren(): 119 | tree_convert_html(n, convert_cell, new_node) 120 | if parent is None: 121 | return new_node 122 | 123 | def similarity_eval_html(pred, true, structure_only=False): 124 | ''' Computes TEDS score between the prediction and the ground truth of a 125 | given samples 126 | ''' 127 | if pred.xpath('body/table') and true.xpath('body/table'): 128 | pred = pred.xpath('body/table')[0] 129 | true = true.xpath('body/table')[0] 130 | n_nodes_pred = len(pred.xpath(".//*")) 131 | n_nodes_true = len(true.xpath(".//*")) 132 | tree_pred = tree_convert_html(pred, convert_cell=not structure_only) 133 | tree_true = tree_convert_html(true, convert_cell=not structure_only) 134 | n_nodes = max(n_nodes_pred, n_nodes_true) 135 | distance = APTED(tree_pred, tree_true, CustomConfig()).compute_edit_distance() 136 | return 1.0 - (float(distance) / n_nodes) 137 | else: 138 | return 0.0 139 | 140 | def TEDS_wraper(prediction, ground_truth, filename=None): 141 | if prediction: 142 | return similarity_eval_html( 143 | html.fromstring(prediction), 144 | html.fromstring(ground_truth) 145 | ) 146 | else: 147 | return 0. 148 | 149 | def TEDS(gt, pred, n_jobs=8): 150 | ''' Computes TEDS scores for an evaluation set 151 | ''' 152 | assert n_jobs > 0 and isinstance(n_jobs, int), 'n_jobs must be positive integer' 153 | inputs = [{'filename': filename, 'prediction': pred.get(filename, ''), 'ground_truth': attributes['html']} for filename, attributes in gt.items()] 154 | scores = parallel_process(inputs, TEDS_wraper, use_kwargs=True, n_jobs=n_jobs, front_num=1) 155 | scores = {i['filename']: score for i, score in zip(inputs, scores)} 156 | return scores 157 | 158 | def html2xml(html_code, out_path): 159 | if not html_code: 160 | return 161 | root = html.fromstring(html_code) 162 | if root.xpath('body/table'): 163 | table = root.xpath('body/table')[0] 164 | cells = [] 165 | multi_row_cells = [] 166 | row_pt = 0 167 | for row in table.iter('tr'): 168 | row_skip = np.inf 169 | col_pt = 0 170 | for cell in row.getchildren(): 171 | # Skip cells expanded from previous rows 172 | multi_row_cells = sorted(multi_row_cells, key=lambda x: x['start-col']) 173 | for c in multi_row_cells: 174 | if 'end-col' in c: 175 | if c['start-row'] <= row_pt <= c['end-row'] and c['start-col'] <= col_pt <= c['end-col']: 176 | col_pt += c['end-col'] - c['start-col'] + 1 177 | else: 178 | if c['start-row'] <= row_pt <= c['end-row'] and c['start-col'] == col_pt: 179 | col_pt += 1 180 | # Generate new cell 181 | new_cell = {'start-row': row_pt, 182 | 'start-col': col_pt, 183 | 'content': html.tostring(cell, method='text', encoding='utf-8').decode('utf-8')} 184 | # Handle multi-row/col cells 185 | if int(cell.attrib.get('colspan', '1')) > 1: 186 | new_cell['end-col'] = col_pt + int(cell.attrib['colspan']) - 1 187 | if int(cell.attrib.get('rowspan', '1')) > 1: 188 | new_cell['end-row'] = row_pt + int(cell.attrib['rowspan']) - 1 189 | multi_row_cells.append(new_cell) 190 | if new_cell['content']: 191 | cells.append(new_cell) 192 | row_skip = min(row_skip, int(cell.attrib.get('rowspan', '1'))) 193 | col_pt += int(cell.attrib.get('colspan', '1')) 194 | row_pt += row_skip if not np.isinf(row_skip) else 1 195 | multi_row_cells = [cell for cell in multi_row_cells if row_pt <= cell['end-row']] 196 | with open(out_path, 'w') as fp: 197 | fp.write('\n') 198 | fp.write('\n') 199 | fp.write(' \n') 200 | fp.write(' \n') 201 | for i, cell in enumerate(cells): 202 | attributes = ' '.join(['%s=\'%d\'' % (key, value) for key, value in cell.items() if key != 'content']) 203 | fp.write(' \n' % (i, attributes)) 204 | fp.write(' %s\n' % escape(cell['content'])) 205 | fp.write(' \n') 206 | fp.write(' \n') 207 | fp.write('
\n') 208 | fp.write('
') 209 | 210 | def relation_metric(pred, gt, thresholds=None): 211 | if thresholds is None: 212 | thresholds = np.linspace(0.6, 0.95, 8) 213 | precisions = [] 214 | recalls = [] 215 | f1scores = [] 216 | for threshold in thresholds: 217 | try: 218 | result = subprocess.check_output(['java', '-jar', 'dataset-tools-fat-lib.jar', '-str', gt, pred, '-threshold%f' % threshold]) 219 | result = result.split(b'\n')[-2].decode('utf-8') 220 | try: 221 | precision = float(re.search(r'Precision[^=]*= ([0-9.]*)', result).group(1)) 222 | except ValueError: 223 | print(ValueError, file=sys.stderr) 224 | precision = 0.0 225 | try: 226 | recall = float(re.search(r'Recall[^=]*= ([0-9.]*)', result).group(1)) 227 | except ValueError: 228 | print(ValueError, file=sys.stderr) 229 | recall = 0.0 230 | f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0. else 0. 231 | precisions.append(precision) 232 | recalls.append(recall) 233 | f1scores.append(f1) 234 | except Exception as e: 235 | print(os.path.basename(pred), file=sys.stderr) 236 | print(e, file=sys.stderr) 237 | precisions.append(0.) 238 | recalls.append(0.) 239 | f1scores.append(0.) 240 | return np.mean(precisions), np.mean(recalls), np.mean(f1scores) 241 | 242 | 243 | if __name__ == '__main__': 244 | from paramiko import SSHClient 245 | 246 | html_pred = '/Users/peterzhong/Downloads/table2html/Tag+Cell/PMC5059900_003_02.html' 247 | with open(html_pred, 'r') as fp: 248 | pred = html.parse(fp).getroot() 249 | filename = os.path.basename(html_pred).split('.')[0] 250 | 251 | ssh = SSHClient() 252 | ssh.load_system_host_keys() 253 | ssh.connect('dccxl003.pok.ibm.com', username='peterz') 254 | sftp_client = ssh.open_sftp() 255 | with sftp_client.open('/dccstor/ddig/peter/Medline_paper_annotator/data/table_norm/htmls/%s.html' % (filename)) as remote_file: 256 | true = html.parse(remote_file).getroot() 257 | true_table = html.Element("table") 258 | for n in true.xpath('body')[0].getchildren(): 259 | true_table.append(n) 260 | true.xpath('body')[0].append(true_table) 261 | print(similarity_eval_html(pred, true)) 262 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Implementation of encoder-dual-decoder model 3 | ''' 4 | import torch 5 | from torch import nn 6 | import torchvision 7 | from torchvision.models.resnet import BasicBlock, conv1x1 8 | import torch.nn.functional as F 9 | from torch.nn.utils.rnn import pack_padded_sequence 10 | from utils import * 11 | import time 12 | import sys 13 | 14 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 15 | 16 | def resnet_block(stride=1): 17 | layers = [] 18 | downsample = nn.Sequential( 19 | conv1x1(256, 512, stride), 20 | nn.BatchNorm2d(512), 21 | ) 22 | layers.append(BasicBlock(256, 512, stride, downsample)) 23 | layers.append(BasicBlock(512, 512, 1)) 24 | return nn.Sequential(*layers) 25 | 26 | def repackage_hidden(h): 27 | """Wraps hidden states in new Tensors, to detach them from their history.""" 28 | if isinstance(h, torch.Tensor): 29 | return h.detach() 30 | else: 31 | return tuple(repackage_hidden(v) for v in h) 32 | 33 | class Encoder(nn.Module): 34 | """ 35 | Encoder. 36 | """ 37 | 38 | def __init__(self, encoded_image_size=14, use_RNN=False, rnn_size=512, last_layer_stride=2): 39 | super(Encoder, self).__init__() 40 | self.enc_image_size = encoded_image_size 41 | self.use_RNN = use_RNN 42 | self.rnn_size = rnn_size 43 | 44 | resnet = torchvision.models.resnet18(pretrained=False) # ImageNet ResNet-18 45 | 46 | # Remove linear and pool layers (since we're not doing classification) 47 | # Also remove the last CNN layer for higher resolution feature map 48 | modules = list(resnet.children())[:-3] 49 | if last_layer_stride is not None: 50 | modules.append(resnet_block(stride=last_layer_stride)) 51 | 52 | # Change stride of max pooling layer for higher resolution feature map 53 | # modules[3] = nn.MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False) 54 | 55 | self.resnet = nn.Sequential(*modules) 56 | 57 | # Resize image to fixed size to allow input images of variable size 58 | self.adaptive_pool = nn.AdaptiveAvgPool2d((self.enc_image_size, self.enc_image_size)) 59 | 60 | if self.use_RNN: 61 | self.RNN = nn.LSTM(512, self.rnn_size, bias=True, batch_first=True) # LSTM that transforms the image features 62 | self.init_h = nn.Linear(512, self.rnn_size) # linear layer to find initial hidden state of LSTM 63 | self.init_c = nn.Linear(512, self.rnn_size) # linear layer to find initial cell state of LSTM 64 | self.fine_tune() 65 | 66 | def init_hidden_state(self, encoder_out): 67 | """ 68 | Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images. 69 | 70 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) 71 | :return: hidden state, cell state 72 | """ 73 | mean_encoder_out = encoder_out.mean(dim=1) 74 | h = self.init_h(mean_encoder_out).unsqueeze(0) # (batch_size*encoded_image_size, rnn_size) 75 | c = self.init_c(mean_encoder_out).unsqueeze(0) 76 | return h, c 77 | 78 | def forward(self, images): 79 | """ 80 | Forward propagation. 81 | 82 | :param images: images, a tensor of dimensions (batch_size, 3, image_size, image_size) 83 | :return: encoded images 84 | """ 85 | batch_size = images.size(0) 86 | out = self.resnet(images) # (batch_size, 512, image_size/32, image_size/32) 87 | out = self.adaptive_pool(out) # (batch_size, 512, encoded_image_size, encoded_image_size) 88 | out = out.permute(0, 2, 3, 1) # (batch_size, encoded_image_size, encoded_image_size, 512) 89 | if self.use_RNN: 90 | out = out.contiguous().view(-1, self.enc_image_size, 512) # (batch_size*encoded_image_size, encoded_image_size, 512) 91 | h = self.init_hidden_state(out) 92 | out, h = self.RNN(out, h) # (batch_size*encoded_image_size, encoded_image_size, 512) 93 | out = out.view(batch_size, self.enc_image_size, self.enc_image_size, self.rnn_size).contiguous() 94 | return out 95 | 96 | def fine_tune(self, fine_tune=True): 97 | """ 98 | Allow or prevent the computation of gradients for convolutional blocks 2 through 4 of the encoder. 99 | 100 | :param fine_tune: Allow? 101 | """ 102 | for p in self.resnet.parameters(): 103 | p.requires_grad = fine_tune 104 | 105 | class Attention(nn.Module): 106 | """ 107 | Attention Network. 108 | """ 109 | 110 | def __init__(self, encoder_dim, decoder_dim, attention_dim): 111 | """ 112 | :param encoder_dim: feature size of encoded images 113 | :param decoder_dim: size of decoder's RNN 114 | :param attention_dim: size of the attention network 115 | """ 116 | super(Attention, self).__init__() 117 | self.encoder_att = nn.Linear(encoder_dim, attention_dim) # linear layer to transform encoded image 118 | self.decoder_att = nn.Linear(decoder_dim, attention_dim) # linear layer to transform decoder's output 119 | self.full_att = nn.Linear(attention_dim, 1) # linear layer to calculate values to be softmax-ed 120 | self.relu = nn.ReLU() 121 | self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights 122 | 123 | def forward(self, encoder_out, decoder_hidden): 124 | """ 125 | Forward propagation. 126 | 127 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) 128 | :param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim) 129 | :return: attention weighted encoding, weights 130 | """ 131 | att1 = self.encoder_att(encoder_out) # (batch_size, num_pixels, attention_dim) 132 | att2 = self.decoder_att(decoder_hidden) # (batch_size, attention_dim) 133 | att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2) # (batch_size, num_pixels) 134 | alpha = self.softmax(att) # (batch_size, num_pixels) 135 | attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1) # (batch_size, encoder_dim) 136 | 137 | return attention_weighted_encoding, alpha 138 | 139 | class CellAttention(nn.Module): 140 | """ 141 | Attention Network. 142 | """ 143 | 144 | def __init__(self, encoder_dim, tag_decoder_dim, language_dim, attention_dim): 145 | """ 146 | :param encoder_dim: feature size of encoded images 147 | :param tag_decoder_dim: size of tag decoder's RNN 148 | :param language_dim: size of language model's RNN 149 | :param attention_dim: size of the attention network 150 | """ 151 | super(CellAttention, self).__init__() 152 | self.encoder_att = nn.Linear(encoder_dim, attention_dim) # linear layer to transform encoded image 153 | self.tag_decoder_att = nn.Linear(tag_decoder_dim, attention_dim) # linear layer to transform tag decoder output 154 | self.language_att = nn.Linear(language_dim, attention_dim) # linear layer to transform language models output 155 | self.full_att = nn.Linear(attention_dim, 1) # linear layer to calculate values to be softmax-ed 156 | self.relu = nn.ReLU() 157 | self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights 158 | 159 | def forward(self, encoder_out, decoder_hidden, language_out): 160 | """ 161 | Forward propagation. 162 | 163 | :param encoder_out: encoded images, a tensor of dimension (1, num_pixels, encoder_dim) 164 | :param decoder_hidden: tag decoder output, a tensor of dimension [(num_cells, tag_decoder_dim)] 165 | :param language_out: language model output, a tensor of dimension (num_cells, language_dim) 166 | :return: attention weighted encoding, weights 167 | """ 168 | att1 = self.encoder_att(encoder_out) # (1, num_pixels, attention_dim) 169 | att2 = self.tag_decoder_att(decoder_hidden) # (num_cells, tag_decoder_dim) 170 | att3 = self.language_att(language_out) # (num_cells, attention_dim) 171 | att = self.full_att(self.relu(att1 + att2.unsqueeze(1) + att3.unsqueeze(1))).squeeze(2) # (num_cells, num_pixels) 172 | alpha = self.softmax(att) # (num_cells, num_pixels) 173 | attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1) # (num_cells, encoder_dim) 174 | 175 | return attention_weighted_encoding, alpha 176 | 177 | class DecoderWithAttention(nn.Module): 178 | """ 179 | Decoder. 180 | """ 181 | 182 | def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, decoder_cell=nn.LSTMCell, encoder_dim=512, dropout=0.5): 183 | """ 184 | :param attention_dim: size of attention network 185 | :param embed_dim: embedding size 186 | :param decoder_dim: size of decoder's RNN 187 | :param vocab_size: size of vocabulary 188 | :param encoder_dim: feature size of encoded images 189 | :param dropout: dropout 190 | """ 191 | super(DecoderWithAttention, self).__init__() 192 | 193 | assert decoder_cell.__name__ in ('GRUCell', 'LSTMCell'), 'decoder_cell must be either nn.LSTMCell or nn.GRUCell' 194 | self.encoder_dim = encoder_dim 195 | self.attention_dim = attention_dim 196 | self.embed_dim = embed_dim 197 | self.decoder_dim = decoder_dim 198 | self.vocab_size = vocab_size 199 | self.dropout = dropout 200 | 201 | self.attention = Attention(encoder_dim, decoder_dim, attention_dim) # attention network 202 | 203 | self.embedding = nn.Embedding(vocab_size, embed_dim) # embedding layer 204 | self.dropout = nn.Dropout(p=self.dropout) 205 | self.decode_step = decoder_cell(embed_dim + encoder_dim, decoder_dim, bias=True) # decoding LSTMCell 206 | self.init_h = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial hidden state of LSTMCell 207 | if isinstance(self.decode_step, nn.LSTMCell): 208 | self.init_c = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial cell state of LSTMCell 209 | self.f_beta = nn.Linear(decoder_dim, encoder_dim) # linear layer to create a sigmoid-activated gate 210 | self.sigmoid = nn.Sigmoid() 211 | self.fc = nn.Linear(decoder_dim, vocab_size) # linear layer to find scores over vocabulary 212 | self.init_weights() # initialize some layers with the uniform distribution 213 | 214 | def init_weights(self): 215 | """ 216 | Initializes some parameters with values from the uniform distribution, for easier convergence. 217 | """ 218 | self.embedding.weight.data.uniform_(-0.1, 0.1) 219 | self.fc.bias.data.fill_(0) 220 | self.fc.weight.data.uniform_(-0.1, 0.1) 221 | 222 | def load_pretrained_embeddings(self, embeddings): 223 | """ 224 | Loads embedding layer with pre-trained embeddings. 225 | 226 | :param embeddings: pre-trained embeddings 227 | """ 228 | self.embedding.weight = nn.Parameter(embeddings) 229 | 230 | def fine_tune_embeddings(self, fine_tune=True): 231 | """ 232 | Allow fine-tuning of embedding layer? (Only makes sense to not-allow if using pre-trained embeddings). 233 | 234 | :param fine_tune: Allow? 235 | """ 236 | for p in self.embedding.parameters(): 237 | p.requires_grad = fine_tune 238 | 239 | def init_hidden_state(self, encoder_out): 240 | """ 241 | Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images. 242 | 243 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) 244 | :return: hidden state, cell state 245 | """ 246 | mean_encoder_out = encoder_out.mean(dim=1) 247 | h = self.init_h(mean_encoder_out) # (batch_size, decoder_dim) 248 | if isinstance(self.decode_step, nn.LSTMCell): 249 | c = self.init_c(mean_encoder_out) 250 | return h, c 251 | else: 252 | return h 253 | 254 | def inference(self, encoder_out, word_map, max_steps=400, beam_size=5, return_attention=False): 255 | """ 256 | Inference on test images with beam search 257 | """ 258 | enc_image_size = encoder_out.size(1) 259 | encoder_dim = encoder_out.size(3) 260 | 261 | # Flatten encoding 262 | encoder_out = encoder_out.view(1, -1, encoder_dim) # (1, num_pixels, encoder_dim) 263 | num_pixels = encoder_out.size(1) 264 | 265 | k = beam_size 266 | vocab_size = len(word_map) 267 | 268 | # We'll treat the problem as having a batch size of k 269 | encoder_out = encoder_out.expand(k, num_pixels, encoder_dim) # (k, num_pixels, encoder_dim) 270 | 271 | # Tensor to store top k previous words at each step; now they're just 272 | k_prev_words = torch.LongTensor([[word_map['']]] * k).to(device) # (k, 1) 273 | 274 | # Tensor to store top k sequences; now they're just 275 | seqs = k_prev_words # (k, 1) 276 | 277 | # Tensor to store top k sequences' scores; now they're just 0 278 | top_k_scores = torch.zeros(k, 1).to(device) # (k, 1) 279 | 280 | # Tensor to store top k sequences' alphas; now they're just 1s 281 | if return_attention: 282 | seqs_alpha = torch.ones(k, 1, enc_image_size, enc_image_size).to(device) # (k, 1, enc_image_size, enc_image_size) 283 | 284 | # Lists to store completed sequences, their alphas and scores 285 | complete_seqs = list() 286 | if return_attention: 287 | complete_seqs_alpha = list() 288 | complete_seqs_scores = list() 289 | 290 | # Start decoding 291 | step = 1 292 | 293 | if isinstance(self.decode_step, nn.LSTMCell): 294 | h, c = self.init_hidden_state(encoder_out) 295 | else: 296 | h = self.init_hidden_state(encoder_out) 297 | 298 | # s is a number less than or equal to k, because sequences are removed from this process once they hit 299 | while True: 300 | embeddings = self.embedding(k_prev_words).squeeze(1) # (s, embed_dim) 301 | if return_attention: 302 | awe, alpha = self.attention(encoder_out, h) # (s, encoder_dim), (s, num_pixels) 303 | alpha = alpha.view(-1, enc_image_size, enc_image_size) # (s, enc_image_size, enc_image_size) 304 | else: 305 | awe, _ = self.attention(encoder_out, h) # (s, encoder_dim), (s, num_pixels) 306 | 307 | gate = self.sigmoid(self.f_beta(h)) # gating scalar, (s, encoder_dim) 308 | awe = gate * awe 309 | 310 | if isinstance(self.decode_step, nn.LSTMCell): 311 | h, c = self.decode_step(torch.cat([embeddings, awe], dim=1), (h, c)) # (s, decoder_dim) 312 | else: 313 | h = self.decode_step(torch.cat([embeddings, awe], dim=1), h) # (s, decoder_dim) 314 | 315 | h = repackage_hidden(h) 316 | if isinstance(self.decode_step, nn.LSTMCell): 317 | c = repackage_hidden(c) 318 | 319 | scores = self.fc(h) # (s, vocab_size) 320 | scores = F.log_softmax(scores, dim=1) 321 | 322 | # Add 323 | scores = top_k_scores.expand_as(scores) + scores # (s, vocab_size) 324 | 325 | # For the first step, all k points will have the same scores (since same k previous words, h, c) 326 | if step == 1: 327 | top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s) 328 | else: 329 | # Unroll and find top scores, and their unrolled indices 330 | top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s) 331 | 332 | # Convert unrolled indices to actual indices of scores 333 | prev_word_inds = top_k_words / vocab_size # (s) 334 | next_word_inds = top_k_words % vocab_size # (s) 335 | 336 | # Add new words to sequences, alphas 337 | seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1) 338 | if return_attention: 339 | seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)], 340 | dim=1) # (s, step+1, enc_image_size, enc_image_size) 341 | 342 | # Which sequences are incomplete (didn't reach )? 343 | incomplete_inds = [] 344 | complete_inds = [] 345 | for ind, next_word in enumerate(next_word_inds): 346 | if next_word == word_map['']: 347 | complete_inds.append(ind) 348 | else: 349 | incomplete_inds.append(ind) 350 | 351 | # Set aside complete sequences 352 | if len(complete_inds) > 0: 353 | complete_seqs.extend(seqs[complete_inds].tolist()) 354 | if return_attention: 355 | complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist()) 356 | complete_seqs_scores.extend(top_k_scores[complete_inds]) 357 | k -= len(complete_inds) # reduce beam length accordingly 358 | 359 | # Proceed with incomplete sequences 360 | if k == 0: 361 | break 362 | 363 | # Break if things have been going on too long 364 | if step > max_steps: 365 | # If no complete sequence is generated, finish the incomplete 366 | # sequences with 367 | if not complete_seqs_scores: 368 | complete_seqs = seqs.tolist() 369 | for i in range(len(complete_seqs)): 370 | complete_seqs[i].append(word_map['']) 371 | if return_attention: 372 | complete_seqs_alpha = seqs_alpha.tolist() 373 | complete_seqs_scores = top_k_scores.tolist() 374 | break 375 | 376 | seqs = seqs[incomplete_inds] 377 | if return_attention: 378 | seqs_alpha = seqs_alpha[incomplete_inds] 379 | h = h[prev_word_inds[incomplete_inds]] 380 | if isinstance(self.decode_step, nn.LSTMCell): 381 | c = c[prev_word_inds[incomplete_inds]] 382 | encoder_out = encoder_out[prev_word_inds[incomplete_inds]] 383 | top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1) 384 | k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1) 385 | 386 | step += 1 387 | i = complete_seqs_scores.index(max(complete_seqs_scores)) 388 | seq = complete_seqs[i] 389 | if return_attention: 390 | alphas = complete_seqs_alpha[i] 391 | return seq, alphas 392 | else: 393 | return seq 394 | 395 | def forward(self, encoder_out, encoded_captions, caption_lengths, h, c=None, begin_tokens=None): 396 | """ 397 | Forward propagation. 398 | 399 | :param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim) 400 | :param encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length) 401 | :param caption_lengths: caption lengths, a tensor of dimension (batch_size, 1) 402 | :return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices 403 | """ 404 | batch_size = encoder_out.size(0) 405 | encoder_dim = encoder_out.size(-1) 406 | vocab_size = self.vocab_size 407 | 408 | # Flatten image 409 | num_pixels = encoder_out.size(1) 410 | 411 | if begin_tokens is None: 412 | # Embedding 413 | embeddings = self.embedding(encoded_captions) # (batch_size, max_caption_length, embed_dim) 414 | # We won't decode at the position, since we've finished generating as soon as we generate 415 | # So, decoding lengths are actual lengths - 1 416 | decode_lengths = (caption_lengths - 1).tolist() 417 | else: # For TBPTT, use the end token of the previous sub-sequence as begin token instead of 418 | embeddings = torch.cat([self.embedding(begin_tokens), self.embedding(encoded_captions)], dim=1) 419 | decode_lengths = caption_lengths.tolist() 420 | 421 | # Create tensors to hold word predicion scores and alphas 422 | predictions = torch.zeros(batch_size, decode_lengths[0], vocab_size).to(device) 423 | alphas = torch.zeros(batch_size, decode_lengths[0], num_pixels).to(device) 424 | 425 | # At each time-step, decode by 426 | # attention-weighing the encoder's output based on the decoder's previous hidden state output 427 | # then generate a new word in the decoder with the previous word and the attention weighted encoding 428 | for t in range(decode_lengths[0]): 429 | batch_size_t = sum([l > t for l in decode_lengths]) 430 | attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], 431 | h[:batch_size_t]) 432 | gate = self.sigmoid(self.f_beta(h[:batch_size_t])) # gating scalar, (batch_size_t, encoder_dim) 433 | attention_weighted_encoding = gate * attention_weighted_encoding 434 | if isinstance(self.decode_step, nn.LSTMCell): 435 | h, c = self.decode_step( 436 | torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1), 437 | (h[:batch_size_t], c[:batch_size_t])) # (batch_size_t, decoder_dim) 438 | else: 439 | h = self.decode_step( 440 | torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1), 441 | h[:batch_size_t]) # (batch_size_t, decoder_dim) 442 | predictions[:batch_size_t, t, :] = self.fc(self.dropout(h)) # (batch_size_t, vocab_size) 443 | alphas[:batch_size_t, t, :] = alpha 444 | 445 | return predictions, decode_lengths, alphas, h, c 446 | 447 | def train_epoch(self, train_loader, encoder, criterion, encoder_optimizer, decoder_optimizer, epoch, args, step=None): 448 | """ 449 | Performs one epoch's training. 450 | 451 | :param train_loader: DataLoader for training data 452 | :param encoder: encoder model 453 | :param decoder: decoder model 454 | :param criterion: loss layer 455 | :param encoder_optimizer: optimizer to update encoder's weights (if fine-tuning) 456 | :param decoder_optimizer: optimizer to update decoder's weights 457 | :param epoch: epoch number 458 | """ 459 | 460 | self.train() # train mode (dropout and batchnorm is used) 461 | encoder.train() 462 | 463 | batch_time = AverageMeter() # forward prop. + back prop. time 464 | data_time = AverageMeter() # data loading time 465 | losses = AverageMeter() # loss (per word decoded) 466 | top1accs = AverageMeter() # top1 accuracy 467 | 468 | start = time.time() 469 | 470 | # Batches 471 | train_loader.shuffle() 472 | for i, (imgs, caps_sorted, caplens) in enumerate(train_loader): 473 | if step is not None: 474 | if i <= step: 475 | continue 476 | data_time.update(time.time() - start) 477 | 478 | # Move to GPU, if available 479 | imgs = imgs.to(device) 480 | caps_sorted = caps_sorted.to(device) 481 | caplens = caplens.to(device) 482 | 483 | # Forward prop. 484 | imgs = encoder(imgs) 485 | # Flatten image 486 | batch_size = imgs.size(0) 487 | encoder_dim = imgs.size(-1) 488 | imgs = imgs.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim) 489 | caplens = caplens.squeeze(1) 490 | 491 | # Sort input data by decreasing lengths 492 | # caplens, sort_ind = caplens.squeeze(1).sort(dim=0, descending=True) 493 | # imgs = imgs[sort_ind] 494 | # caps_sorted = caps[sort_ind] 495 | 496 | # Initialize LSTM state 497 | if isinstance(self.decode_step, nn.LSTMCell): 498 | h, c = self.init_hidden_state(imgs) # (batch_size, decoder_dim) 499 | else: 500 | h = self.init_hidden_state(imgs) # (batch_size, decoder_dim) 501 | c = None 502 | 503 | max_cap_length = max(caplens.tolist()) 504 | # TBPTT 505 | j = 0 506 | while j < max_cap_length: 507 | if j == 0: 508 | # bptt tokens after 509 | sub_seq_len = min(args.bptt + 1, max_cap_length - j) 510 | else: 511 | sub_seq_len = min(args.bptt, max_cap_length - j) 512 | # Do not leave too short tails (less than 10 tokens) 513 | short_tail = (caplens - (j + sub_seq_len) < 10) & (caplens - (j + sub_seq_len) > 0) 514 | if short_tail.any(): 515 | sub_seq_len += max((caplens - (j + sub_seq_len))[short_tail].tolist()) 516 | 517 | sub_seq_caplens = caplens - j 518 | sub_seq_caplens[sub_seq_caplens > sub_seq_len] = sub_seq_len 519 | batch_size_t = (sub_seq_caplens > 0).sum().item() 520 | sub_seq_caplens = sub_seq_caplens[:batch_size_t] 521 | sub_seq_cap = caps_sorted[:batch_size_t, j:j + sub_seq_len] 522 | 523 | h = repackage_hidden(h) 524 | if isinstance(self.decode_step, nn.LSTMCell): 525 | c = repackage_hidden(c) 526 | 527 | decoder_optimizer.zero_grad() 528 | if encoder_optimizer is not None: 529 | encoder_optimizer.zero_grad() 530 | if j == 0: 531 | scores, decode_lengths, alphas, h, c = self( 532 | imgs[:batch_size_t], 533 | sub_seq_cap, 534 | sub_seq_caplens, 535 | h, 536 | c) 537 | # Since we decoded starting with , the targets are all words after , up to 538 | targets = sub_seq_cap[:, 1:] 539 | else: 540 | scores, decode_lengths, alphas, h, c = self( 541 | imgs[:batch_size_t], 542 | sub_seq_cap, 543 | sub_seq_caplens, 544 | h, 545 | c, 546 | caps_sorted[:batch_size_t, j - 1].unsqueeze(1)) 547 | targets = sub_seq_cap 548 | 549 | # Remove timesteps that we didn't decode at, or are pads 550 | # pack_padded_sequence is an easy trick to do this 551 | scores = pack_padded_sequence(scores, decode_lengths, batch_first=True)[0] 552 | targets = pack_padded_sequence(targets, decode_lengths, batch_first=True)[0] 553 | 554 | # Calculate loss 555 | loss = criterion(scores, targets) 556 | 557 | # Add doubly stochastic attention regularization 558 | loss += args.alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean() 559 | 560 | # Back prop. 561 | if j + sub_seq_len < max_cap_length: 562 | loss.backward(retain_graph=True) 563 | else: 564 | loss.backward() 565 | 566 | # Clip gradients 567 | if args.grad_clip is not None: 568 | clip_gradient(decoder_optimizer, args.grad_clip) 569 | if encoder_optimizer is not None: 570 | clip_gradient(encoder_optimizer, args.grad_clip) 571 | 572 | # Update weights 573 | decoder_optimizer.step() 574 | if encoder_optimizer is not None: 575 | encoder_optimizer.step() 576 | 577 | # Keep track of metrics 578 | top1 = accuracy(scores, targets, 1) 579 | losses.update(loss.item(), sum(decode_lengths)) 580 | top1accs.update(top1, sum(decode_lengths)) 581 | j += sub_seq_len 582 | batch_time.update(time.time() - start) 583 | 584 | start = time.time() 585 | 586 | # Print status 587 | if i % args.print_freq == 0: 588 | print('Epoch: [{0}][{1}/{2}]\t' 589 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 590 | 'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t' 591 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 592 | 'Top-1 Accuracy {top1.val:.3f} ({top1.avg:.3f})'.format(epoch, i, len(train_loader), 593 | batch_time=batch_time, 594 | data_time=data_time, loss=losses, 595 | top1=top1accs), file=sys.stderr) 596 | sys.stderr.flush() 597 | 598 | class DecoderWithAttentionAndLanguageModel(nn.Module): 599 | ''' 600 | Stacked 2-layer LSTM with Attention model. First LSTM is a languange model, second LSTM is a decoder. 601 | See "Recursive Recurrent Nets with Attention Modeling for OCR in the Wild" 602 | ''' 603 | def __init__(self, attention_dim, embed_dim, language_dim, decoder_dim, vocab_size, 604 | decoder_cell=nn.LSTMCell, encoder_dim=512, dropout=0.5): 605 | """ 606 | :param attention_dim: size of attention network 607 | :param embed_dim: embedding size 608 | :param language_dim: size of language model's RNN 609 | :param decoder_dim: size of decoder's RNN 610 | :param vocab_size: size of vocabulary 611 | :param encoder_dim: feature size of encoded images 612 | :param dropout: dropout 613 | """ 614 | super(DecoderWithAttentionAndLanguageModel, self).__init__() 615 | 616 | self.encoder_dim = encoder_dim 617 | self.attention_dim = attention_dim 618 | self.embed_dim = embed_dim 619 | self.language_dim = language_dim 620 | self.decoder_dim = decoder_dim 621 | self.vocab_size = vocab_size 622 | self.dropout = dropout 623 | 624 | self.attention = Attention(encoder_dim, language_dim, attention_dim) # attention network 625 | 626 | self.embedding = nn.Embedding(vocab_size, embed_dim) # embedding layer 627 | 628 | self.decode_step_LM = decoder_cell(embed_dim, language_dim, bias=True) # language model LSTMCell 629 | 630 | self.decode_step_pred = decoder_cell(encoder_dim, decoder_dim, bias=True) # decoding LSTMCell 631 | self.init_h = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial hidden state of LSTMCell 632 | if isinstance(self.decode_step_pred, nn.LSTMCell): 633 | self.init_c = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial cell state of LSTMCell 634 | self.f_beta = nn.Linear(decoder_dim, encoder_dim) # linear layer to create a sigmoid-activated gate 635 | self.sigmoid = nn.Sigmoid() 636 | self.fc = nn.Linear(decoder_dim, vocab_size) # linear layer to find scores over vocabulary 637 | self.dropout = nn.Dropout(p=self.dropout) 638 | self.init_weights() # initialize some layers with the uniform distribution 639 | 640 | def init_weights(self): 641 | """ 642 | Initializes some parameters with values from the uniform distribution, for easier convergence. 643 | """ 644 | self.embedding.weight.data.uniform_(-0.1, 0.1) 645 | self.fc.bias.data.fill_(0) 646 | self.fc.weight.data.uniform_(-0.1, 0.1) 647 | 648 | def init_hidden_state(self, encoder_out): 649 | """ 650 | Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images. 651 | 652 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) 653 | :return: hidden state, cell state 654 | """ 655 | batch_size = encoder_out.size(0) 656 | mean_encoder_out = encoder_out.mean(dim=1) 657 | h_LM = torch.zeros(batch_size, self.language_dim).to(device) 658 | h_pred = self.init_h(mean_encoder_out) # (batch_size, decoder_dim) 659 | if isinstance(self.decode_step_pred, nn.LSTMCell): 660 | c_LM = torch.zeros(batch_size, self.language_dim).to(device) 661 | c_pred = self.init_c(mean_encoder_out) 662 | return h_LM, c_LM, h_pred, c_pred 663 | else: 664 | return h_LM, h_pred 665 | 666 | def inference(self, encoder_out, word_map, max_steps=400, beam_size=5): 667 | """ 668 | Inference on test images with beam search 669 | """ 670 | enc_image_size = encoder_out.size(1) 671 | encoder_dim = encoder_out.size(3) 672 | 673 | # Flatten encoding 674 | encoder_out = encoder_out.view(1, -1, encoder_dim) # (1, num_pixels, encoder_dim) 675 | num_pixels = encoder_out.size(1) 676 | 677 | k = beam_size 678 | vocab_size = len(word_map) 679 | 680 | # We'll treat the problem as having a batch size of k 681 | encoder_out = encoder_out.expand(k, num_pixels, encoder_dim) # (k, num_pixels, encoder_dim) 682 | 683 | # Tensor to store top k previous words at each step; now they're just 684 | k_prev_words = torch.LongTensor([[word_map['']]] * k).to(device) # (k, 1) 685 | 686 | # Tensor to store top k sequences; now they're just 687 | seqs = k_prev_words # (k, 1) 688 | 689 | # Tensor to store top k sequences' scores; now they're just 0 690 | top_k_scores = torch.zeros(k, 1).to(device) # (k, 1) 691 | 692 | # Tensor to store top k sequences' alphas; now they're just 1s 693 | seqs_alpha = torch.ones(k, 1, enc_image_size, enc_image_size).to(device) # (k, 1, enc_image_size, enc_image_size) 694 | 695 | # Lists to store completed sequences, their alphas and scores 696 | complete_seqs = list() 697 | complete_seqs_alpha = list() 698 | complete_seqs_scores = list() 699 | 700 | # Start decoding 701 | step = 1 702 | if isinstance(self.decode_step_pred, nn.LSTMCell): 703 | h_LM, c_LM, h_cell, c_cell = self.init_hidden_state(encoder_out) 704 | else: 705 | h_LM, h_cell = self.init_hidden_state(encoder_out) 706 | 707 | # s is a number less than or equal to k, because sequences are removed from this process once they hit 708 | while True: 709 | embeddings = self.embedding(k_prev_words).squeeze(1) # (s, embed_dim) 710 | 711 | if isinstance(self.decode_step_LM, nn.LSTMCell): 712 | h_LM, c_LM = self.decode_step_LM(embeddings, (h_LM, c_LM)) 713 | else: 714 | h_LM = self.decode_step_LM(embeddings, h_LM) 715 | awe, alpha = self.attention(encoder_out, h_LM) 716 | 717 | alpha = alpha.view(-1, enc_image_size, enc_image_size) # (s, enc_image_size, enc_image_size) 718 | 719 | gate = self.sigmoid(self.f_beta(h_cell)) # gating scalar, (s, encoder_dim) 720 | awe = gate * awe 721 | 722 | if isinstance(self.decode_step_pred, nn.LSTMCell): 723 | h_cell, c_cell = self.decode_step_pred(awe, (h_cell, c_cell)) # (batch_size_t, decoder_dim) 724 | else: 725 | h_cell = self.decode_step_pred(awe, h_cell) # (batch_size_t, decoder_dim) 726 | 727 | scores = self.fc(h_cell) # (s, vocab_size) 728 | scores = F.log_softmax(scores, dim=1) 729 | 730 | # Add 731 | scores = top_k_scores.expand_as(scores) + scores # (s, vocab_size) 732 | 733 | # For the first step, all k points will have the same scores (since same k previous words, h, c) 734 | if step == 1: 735 | top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s) 736 | else: 737 | # Unroll and find top scores, and their unrolled indices 738 | top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s) 739 | 740 | # Convert unrolled indices to actual indices of scores 741 | prev_word_inds = top_k_words / vocab_size # (s) 742 | next_word_inds = top_k_words % vocab_size # (s) 743 | 744 | # Add new words to sequences, alphas 745 | seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1) 746 | seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)], 747 | dim=1) # (s, step+1, enc_image_size, enc_image_size) 748 | 749 | # Which sequences are incomplete (didn't reach )? 750 | incomplete_inds = [] 751 | complete_inds = [] 752 | for ind, next_word in enumerate(next_word_inds): 753 | if next_word == word_map['']: 754 | complete_inds.append(ind) 755 | else: 756 | incomplete_inds.append(ind) 757 | 758 | # Set aside complete sequences 759 | if len(complete_inds) > 0: 760 | complete_seqs.extend(seqs[complete_inds].tolist()) 761 | complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist()) 762 | complete_seqs_scores.extend(top_k_scores[complete_inds]) 763 | k -= len(complete_inds) # reduce beam length accordingly 764 | 765 | # Proceed with incomplete sequences 766 | if k == 0: 767 | break 768 | seqs = seqs[incomplete_inds] 769 | seqs_alpha = seqs_alpha[incomplete_inds] 770 | h_LM = h_LM[prev_word_inds[incomplete_inds]] 771 | h_cell = h_cell[prev_word_inds[incomplete_inds]] 772 | if isinstance(self.decode_step_pred, nn.LSTMCell): 773 | c_LM = c_LM[prev_word_inds[incomplete_inds]] 774 | c_cell = c_cell[prev_word_inds[incomplete_inds]] 775 | encoder_out = encoder_out[prev_word_inds[incomplete_inds]] 776 | top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1) 777 | k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1) 778 | 779 | # Break if things have been going on too long 780 | if step > max_steps: 781 | break 782 | step += 1 783 | if complete_seqs_scores: 784 | i = complete_seqs_scores.index(max(complete_seqs_scores)) 785 | seq = complete_seqs[i] 786 | alphas = complete_seqs_alpha[i] 787 | return seq, alphas 788 | else: 789 | return None 790 | 791 | def forward(self, encoder_out, encoded_captions, caption_lengths, h_LM, h_pred, c_LM=None, c_pred=None, begin_tokens=None): 792 | """ 793 | Forward propagation. 794 | 795 | :param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim) 796 | :param encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length) 797 | :param caption_lengths: caption lengths, a tensor of dimension (batch_size, 1) 798 | :return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices 799 | """ 800 | batch_size = encoder_out.size(0) 801 | encoder_dim = encoder_out.size(-1) 802 | vocab_size = self.vocab_size 803 | 804 | # Flatten image 805 | num_pixels = encoder_out.size(1) 806 | 807 | if begin_tokens is None: 808 | # Embedding 809 | embeddings = self.embedding(encoded_captions) # (batch_size, max_caption_length, embed_dim) 810 | # We won't decode at the position, since we've finished generating as soon as we generate 811 | # So, decoding lengths are actual lengths - 1 812 | decode_lengths = (caption_lengths - 1).tolist() 813 | else: # For TBPTT, use the end token of the previous sub-sequence as begin token instead of 814 | embeddings = torch.cat([self.embedding(begin_tokens), self.embedding(encoded_captions)], dim=1) 815 | decode_lengths = caption_lengths.tolist() 816 | 817 | # Create tensors to hold word predicion scores and alphas 818 | predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device) 819 | alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device) 820 | 821 | # At each time-step, decode by 822 | # attention-weighing the encoder's output based on the decoder's previous hidden state output 823 | # then generate a new word in the decoder with the previous word and the attention weighted encoding 824 | for t in range(max(decode_lengths)): 825 | batch_size_t = sum([l > t for l in decode_lengths]) 826 | 827 | # Language LSTM 828 | if isinstance(self.decode_step_LM, nn.LSTMCell): 829 | h_LM, c_LM = self.decode_step_LM( 830 | embeddings[:batch_size_t, t, :], 831 | (h_LM[:batch_size_t], c_LM[:batch_size_t])) # (batch_size_t, decoder_dim) 832 | else: 833 | h_LM = self.decode_step_LM( 834 | embeddings[:batch_size_t, t, :], 835 | h_LM[:batch_size_t]) # (batch_size_t, decoder_dim) 836 | 837 | # Attention 838 | attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], 839 | h_LM) 840 | 841 | # Decoder LSTM 842 | gate = self.sigmoid(self.f_beta(h_pred[:batch_size_t])) # gating scalar, (batch_size_t, encoder_dim) 843 | attention_weighted_encoding = gate * attention_weighted_encoding 844 | if isinstance(self.decode_step_pred, nn.LSTMCell): 845 | h_pred, c_pred = self.decode_step_pred( 846 | attention_weighted_encoding, 847 | (h_pred[:batch_size_t], c_pred[:batch_size_t])) # (batch_size_t, decoder_dim) 848 | else: 849 | h_pred = self.decode_step_pred( 850 | attention_weighted_encoding, 851 | h_pred[:batch_size_t]) # (batch_size_t, decoder_dim) 852 | predictions[:batch_size_t, t, :] = self.fc(self.dropout(h_pred)) # (batch_size_t, vocab_size) 853 | alphas[:batch_size_t, t, :] = alpha 854 | 855 | return predictions, decode_lengths, alphas, h_LM, h_pred, c_LM, c_pred 856 | 857 | def train_epoch(self, train_loader, encoder, criterion, encoder_optimizer, decoder_optimizer, epoch, args, step=None): 858 | """ 859 | Performs one epoch's training. 860 | 861 | :param train_loader: DataLoader for training data 862 | :param encoder: encoder model 863 | :param decoder: decoder model 864 | :param criterion: loss layer 865 | :param encoder_optimizer: optimizer to update encoder's weights (if fine-tuning) 866 | :param decoder_optimizer: optimizer to update decoder's weights 867 | :param epoch: epoch number 868 | """ 869 | 870 | self.train() # train mode (dropout and batchnorm is used) 871 | encoder.train() 872 | 873 | batch_time = AverageMeter() # forward prop. + back prop. time 874 | data_time = AverageMeter() # data loading time 875 | losses = AverageMeter() # loss (per word decoded) 876 | top1accs = AverageMeter() # top1 accuracy 877 | 878 | start = time.time() 879 | 880 | # Batches 881 | train_loader.shuffle() 882 | for i, (imgs, caps_sorted, caplens) in enumerate(train_loader): 883 | if step is not None: 884 | if i <= step: 885 | continue 886 | data_time.update(time.time() - start) 887 | 888 | # Move to GPU, if available 889 | imgs = imgs.to(device) 890 | caps_sorted = caps_sorted.to(device) 891 | caplens = caplens.to(device) 892 | 893 | # Forward prop. 894 | imgs = encoder(imgs) 895 | # Flatten image 896 | batch_size = imgs.size(0) 897 | encoder_dim = imgs.size(-1) 898 | imgs = imgs.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim) 899 | caplens = caplens.squeeze(1) 900 | 901 | # Sort input data by decreasing lengths 902 | # caplens, sort_ind = caplens.squeeze(1).sort(dim=0, descending=True) 903 | # imgs = imgs[sort_ind] 904 | # caps_sorted = caps[sort_ind] 905 | 906 | # Initialize LSTM state 907 | if isinstance(self.decode_step_pred, nn.LSTMCell): 908 | h_LM, c_LM, h_pred, c_pred = self.init_hidden_state(imgs) # (batch_size, decoder_dim) 909 | else: 910 | h_LM, h_pred = self.init_hidden_state(imgs) # (batch_size, decoder_dim) 911 | c_LM = c_pred = None 912 | 913 | max_cap_length = max(caplens.tolist()) 914 | # TBPTT 915 | j = 0 916 | while j < max_cap_length: 917 | if j == 0: 918 | # bptt tokens after 919 | sub_seq_len = min(args.bptt + 1, max_cap_length - j) 920 | else: 921 | sub_seq_len = min(args.bptt, max_cap_length - j) 922 | # Do not leave too short tails (less than 10 tokens) 923 | short_tail = (caplens - (j + sub_seq_len) < 10) & (caplens - (j + sub_seq_len) > 0) 924 | if short_tail.any(): 925 | sub_seq_len += max((caplens - (j + sub_seq_len))[short_tail].tolist()) 926 | 927 | sub_seq_caplens = caplens - j 928 | sub_seq_caplens[sub_seq_caplens > sub_seq_len] = sub_seq_len 929 | batch_size_t = (sub_seq_caplens > 0).sum().item() 930 | sub_seq_caplens = sub_seq_caplens[:batch_size_t] 931 | sub_seq_cap = caps_sorted[:batch_size_t, j:j + sub_seq_len] 932 | 933 | h_LM = repackage_hidden(h_LM) 934 | h_pred = repackage_hidden(h_pred) 935 | if isinstance(self.decode_step_pred, nn.LSTMCell): 936 | c_LM = repackage_hidden(c_LM) 937 | c_pred = repackage_hidden(c_pred) 938 | 939 | decoder_optimizer.zero_grad() 940 | if encoder_optimizer is not None: 941 | encoder_optimizer.zero_grad() 942 | if j == 0: 943 | scores, decode_lengths, alphas, h_LM, h_pred, c_LM, c_pred = self( 944 | imgs[:batch_size_t], 945 | sub_seq_cap, 946 | sub_seq_caplens, 947 | h_LM, h_pred, c_LM, c_pred) 948 | # Since we decoded starting with , the targets are all words after , up to 949 | targets = sub_seq_cap[:, 1:] 950 | else: 951 | scores, decode_lengths, alphas, h_LM, h_pred, c_LM, c_pred = self( 952 | imgs[:batch_size_t], 953 | sub_seq_cap, 954 | sub_seq_caplens, 955 | h_LM, h_pred, c_LM, c_pred, 956 | caps_sorted[:batch_size_t, j - 1].unsqueeze(1)) 957 | targets = sub_seq_cap 958 | 959 | # Remove timesteps that we didn't decode at, or are pads 960 | # pack_padded_sequence is an easy trick to do this 961 | scores = pack_padded_sequence(scores, decode_lengths, batch_first=True)[0] 962 | targets = pack_padded_sequence(targets, decode_lengths, batch_first=True)[0] 963 | 964 | # Calculate loss 965 | loss = criterion(scores, targets) 966 | 967 | # Add doubly stochastic attention regularization 968 | loss += args.alpha_c * ((1. - alphas.sum(dim=1)) ** 2).mean() 969 | 970 | # Back prop. 971 | if j + sub_seq_len < max_cap_length: 972 | loss.backward(retain_graph=True) 973 | else: 974 | loss.backward() 975 | 976 | # Clip gradients 977 | if args.grad_clip is not None: 978 | clip_gradient(decoder_optimizer, args.grad_clip) 979 | if encoder_optimizer is not None: 980 | clip_gradient(encoder_optimizer, args.grad_clip) 981 | 982 | # Update weights 983 | decoder_optimizer.step() 984 | if encoder_optimizer is not None: 985 | encoder_optimizer.step() 986 | 987 | # Keep track of metrics 988 | top1 = accuracy(scores, targets, 1) 989 | losses.update(loss.item(), sum(decode_lengths)) 990 | top1accs.update(top1, sum(decode_lengths)) 991 | j += sub_seq_len 992 | batch_time.update(time.time() - start) 993 | 994 | start = time.time() 995 | 996 | # Print status 997 | if i % args.print_freq == 0: 998 | print('Epoch: [{0}][{1}/{2}]\t' 999 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 1000 | 'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t' 1001 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 1002 | 'Top-1 Accuracy {top1.val:.3f} ({top1.avg:.3f})'.format(epoch, i, len(train_loader), 1003 | batch_time=batch_time, 1004 | data_time=data_time, loss=losses, 1005 | top1=top1accs), file=sys.stderr) 1006 | sys.stderr.flush() 1007 | 1008 | class TagDecoder(DecoderWithAttention): 1009 | ''' 1010 | TagDecoder generates structure of the table 1011 | ''' 1012 | def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, 1013 | td_encode, decoder_cell=nn.LSTMCell, encoder_dim=512, 1014 | dropout=0.5, cnn_layer_stride=None, tag_H_grad=True): 1015 | """ 1016 | :param attention_dim: size of attention network 1017 | :param embed_dim: embedding size 1018 | :param decoder_dim: size of decoder's RNN 1019 | :param vocab_size: size of vocabulary 1020 | :param encoder_dim: feature size of encoded images 1021 | :param dropout: dropout 1022 | """ 1023 | super(TagDecoder, self).__init__( 1024 | attention_dim, 1025 | embed_dim, 1026 | decoder_dim, 1027 | vocab_size, 1028 | decoder_cell, 1029 | encoder_dim, 1030 | dropout) 1031 | self.td_encode = td_encode 1032 | self.tag_H_grad = tag_H_grad 1033 | if cnn_layer_stride is not None: 1034 | self.input_filter = resnet_block(cnn_layer_stride) 1035 | 1036 | def inference(self, encoder_out, word_map, max_steps=400, beam_size=5, return_attention=False): 1037 | """ 1038 | Inference on test images with beam search 1039 | """ 1040 | if hasattr(self, 'input_filter'): 1041 | encoder_out = self.input_filter(encoder_out.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 1042 | enc_image_size = encoder_out.size(1) 1043 | encoder_dim = encoder_out.size(3) 1044 | 1045 | # Flatten encoding 1046 | encoder_out = encoder_out.view(1, -1, encoder_dim) # (1, num_pixels, encoder_dim) 1047 | num_pixels = encoder_out.size(1) 1048 | 1049 | k = beam_size 1050 | vocab_size = len(word_map) 1051 | 1052 | # We'll treat the problem as having a batch size of k 1053 | encoder_out = encoder_out.expand(k, num_pixels, encoder_dim) # (k, num_pixels, encoder_dim) 1054 | 1055 | # Tensor to store top k previous words at each step; now they're just 1056 | k_prev_words = torch.LongTensor([[word_map['']]] * k).to(device) # (k, 1) 1057 | 1058 | # Tensor to store top k sequences; now they're just 1059 | seqs = k_prev_words # (k, 1) 1060 | 1061 | # Tensor to store top k sequences' scores; now they're just 0 1062 | top_k_scores = torch.zeros(k, 1).to(device) # (k, 1) 1063 | 1064 | # Tensor to store top k sequences' alphas; now they're just 1s 1065 | if return_attention: 1066 | seqs_alpha = torch.ones(k, 1, enc_image_size, enc_image_size).to(device) # (k, 1, enc_image_size, enc_image_size) 1067 | 1068 | # Lists to store completed sequences, their alphas and scores 1069 | complete_seqs = list() 1070 | if return_attention: 1071 | complete_seqs_alpha = list() 1072 | complete_seqs_scores = list() 1073 | complete_seqs_tag_H = list() 1074 | 1075 | # Start decoding 1076 | step = 1 1077 | if isinstance(self.decode_step, nn.LSTMCell): 1078 | h, c = self.init_hidden_state(encoder_out) 1079 | else: 1080 | h = self.init_hidden_state(encoder_out) 1081 | tag_H = [[] for i in range(k)] 1082 | # s is a number less than or equal to k, because sequences are removed from this process once they hit 1083 | while True: 1084 | embeddings = self.embedding(k_prev_words).squeeze(1) # (s, embed_dim) 1085 | 1086 | if return_attention: 1087 | awe, alpha = self.attention(encoder_out, h) # (s, encoder_dim), (s, num_pixels) 1088 | alpha = alpha.view(-1, enc_image_size, enc_image_size) # (s, enc_image_size, enc_image_size) 1089 | else: 1090 | awe, _ = self.attention(encoder_out, h) # (s, encoder_dim), (s, num_pixels) 1091 | 1092 | gate = self.sigmoid(self.f_beta(h)) # gating scalar, (s, encoder_dim) 1093 | awe = gate * awe 1094 | 1095 | if isinstance(self.decode_step, nn.LSTMCell): 1096 | h, c = self.decode_step(torch.cat([embeddings, awe], dim=1), (h, c)) # (s, decoder_dim) 1097 | else: 1098 | h = self.decode_step(torch.cat([embeddings, awe], dim=1), h) # (s, decoder_dim) 1099 | 1100 | h = repackage_hidden(h) 1101 | if isinstance(self.decode_step, nn.LSTMCell): 1102 | c = repackage_hidden(c) 1103 | 1104 | for i, w in enumerate(k_prev_words): 1105 | if w[0].item() in (word_map[''], word_map['>']): 1106 | tag_H[i].append(h[i].unsqueeze(0)) 1107 | 1108 | scores = self.fc(h) # (s, vocab_size) 1109 | scores = F.log_softmax(scores, dim=1) 1110 | 1111 | # Add 1112 | scores = top_k_scores.expand_as(scores) + scores # (s, vocab_size) 1113 | 1114 | # For the first step, all k points will have the same scores (since same k previous words, h, c) 1115 | if step == 1: 1116 | top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s) 1117 | else: 1118 | # Unroll and find top scores, and their unrolled indices 1119 | top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s) 1120 | 1121 | # Convert unrolled indices to actual indices of scores 1122 | prev_word_inds = top_k_words // vocab_size # (s) 1123 | next_word_inds = top_k_words % vocab_size # (s) 1124 | 1125 | 1126 | # Add new words to sequences, alphas 1127 | seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1) 1128 | if return_attention: 1129 | seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)], 1130 | dim=1) # (s, step+1, enc_image_size, enc_image_size) 1131 | 1132 | # Which sequences are incomplete (didn't reach )? 1133 | incomplete_inds = [] 1134 | complete_inds = [] 1135 | for ind, next_word in enumerate(next_word_inds): 1136 | if next_word == word_map['']: 1137 | complete_inds.append(ind) 1138 | else: 1139 | incomplete_inds.append(ind) 1140 | 1141 | # Set aside complete sequences 1142 | if len(complete_inds) > 0: 1143 | complete_seqs.extend(seqs[complete_inds].tolist()) 1144 | if return_attention: 1145 | complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist()) 1146 | complete_seqs_scores.extend(top_k_scores[complete_inds]) 1147 | complete_seqs_tag_H.extend([tag_H[i].copy() for i in prev_word_inds[complete_inds]]) 1148 | k -= len(complete_inds) # reduce beam length accordingly 1149 | 1150 | # Break if all sequences are complete 1151 | if k == 0: 1152 | break 1153 | # Break if things have been going on too long 1154 | if step > max_steps: 1155 | # If no complete sequence is generated, finish the incomplete 1156 | # sequences with 1157 | if not complete_seqs_scores: 1158 | complete_seqs = seqs.tolist() 1159 | for i in range(len(complete_seqs)): 1160 | complete_seqs[i].append(word_map['']) 1161 | if return_attention: 1162 | complete_seqs_alpha = seqs_alpha.tolist() 1163 | complete_seqs_scores = top_k_scores.tolist() 1164 | complete_seqs_tag_H = [tag_H[i].copy() for i in prev_word_inds] 1165 | break 1166 | 1167 | # Proceed with incomplete sequences 1168 | seqs = seqs[incomplete_inds] 1169 | if return_attention: 1170 | seqs_alpha = seqs_alpha[incomplete_inds] 1171 | tag_H = [tag_H[i].copy() for i in prev_word_inds[incomplete_inds]] 1172 | h = h[prev_word_inds[incomplete_inds]] 1173 | if isinstance(self.decode_step, nn.LSTMCell): 1174 | c = c[prev_word_inds[incomplete_inds]] 1175 | encoder_out = encoder_out[prev_word_inds[incomplete_inds]] 1176 | top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1) 1177 | k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1) 1178 | 1179 | step += 1 1180 | i = complete_seqs_scores.index(max(complete_seqs_scores)) 1181 | seq = complete_seqs[i] 1182 | if complete_seqs_tag_H[i]: 1183 | tag_H = torch.cat(complete_seqs_tag_H[i]).to(device) 1184 | else: 1185 | tag_H = torch.zeros(0).to(device) 1186 | if return_attention: 1187 | alphas = complete_seqs_alpha[i] 1188 | return seq, alphas, tag_H 1189 | else: 1190 | return seq, tag_H 1191 | 1192 | def forward(self, encoder_out, encoded_tags_sorted, tag_lengths, num_cells=None, max_tag_len=None): 1193 | # Flatten image 1194 | if hasattr(self, 'input_filter'): 1195 | encoder_out = self.input_filter(encoder_out.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 1196 | batch_size = encoder_out.size(0) 1197 | encoder_dim = encoder_out.size(-1) 1198 | encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim) 1199 | num_pixels = encoder_out.size(1) 1200 | 1201 | # Embedding 1202 | embeddings = self.embedding(encoded_tags_sorted) # (batch_size, max_caption_length, embed_dim) 1203 | # We won't decode at the position, since we've finished generating as soon as we generate 1204 | # So, decoding lengths are actual lengths - 1 1205 | decode_lengths = (tag_lengths - 1).tolist() 1206 | max_decode_lengths = decode_lengths[0] if max_tag_len is None else max_tag_len 1207 | # Create tensors to hold word predicion scores and alphas 1208 | predictions = torch.zeros(batch_size, max_decode_lengths, self.vocab_size).to(device) 1209 | alphas = torch.zeros(batch_size, max_decode_lengths, num_pixels).to(device) 1210 | 1211 | if num_cells is not None: 1212 | # Create tensors to hold hidden state of tag decoder for cell decoder 1213 | tag_H = [torch.zeros(n.item(), self.decoder_dim).to(device) for n in num_cells] 1214 | pointer = torch.zeros(batch_size, dtype=torch.long).to(device) 1215 | 1216 | # Initialize LSTM state 1217 | if isinstance(self.decode_step, nn.LSTMCell): 1218 | h, c = self.init_hidden_state(encoder_out) 1219 | else: 1220 | h = self.init_hidden_state(encoder_out) 1221 | 1222 | # Decode table structure 1223 | for t in range(max_decode_lengths): 1224 | batch_size_t = sum([l > t for l in decode_lengths]) 1225 | if batch_size_t > 0: 1226 | attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], 1227 | h[:batch_size_t]) 1228 | gate = self.sigmoid(self.f_beta(h[:batch_size_t])) # gating scalar, (batch_size_t, encoder_dim) 1229 | attention_weighted_encoding = gate * attention_weighted_encoding 1230 | if isinstance(self.decode_step, nn.LSTMCell): 1231 | h, c = self.decode_step( 1232 | torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1), 1233 | (h[:batch_size_t], c[:batch_size_t])) # (batch_size_t, decoder_dim) 1234 | else: 1235 | h = self.decode_step( 1236 | torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1), 1237 | h[:batch_size_t]) # (batch_size_t, decoder_dim) 1238 | predictions[:batch_size_t, t, :] = self.fc(self.dropout(h)) # (batch_size_t, vocab_size) 1239 | alphas[:batch_size_t, t, :] = alpha 1240 | if num_cells is not None: 1241 | for i in range(batch_size_t): 1242 | if encoded_tags_sorted[i, t] in self.td_encode: 1243 | if self.tag_H_grad: 1244 | tag_H[i][pointer[i]] = h[i] 1245 | else: 1246 | tag_H[i][pointer[i]] = repackage_hidden(h[i]) 1247 | pointer[i] += 1 1248 | if num_cells is None: 1249 | return predictions, decode_lengths, alphas 1250 | else: 1251 | return predictions, decode_lengths, alphas, tag_H 1252 | 1253 | def train_epoch(self, train_loader, encoder, criterion, encoder_optimizer, decoder_optimizer, epoch, args): 1254 | """ 1255 | Performs one epoch's training. 1256 | 1257 | :param train_loader: DataLoader for training data 1258 | :param encoder: encoder model 1259 | :param criterion: loss layer 1260 | :param encoder_optimizer: optimizer to update encoder's weights (if fine-tuning) 1261 | :param decoder_optimizer: optimizer to update decoder's weights 1262 | :param epoch: epoch number 1263 | """ 1264 | self.train() # train mode (dropout and batchnorm is used) 1265 | encoder.train() 1266 | 1267 | batch_time = AverageMeter() # forward prop. + back prop. time 1268 | losses = AverageMeter() # loss (per word decoded) 1269 | top1accs = AverageMeter() # top1 accuracy 1270 | 1271 | start = time.time() 1272 | # Batches 1273 | for i, (imgs, tags, tag_lens) in enumerate(train_loader): 1274 | # Move to GPU, if available 1275 | imgs = imgs.to(device) 1276 | tags = tags.to(device) 1277 | tag_lens = tag_lens.to(device) 1278 | 1279 | # Flatten image 1280 | batch_size = imgs.size(0) 1281 | encoder_dim = imgs.size(-1) 1282 | imgs = imgs.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim) 1283 | tag_lens = tag_lens.squeeze(1) 1284 | 1285 | # Sort input data by decreasing lengths 1286 | tag_lens, sort_ind = tag_lens.sort(dim=0, descending=True) 1287 | imgs = imgs[sort_ind] 1288 | tags_sorted = tags[sort_ind] 1289 | 1290 | # Forward prop. 1291 | imgs = encoder(imgs) 1292 | if hasattr(self, 'input_filter'): 1293 | imgs = self.input_filter(imgs.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 1294 | 1295 | scores_tag, decode_lengths_tag, alphas_tag = self( 1296 | imgs, tags_sorted, tag_lens) 1297 | 1298 | # Calculate tag loss 1299 | targets_tag = tags_sorted[:, 1:] 1300 | scores_tag = pack_padded_sequence(scores_tag, decode_lengths_tag, batch_first=True)[0] 1301 | targets_tag = pack_padded_sequence(targets_tag, decode_lengths_tag, batch_first=True)[0] 1302 | loss = criterion(scores_tag, targets_tag) 1303 | # Add doubly stochastic attention regularization 1304 | loss += args.alpha_c * ((1. - alphas_tag.sum(dim=1)) ** 2).mean() 1305 | top1 = accuracy(scores_tag, targets_tag, 1) 1306 | tag_count = sum(decode_lengths_tag) 1307 | losses.update(loss.item(), tag_count) 1308 | top1accs.update(top1, tag_count) 1309 | 1310 | # Back prop. 1311 | decoder_optimizer.zero_grad() 1312 | if encoder_optimizer is not None: 1313 | encoder_optimizer.zero_grad() 1314 | loss.backward() 1315 | 1316 | # Clip gradients 1317 | if args.grad_clip is not None: 1318 | clip_gradient(decoder_optimizer, args.grad_clip) 1319 | if encoder_optimizer is not None: 1320 | clip_gradient(encoder_optimizer, args.grad_clip) 1321 | 1322 | # Update weights 1323 | decoder_optimizer.step() 1324 | if encoder_optimizer is not None: 1325 | encoder_optimizer.step() 1326 | 1327 | batch_time.update(time.time() - start) 1328 | start = time.time() 1329 | 1330 | # Print status 1331 | if i % args.print_freq == 0: 1332 | print('Epoch: [{0}][{1}/{2}]\t' 1333 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 1334 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 1335 | 'Accuracy {top1.val:.3f} ({top1.avg:.3f})'.format(epoch, i, len(train_loader), 1336 | batch_time=batch_time, 1337 | loss=losses, 1338 | top1=top1accs), file=sys.stderr) 1339 | sys.stderr.flush() 1340 | 1341 | class BBoxLoss(nn.Module): 1342 | def __init__(self): 1343 | super(BBoxLoss, self).__init__() 1344 | self.CE = nn.CrossEntropyLoss() 1345 | 1346 | def bbox_loss(self, gt, pred): 1347 | center_loss = (gt[:, :2] - pred[:, :2]).square().sum(dim=1) 1348 | size_loss = (gt[:, 2:].sqrt() - pred[:, 2:].sqrt()).square().sum(dim=1) 1349 | return center_loss + size_loss 1350 | 1351 | def forward(self, gt, pred): 1352 | empty_loss = self.CE(pred[:, :2], gt[:, 0].long()) # Empty cell classification loss 1353 | bbox_loss = (gt[:, 0] * self.bbox_loss(gt[:, 1:], pred[:, 2:])).mean() # Only compute for non-empty cells 1354 | return empty_loss + bbox_loss 1355 | 1356 | class CellBBox(nn.Module): 1357 | """ 1358 | Regression network for bbox of table cell. 1359 | """ 1360 | 1361 | def __init__(self, tag_decoder_dim): 1362 | """ 1363 | :param tag_decoder_dim: size of tag decoder's RNN 1364 | """ 1365 | super(CellBBox, self).__init__() 1366 | # linear layers to predict bbox (x_c, y_c, w, h) 1367 | self.bbox = nn.Sequential( 1368 | nn.Linear(tag_decoder_dim, tag_decoder_dim), 1369 | nn.ReLU(), 1370 | nn.Linear(tag_decoder_dim, 4), 1371 | nn.Sigmoid() 1372 | ) 1373 | # linear layers to predict if a cell is empty 1374 | self.empty_cls = nn.Sequential( 1375 | nn.Linear(tag_decoder_dim, tag_decoder_dim), 1376 | nn.ReLU(), 1377 | nn.Linear(tag_decoder_dim, 2) 1378 | ) 1379 | 1380 | def forward(self, decoder_hidden): 1381 | """ 1382 | Forward propagation. 1383 | :param decoder_hidden: tag decoder output, a tensor of dimension (batch_size, tag_decoder_dim) 1384 | """ 1385 | batch_size = encoder_out.size(0) 1386 | output = [] 1387 | for i in range(batch_size): 1388 | not_empty = self.empty_cls(decoder_hidden[i]) # (num_cells, 2) 1389 | bbox_pred = self.bbox(decoder_hidden[i]) # (num_cells, 4) 1390 | output.append(torch.cat([not_empty, bbox_pred])) # (num_cells, 6) 1391 | return output 1392 | 1393 | 1394 | class BBoxLoss_Yolo(nn.Module): 1395 | def __init__(self, w_coor=5.0, w_noobj=0.5, image_size=(28, 28)): 1396 | super(BBoxLoss_Yolo, self).__init__() 1397 | self.w_coor = w_coor 1398 | self.w_noobj = w_noobj 1399 | self.image_size = image_size 1400 | 1401 | def IoU(self, pred, idx, gt): 1402 | ''' Calculates IoU between prediction boxes and table cell 1403 | ''' 1404 | pred_xmin = pred[:, 1::5] - pred[:, 3::5] / 2 1405 | pred_xmax = pred[:, 1::5] + pred[:, 3::5] / 2 1406 | pred_ymin = pred[:, 2::5] - pred[:, 4::5] / 2 1407 | pred_ymax = pred[:, 2::5] + pred[:, 4::5] / 2 1408 | gt_xmin = (gt[:, 1] - gt[:, 3] / 2).unsqueeze(1) 1409 | gt_xmax = (gt[:, 1] + gt[:, 3] / 2).unsqueeze(1) 1410 | gt_ymin = (gt[:, 2] - gt[:, 4] / 2).unsqueeze(1) 1411 | gt_ymax = (gt[:, 2] + gt[:, 4] / 2).unsqueeze(1) 1412 | 1413 | I_w = torch.max(torch.FloatTensor([0]), torch.min(pred_xmax, gt_xmax) - torch.max(pred_xmin, gt_xmin)) 1414 | I_h = torch.max(torch.FloatTensor([0]), torch.min(pred_ymax, gt_ymax) - torch.max(pred_ymin, gt_ymin)) 1415 | I = I_w * I_h 1416 | U = pred[:, 3::5] * pred[:, 4::5] + (gt[:, 3] * gt[:, 4]).unsqueeze(1) - I 1417 | IoU = I / (U + 1e-8) # Avoid dividing by 0 1418 | return IoU 1419 | 1420 | def find_responsible_box(self, pred, idx, gt): 1421 | ''' Finds which prediction box is responsible for the table cell 1422 | ''' 1423 | pred = pred[idx[0], idx[1]] 1424 | IoU = self.IoU(pred, gt) 1425 | num_cells = gt.size(0) 1426 | IoU, responsible_box = torch.max(IoU, dim=1) 1427 | return responsible_box, IoU 1428 | 1429 | def forward(self, gt, pred): 1430 | ''' 1431 | :param gt: ground truth (num_cells, 5) 1432 | :param pred: prediction of CellBBoxYolo (num_cells, num_pixels, 5 * num_bboxes_per_pixel) 1433 | ''' 1434 | num_cells = gt.size(0) 1435 | image_width, image_height = self.image_size28 1436 | non_empty_cell = gt[:, 0] == 1 1437 | 1438 | gt_non_empty, gt_empty = gt[non_empty_cell], gt[~non_empty_cell] 1439 | pred_non_empty, pred_empty = pred[non_empty_cell], pred[~non_empty_cell] 1440 | loss_empty = self.w_noobj * pred_empty[:, :, 0::5].square().sum() 1441 | 1442 | # Encode gt as Yolo format 1443 | # Find center pixel 1444 | x_c, y_c = torch.floor(gt_non_empty[:, 1] * image_width), torch.floor(gt_non_empty[:, 2] * image_height) 1445 | idx = (torch.LongTensor(torch.arange(gt_non_empty.size(0))), (x_c * image_width + y_c).long()) 1446 | 1447 | # Compute offset 1448 | gt_non_empty[:, 1], gt_non_empty[:, 2] = gt_non_empty[:, 1] * image_width - x_c, gt_non_empty[:, 2] * image_height - y_c 1449 | gt_non_empty[:, 3], gt_non_empty[:, 4] = gt_non_empty[:, 3] * image_width, gt_non_empty[:, 4] * image_height 1450 | 1451 | responsible_box, IoU = self.find_responsible_box(pred_non_empty, idx, gt_non_empty) 1452 | responsible_box = responsible_box * 5 1453 | gt_non_empty[:, 0] = IoU 1454 | gt_non_empty[:, 3:5] = gt_non_empty[:, 3:5].sqrt() 1455 | 1456 | responsible_box = torch.cat(( 1457 | pred_non_empty[idx[0], idx[1], responsible_box], 1458 | pred_non_empty[idx[0], idx[1], responsible_box + 1], 1459 | pred_non_empty[idx[0], idx[1], responsible_box + 2], 1460 | pred_non_empty[idx[0], idx[1], responsible_box + 3].sqrt(), 1461 | pred_non_empty[idx[0], idx[1], responsible_box + 4].sqrt() 1462 | ), dim=1) 1463 | 1464 | 1465 | loss_coor = self.w_coor * (responsible_box[:, 1:5] - gt_non_empty[:, 1:5]).square().sum() 1466 | loss_noobj = (responsible_box[:, 0] - gt_non_empty[:, 0]).square().sum() + \ 1467 | self.w_noobj * 0 + \ 1468 | loss_empty 1469 | 1470 | return loss_coor + loss_noobj 1471 | 1472 | 1473 | class CellBBoxYolo(nn.Module): 1474 | """ 1475 | NOT READY 1476 | Table cell detection network (based on the idea of Yolo). 1477 | """ 1478 | 1479 | def __init__(self, encoder_dim, tag_decoder_dim, feature_dim, num_bboxes_per_pixel=2): 1480 | """ 1481 | :param encoder_dim: feature size of encoded images 1482 | :param tag_decoder_dim: size of tag decoder's RNN 1483 | :param feature_dim: size of the features 1484 | """ 1485 | super(CellBBoxYolo, self).__init__() 1486 | self.encoder_att = nn.Linear(encoder_dim, feature_dim) # linear layer to transform encoded image 1487 | self.tag_decoder_att = nn.Linear(tag_decoder_dim, feature_dim) # linear layer to transform tag decoder output 1488 | self.bbox = nn.Linear(feature_dim, 5 * num_bboxes_per_pixel) # linear layer to predict bboxes [c, x_c, y_c, w, h] * num_bboxes_per_pixel 1489 | self.relu = nn.ReLU() 1490 | self.sigmoid = nn.Sigmoid() # sigmoid to scale bbox between 0 and 1 1491 | 1492 | def forward(self, encoder_out, decoder_hidden): 1493 | """ 1494 | Forward propagation. 1495 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) 1496 | :param decoder_hidden: tag decoder output, a tensor of dimension [(num_cells, tag_decoder_dim)] * batch_size 1497 | :return: [(num_cells, 5 * num_bboxes_per_pixel)] * batch_size 1498 | """ 1499 | batch_size = encoder_out.size(0) 1500 | output = [] 1501 | for i in range(batch_size): 1502 | att1 = self.encoder_att(encoder_out[i].unsqueeze(0)) # (1, num_pixels, feature_dim) 1503 | att2 = self.tag_decoder_att(decoder_hidden[i]).unsqueeze(1) # (num_cells, 1, feature_dim) 1504 | att = self.relu(att1 + att2) # (num_cells, num_pixels, feature_dim) 1505 | bboxes = self.sigmoid(self.bbox(att)) # (num_cells, num_pixels, 5 * num_bboxes_per_pixel) 1506 | output.append(bboxes) 1507 | return output 1508 | 1509 | 1510 | class CellDecoder_baseline(nn.Module): 1511 | ''' 1512 | CellDecoder generates cell content 1513 | ''' 1514 | def __init__(self, attention_dim, embed_dim, tag_decoder_dim, decoder_dim, 1515 | vocab_size, decoder_cell=nn.LSTMCell, encoder_dim=512, 1516 | dropout=0.5, cnn_layer_stride=None): 1517 | """ 1518 | :param attention_dim: size of attention network 1519 | :param embed_dim: embedding size 1520 | :param tag_decoder_dim: size of tag decoder's RNN 1521 | :param decoder_dim: size of decoder's RNN 1522 | :param vocab_size: size of vocabulary 1523 | :param encoder_dim: feature size of encoded images 1524 | :param dropout: dropout 1525 | :param mini_batch_size: batch size of cells to reduce GPU memory usage 1526 | """ 1527 | super(CellDecoder_baseline, self).__init__() 1528 | 1529 | self.encoder_dim = encoder_dim 1530 | self.attention_dim = attention_dim 1531 | self.embed_dim = embed_dim 1532 | self.decoder_dim = decoder_dim 1533 | self.vocab_size = vocab_size 1534 | self.dropout = dropout 1535 | 1536 | self.attention = CellAttention(encoder_dim, tag_decoder_dim, decoder_dim, attention_dim) # attention network 1537 | 1538 | self.embedding = nn.Embedding(vocab_size, embed_dim) # embedding layer 1539 | 1540 | self.decode_step = decoder_cell(embed_dim + encoder_dim, decoder_dim, bias=True) # decoder LSTMCell 1541 | 1542 | self.init_h = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial hidden state of LSTMCell 1543 | if isinstance(self.decode_step, nn.LSTMCell): 1544 | self.init_c = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial cell state of LSTMCell 1545 | self.f_beta = nn.Linear(decoder_dim, encoder_dim) # linear layer to create a sigmoid-activated gate 1546 | self.sigmoid = nn.Sigmoid() 1547 | self.fc = nn.Linear(decoder_dim, vocab_size) # linear layer to find scores over vocabulary 1548 | self.dropout = nn.Dropout(p=self.dropout) 1549 | 1550 | if cnn_layer_stride is not None: 1551 | self.input_filter = resnet_block(cnn_layer_stride) 1552 | 1553 | self.init_weights() # initialize some layers with the uniform distribution 1554 | 1555 | def init_weights(self): 1556 | """ 1557 | Initializes some parameters with values from the uniform distribution, for easier convergence. 1558 | """ 1559 | self.embedding.weight.data.uniform_(-0.1, 0.1) 1560 | self.fc.bias.data.fill_(0) 1561 | self.fc.weight.data.uniform_(-0.1, 0.1) 1562 | 1563 | def init_hidden_state(self, encoder_out, batch_size): 1564 | """ 1565 | Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images. 1566 | 1567 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) 1568 | :return: hidden state, cell state 1569 | """ 1570 | mean_encoder_out = encoder_out.mean(dim=1) 1571 | h = self.init_h(mean_encoder_out).expand(batch_size, -1) 1572 | if isinstance(self.decode_step, nn.LSTMCell): 1573 | c = self.init_c(mean_encoder_out).expand(batch_size, -1) 1574 | return h, c 1575 | else: 1576 | return h 1577 | 1578 | def inference(self, encoder_out, tag_H, word_map, max_steps=400, beam_size=5, return_attention=False): 1579 | """ 1580 | Inference on test images with beam search 1581 | """ 1582 | 1583 | if hasattr(self, 'input_filter'): 1584 | encoder_out = self.input_filter(encoder_out.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 1585 | enc_image_size = encoder_out.size(1) 1586 | encoder_dim = encoder_out.size(3) 1587 | 1588 | # Flatten encoding 1589 | encoder_out = encoder_out.view(1, -1, encoder_dim) # (1, num_pixels, encoder_dim) 1590 | 1591 | num_cells = tag_H.size(0) 1592 | cell_seqs = [] 1593 | if return_attention: 1594 | cell_alphas = [] 1595 | vocab_size = len(word_map) 1596 | 1597 | for c_id in range(num_cells): 1598 | k = beam_size 1599 | # Tensor to store top k previous words at each step; now they're just 1600 | k_prev_words = torch.LongTensor([[word_map['']]] * k).to(device) # (k, 1) 1601 | 1602 | # Tensor to store top k sequences; now they're just 1603 | seqs = k_prev_words # (k, 1) 1604 | 1605 | # Tensor to store top k sequences' scores; now they're just 0 1606 | top_k_scores = torch.zeros(k, 1).to(device) # (k, 1) 1607 | 1608 | if return_attention: 1609 | # Tensor to store top k sequences' alphas; now they're just 1s 1610 | seqs_alpha = torch.ones(k, 1, enc_image_size, enc_image_size).to(device) # (k, 1, enc_image_size, enc_image_size) 1611 | 1612 | # Lists to store completed sequences, their alphas and scores 1613 | complete_seqs = list() 1614 | if return_attention: 1615 | complete_seqs_alpha = list() 1616 | complete_seqs_scores = list() 1617 | 1618 | # Start decoding 1619 | step = 1 1620 | if isinstance(self.decode_step, nn.LSTMCell): 1621 | h, c = self.init_hidden_state(encoder_out, k) 1622 | else: 1623 | h = self.init_hidden_state(encoder_out, k) 1624 | 1625 | cell_tag_H = tag_H[c_id].expand(k, -1) 1626 | # s is a number less than or equal to k, because sequences are removed from this process once they hit 1627 | while True: 1628 | embeddings = self.embedding(k_prev_words).squeeze(1) # (s, embed_dim) 1629 | if return_attention: 1630 | awe, alpha = self.attention(encoder_out, cell_tag_H, h) # (s, encoder_dim), (s, num_pixels) 1631 | alpha = alpha.view(-1, enc_image_size, enc_image_size) # (s, enc_image_size, enc_image_size) 1632 | else: 1633 | awe, _ = self.attention(encoder_out, cell_tag_H, h) # (s, encoder_dim), (s, num_pixels) 1634 | 1635 | gate = self.sigmoid(self.f_beta(h)) # gating scalar, (s, encoder_dim) 1636 | awe = gate * awe 1637 | 1638 | if isinstance(self.decode_step, nn.LSTMCell): 1639 | h, c = self.decode_step(torch.cat([embeddings, awe], dim=1), (h, c)) # (s, decoder_dim) 1640 | else: 1641 | h = self.decode_step(torch.cat([embeddings, awe], dim=1), h) # (s, decoder_dim) 1642 | 1643 | h = repackage_hidden(h) 1644 | if isinstance(self.decode_step, nn.LSTMCell): 1645 | c = repackage_hidden(c) 1646 | 1647 | scores = self.fc(h) # (s, vocab_size) 1648 | scores = F.log_softmax(scores, dim=1) 1649 | 1650 | # Add 1651 | scores = top_k_scores.expand_as(scores) + scores # (s, vocab_size) 1652 | 1653 | # For the first step, all k points will have the same scores (since same k previous words, h, c) 1654 | if step == 1: 1655 | top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s) 1656 | else: 1657 | # Unroll and find top scores, and their unrolled indices 1658 | top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s) 1659 | 1660 | # Convert unrolled indices to actual indices of scores 1661 | prev_word_inds = top_k_words // vocab_size # (s) 1662 | next_word_inds = top_k_words % vocab_size # (s) 1663 | 1664 | # Add new words to sequences, alphas 1665 | seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1) 1666 | if return_attention: 1667 | seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)], 1668 | dim=1) # (s, step+1, enc_image_size, enc_image_size) 1669 | 1670 | # Which sequences are incomplete (didn't reach )? 1671 | incomplete_inds = [] 1672 | complete_inds = [] 1673 | for ind, next_word in enumerate(next_word_inds): 1674 | if next_word == word_map['']: 1675 | complete_inds.append(ind) 1676 | else: 1677 | incomplete_inds.append(ind) 1678 | 1679 | # Set aside complete sequences 1680 | if len(complete_inds) > 0: 1681 | complete_seqs.extend(seqs[complete_inds].tolist()) 1682 | if return_attention: 1683 | complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist()) 1684 | complete_seqs_scores.extend(top_k_scores[complete_inds]) 1685 | k -= len(complete_inds) # reduce beam length accordingly 1686 | 1687 | # Break if all sequences are complete 1688 | if k == 0: 1689 | break 1690 | # Break if things have been going on too long 1691 | if step > max_steps: 1692 | # If no complete sequence is generated, finish the incomplete 1693 | # sequences with 1694 | if not complete_seqs_scores: 1695 | complete_seqs = seqs.tolist() 1696 | for i in range(len(complete_seqs)): 1697 | complete_seqs[i].append(word_map['']) 1698 | if return_attention: 1699 | complete_seqs_alpha = seqs_alpha.tolist() 1700 | complete_seqs_scores = top_k_scores.tolist() 1701 | break 1702 | 1703 | # Proceed with incomplete sequences 1704 | seqs = seqs[incomplete_inds] 1705 | if return_attention: 1706 | seqs_alpha = seqs_alpha[incomplete_inds] 1707 | cell_tag_H = cell_tag_H[prev_word_inds[incomplete_inds]] 1708 | h = h[prev_word_inds[incomplete_inds]] 1709 | if isinstance(self.decode_step, nn.LSTMCell): 1710 | c = c[prev_word_inds[incomplete_inds]] 1711 | top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1) 1712 | k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1) 1713 | 1714 | step += 1 1715 | i = complete_seqs_scores.index(max(complete_seqs_scores)) 1716 | cell_seqs.append(complete_seqs[i]) 1717 | if return_attention: 1718 | cell_alphas.append(complete_seqs_alpha[i]) 1719 | if return_attention: 1720 | return cell_seqs, cell_alphas 1721 | else: 1722 | return cell_seqs 1723 | 1724 | def forward(self, encoder_out, encoded_cells_sorted, cell_lengths, tag_H): 1725 | """ 1726 | Forward propagation. 1727 | :param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim) 1728 | :param encoded_cells_sorted: encoded cells, a list of batch_size tensors of dimension (num_cells, max_cell_length) 1729 | :param tag_H: hidden state from TagDeoder, a list of batch_size tensors of dimension (num_cells, TagDecoder's decoder_dim) 1730 | :param cell_lengths: caption lengths, a list of batch_size tensor of dimension (num_cells, 1) 1731 | :return: scores for vocabulary, sorted encoded captions, decode lengths, weights 1732 | """ 1733 | if hasattr(self, 'input_filter'): 1734 | encoder_out = self.input_filter(encoder_out.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 1735 | 1736 | # Flatten image 1737 | batch_size = encoder_out.size(0) 1738 | encoder_dim = encoder_out.size(-1) 1739 | encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim) 1740 | num_pixels = encoder_out.size(1) 1741 | 1742 | # Decode cell content 1743 | predictions_cell = [] 1744 | alphas_cell = [] 1745 | decode_lengths_cell = [] 1746 | for i in range(batch_size): 1747 | num_cells = cell_lengths[i].size(0) 1748 | embeddings = self.embedding(encoded_cells_sorted[i]) 1749 | decode_lengths = (cell_lengths[i] - 1).tolist() 1750 | max_decode_lengths = decode_lengths[0] 1751 | predictions = torch.zeros(num_cells, max_decode_lengths, self.vocab_size).to(device) 1752 | alphas = torch.zeros(num_cells, max_decode_lengths, num_pixels).to(device) 1753 | if isinstance(self.decode_step, nn.LSTMCell): 1754 | h, c = self.init_hidden_state(encoder_out[i].unsqueeze(0), num_cells) 1755 | else: 1756 | h = self.init_hidden_state(encoder_out[i].unsqueeze(0), num_cells) 1757 | for t in range(max_decode_lengths): 1758 | batch_size_t = sum([l > t for l in decode_lengths]) 1759 | attention_weighted_encoding, alpha = self.attention(encoder_out[i].unsqueeze(0), 1760 | tag_H[i][:batch_size_t], 1761 | h[:batch_size_t]) 1762 | gate = self.sigmoid(self.f_beta(h[:batch_size_t])) # gating scalar, (batch_size_t, encoder_dim) 1763 | attention_weighted_encoding = gate * attention_weighted_encoding 1764 | if isinstance(self.decode_step, nn.LSTMCell): 1765 | h, c = self.decode_step( 1766 | torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1), 1767 | (h[:batch_size_t], c[:batch_size_t])) # (batch_size_t, decoder_dim) 1768 | else: 1769 | h = self.decode_step( 1770 | torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1), 1771 | h[:batch_size_t]) # (batch_size_t, decoder_dim) 1772 | predictions[:batch_size_t, t, :] = self.fc(self.dropout(h)) # (batch_size_t, vocab_size) 1773 | alphas[:batch_size_t, t, :] = alpha 1774 | predictions_cell.append(predictions) 1775 | alphas_cell.append(alphas) 1776 | decode_lengths_cell.append(decode_lengths) 1777 | return predictions_cell, decode_lengths_cell, alphas_cell 1778 | 1779 | class CellDecoder(nn.Module): 1780 | ''' 1781 | CellDecoder generates cell content 1782 | ''' 1783 | def __init__(self, attention_dim, embed_dim, tag_decoder_dim, language_dim, 1784 | decoder_dim, vocab_size, decoder_cell=nn.LSTMCell, 1785 | encoder_dim=512, dropout=0.5, cnn_layer_stride=None): 1786 | """ 1787 | :param attention_dim: size of attention network 1788 | :param embed_dim: embedding size 1789 | :param tag_decoder_dim: size of tag decoder's RNN 1790 | :param language_dim: size of language model's RNN 1791 | :param decoder_dim: size of decoder's RNN 1792 | :param vocab_size: size of vocabulary 1793 | :param encoder_dim: feature size of encoded images 1794 | :param dropout: dropout 1795 | :param mini_batch_size: batch size of cells to reduce GPU memory usage 1796 | """ 1797 | super(CellDecoder, self).__init__() 1798 | 1799 | self.encoder_dim = encoder_dim 1800 | self.attention_dim = attention_dim 1801 | self.embed_dim = embed_dim 1802 | self.language_dim = language_dim 1803 | self.decoder_dim = decoder_dim 1804 | self.vocab_size = vocab_size 1805 | self.dropout = dropout 1806 | 1807 | self.attention = CellAttention(encoder_dim, tag_decoder_dim, language_dim, attention_dim) # attention network 1808 | 1809 | self.embedding = nn.Embedding(vocab_size, embed_dim) # embedding layer 1810 | 1811 | self.decode_step_LM = decoder_cell(embed_dim, language_dim, bias=True) # language model LSTMCell 1812 | 1813 | self.decode_step_pred = decoder_cell(encoder_dim, decoder_dim, bias=True) # decoding LSTMCell 1814 | self.init_h = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial hidden state of LSTMCell 1815 | if isinstance(self.decode_step_pred, nn.LSTMCell): 1816 | self.init_c = nn.Linear(encoder_dim, decoder_dim) # linear layer to find initial cell state of LSTMCell 1817 | self.f_beta = nn.Linear(decoder_dim, encoder_dim) # linear layer to create a sigmoid-activated gate 1818 | self.sigmoid = nn.Sigmoid() 1819 | self.fc = nn.Linear(decoder_dim, vocab_size) # linear layer to find scores over vocabulary 1820 | self.dropout = nn.Dropout(p=self.dropout) 1821 | 1822 | if cnn_layer_stride is not None: 1823 | self.input_filter = resnet_block(cnn_layer_stride) 1824 | 1825 | self.init_weights() # initialize some layers with the uniform distribution 1826 | 1827 | def init_weights(self): 1828 | """ 1829 | Initializes some parameters with values from the uniform distribution, for easier convergence. 1830 | """ 1831 | self.embedding.weight.data.uniform_(-0.1, 0.1) 1832 | self.fc.bias.data.fill_(0) 1833 | self.fc.weight.data.uniform_(-0.1, 0.1) 1834 | 1835 | def init_hidden_state(self, encoder_out, batch_size): 1836 | """ 1837 | Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images. 1838 | 1839 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) 1840 | :return: hidden state, cell state 1841 | """ 1842 | mean_encoder_out = encoder_out.mean(dim=1) 1843 | h_LM = torch.zeros(batch_size, self.language_dim).to(device) 1844 | h_pred = self.init_h(mean_encoder_out).expand(batch_size, -1) 1845 | if isinstance(self.decode_step_pred, nn.LSTMCell): 1846 | c_LM = torch.zeros(batch_size, self.language_dim).to(device) 1847 | c_pred = self.init_c(mean_encoder_out).expand(batch_size, -1) 1848 | return h_LM, c_LM, h_pred, c_pred 1849 | else: 1850 | return h_LM, h_pred 1851 | 1852 | def inference(self, encoder_out, tag_H, word_map, max_steps=400, beam_size=5, return_attention=False): 1853 | """ 1854 | Inference on test images with beam search 1855 | """ 1856 | if hasattr(self, 'input_filter'): 1857 | encoder_out = self.input_filter(encoder_out.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 1858 | 1859 | enc_image_size = encoder_out.size(1) 1860 | encoder_dim = encoder_out.size(3) 1861 | 1862 | # Flatten encoding 1863 | encoder_out = encoder_out.view(1, -1, encoder_dim) # (1, num_pixels, encoder_dim) 1864 | 1865 | num_cells = tag_H.size(0) 1866 | cell_seqs = [] 1867 | if return_attention: 1868 | cell_alphas = [] 1869 | vocab_size = len(word_map) 1870 | 1871 | for c in range(num_cells): 1872 | k = beam_size 1873 | # Tensor to store top k previous words at each step; now they're just 1874 | k_prev_words = torch.LongTensor([[word_map['']]] * k).to(device) # (k, 1) 1875 | 1876 | # Tensor to store top k sequences; now they're just 1877 | seqs = k_prev_words # (k, 1) 1878 | 1879 | # Tensor to store top k sequences' scores; now they're just 0 1880 | top_k_scores = torch.zeros(k, 1).to(device) # (k, 1) 1881 | 1882 | if return_attention: 1883 | # Tensor to store top k sequences' alphas; now they're just 1s 1884 | seqs_alpha = torch.ones(k, 1, enc_image_size, enc_image_size).to(device) # (k, 1, enc_image_size, enc_image_size) 1885 | 1886 | # Lists to store completed sequences, their alphas and scores 1887 | complete_seqs = list() 1888 | if return_attention: 1889 | complete_seqs_alpha = list() 1890 | complete_seqs_scores = list() 1891 | 1892 | # Start decoding 1893 | step = 1 1894 | if isinstance(self.decode_step_pred, nn.LSTMCell): 1895 | h_LM, c_LM, h_cell, c_cell = self.init_hidden_state(encoder_out, k) 1896 | else: 1897 | h_LM, h_cell = self.init_hidden_state(encoder_out, k) 1898 | 1899 | cell_tag_H = tag_H[c].expand(k, -1) 1900 | # s is a number less than or equal to k, because sequences are removed from this process once they hit 1901 | while True: 1902 | embeddings = self.embedding(k_prev_words).squeeze(1) # (s, embed_dim) 1903 | 1904 | if isinstance(self.decode_step_LM, nn.LSTMCell): 1905 | h_LM, c_LM = self.decode_step_LM(embeddings, (h_LM, c_LM)) 1906 | else: 1907 | h_LM = self.decode_step_LM(embeddings, h_LM) 1908 | 1909 | if return_attention: 1910 | awe, alpha = self.attention( 1911 | encoder_out, 1912 | cell_tag_H, 1913 | h_LM) 1914 | alpha = alpha.view(-1, enc_image_size, enc_image_size) # (s, enc_image_size, enc_image_size) 1915 | else: 1916 | awe, _ = self.attention( 1917 | encoder_out, 1918 | cell_tag_H, 1919 | h_LM) 1920 | gate = self.sigmoid(self.f_beta(h_cell)) # gating scalar, (s, encoder_dim) 1921 | awe = gate * awe 1922 | 1923 | if isinstance(self.decode_step_pred, nn.LSTMCell): 1924 | h_cell, c_cell = self.decode_step_pred(awe, (h_cell, c_cell)) # (batch_size_t, decoder_dim) 1925 | else: 1926 | h_cell = self.decode_step_pred(awe, h_cell) # (batch_size_t, decoder_dim) 1927 | 1928 | h_LM = repackage_hidden(h_LM) 1929 | h_cell = repackage_hidden(h_cell) 1930 | if isinstance(self.decode_step_pred, nn.LSTMCell): 1931 | c_LM = repackage_hidden(c_LM) 1932 | c_cell = repackage_hidden(c_cell) 1933 | 1934 | scores = self.fc(h_cell) # (s, vocab_size) 1935 | scores = F.log_softmax(scores, dim=1) 1936 | 1937 | # Add 1938 | scores = top_k_scores.expand_as(scores) + scores # (s, vocab_size) 1939 | 1940 | # For the first step, all k points will have the same scores (since same k previous words, h, c) 1941 | if step == 1: 1942 | top_k_scores, top_k_words = scores[0].topk(k, 0, True, True) # (s) 1943 | else: 1944 | # Unroll and find top scores, and their unrolled indices 1945 | top_k_scores, top_k_words = scores.view(-1).topk(k, 0, True, True) # (s) 1946 | 1947 | # Convert unrolled indices to actual indices of scores 1948 | prev_word_inds = top_k_words / vocab_size # (s) 1949 | next_word_inds = top_k_words % vocab_size # (s) 1950 | 1951 | # Add new words to sequences, alphas 1952 | seqs = torch.cat([seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1) 1953 | if return_attention: 1954 | seqs_alpha = torch.cat([seqs_alpha[prev_word_inds], alpha[prev_word_inds].unsqueeze(1)], 1955 | dim=1) # (s, step+1, enc_image_size, enc_image_size) 1956 | 1957 | # Which sequences are incomplete (didn't reach )? 1958 | incomplete_inds = [] 1959 | complete_inds = [] 1960 | for ind, next_word in enumerate(next_word_inds): 1961 | if next_word == word_map['']: 1962 | complete_inds.append(ind) 1963 | else: 1964 | incomplete_inds.append(ind) 1965 | 1966 | # Set aside complete sequences 1967 | if len(complete_inds) > 0: 1968 | complete_seqs.extend(seqs[complete_inds].tolist()) 1969 | if return_attention: 1970 | complete_seqs_alpha.extend(seqs_alpha[complete_inds].tolist()) 1971 | complete_seqs_scores.extend(top_k_scores[complete_inds]) 1972 | k -= len(complete_inds) # reduce beam length accordingly 1973 | 1974 | # Break if all sequences are complete 1975 | if k == 0: 1976 | break 1977 | # Break if things have been going on too long 1978 | if step > max_steps: 1979 | # If no complete sequence is generated, finish the incomplete 1980 | # sequences with 1981 | if not complete_seqs_scores: 1982 | complete_seqs = seqs.tolist() 1983 | for i in range(len(complete_seqs)): 1984 | complete_seqs[i].append(word_map['']) 1985 | if return_attention: 1986 | complete_seqs_alpha = seqs_alpha.tolist() 1987 | complete_seqs_scores = top_k_scores.tolist() 1988 | break 1989 | 1990 | # Proceed with incomplete sequences 1991 | seqs = seqs[incomplete_inds] 1992 | if return_attention: 1993 | seqs_alpha = seqs_alpha[incomplete_inds] 1994 | cell_tag_H = cell_tag_H[prev_word_inds[incomplete_inds]] 1995 | h_LM = h_LM[prev_word_inds[incomplete_inds]] 1996 | h_cell = h_cell[prev_word_inds[incomplete_inds]] 1997 | if isinstance(self.decode_step_pred, nn.LSTMCell): 1998 | c_LM = c_LM[prev_word_inds[incomplete_inds]] 1999 | c_cell = c_cell[prev_word_inds[incomplete_inds]] 2000 | top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1) 2001 | k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1) 2002 | 2003 | step += 1 2004 | i = complete_seqs_scores.index(max(complete_seqs_scores)) 2005 | cell_seqs.append(complete_seqs[i]) 2006 | if return_attention: 2007 | cell_alphas.append(complete_seqs_alpha[i]) 2008 | if return_attention: 2009 | return cell_seqs, cell_alphas 2010 | else: 2011 | return cell_seqs 2012 | 2013 | def forward(self, encoder_out, encoded_cells_sorted, cell_lengths, tag_H): 2014 | """ 2015 | Forward propagation. 2016 | :param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim) 2017 | :param encoded_cells_sorted: encoded cells, a list of batch_size tensors of dimension (num_cells, max_cell_length) 2018 | :param tag_H: hidden state from TagDeoder, a list of batch_size tensors of dimension (num_cells, TagDecoder's decoder_dim) 2019 | :param cell_lengths: caption lengths, a list of batch_size tensor of dimension (num_cells, 1) 2020 | :return: scores for vocabulary, sorted encoded captions, decode lengths, weights 2021 | """ 2022 | if hasattr(self, 'input_filter'): 2023 | encoder_out = self.input_filter(encoder_out.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) 2024 | 2025 | # Flatten image 2026 | batch_size = encoder_out.size(0) 2027 | encoder_dim = encoder_out.size(-1) 2028 | encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim) 2029 | num_pixels = encoder_out.size(1) 2030 | 2031 | # Decode cell content 2032 | predictions_cell = [] 2033 | alphas_cell = [] 2034 | decode_lengths_cell = [] 2035 | for i in range(batch_size): 2036 | num_cells = cell_lengths[i].size(0) 2037 | embeddings = self.embedding(encoded_cells_sorted[i]) 2038 | decode_lengths = (cell_lengths[i] - 1).tolist() 2039 | max_decode_lengths = decode_lengths[0] 2040 | predictions = torch.zeros(num_cells, max_decode_lengths, self.vocab_size).to(device) 2041 | alphas = torch.zeros(num_cells, max_decode_lengths, num_pixels).to(device) 2042 | if isinstance(self.decode_step_pred, nn.LSTMCell): 2043 | h_LM, c_LM, h_cell, c_cell = self.init_hidden_state(encoder_out[i].unsqueeze(0), num_cells) 2044 | else: 2045 | h_LM, h_cell = self.init_hidden_state(encoder_out[i].unsqueeze(0), num_cells) 2046 | for t in range(max_decode_lengths): 2047 | batch_size_t = sum([l > t for l in decode_lengths]) 2048 | # Language LSTM 2049 | if isinstance(self.decode_step_LM, nn.LSTMCell): 2050 | h_LM, c_LM = self.decode_step_LM( 2051 | embeddings[:batch_size_t, t, :], 2052 | (h_LM[:batch_size_t], c_LM[:batch_size_t])) # (batch_size_t, decoder_dim) 2053 | else: 2054 | h_LM = self.decode_step_LM( 2055 | embeddings[:batch_size_t, t, :], 2056 | h_LM[:batch_size_t]) # (batch_size_t, decoder_dim) 2057 | 2058 | # Attention 2059 | attention_weighted_encoding, alpha = self.attention( 2060 | encoder_out[i].unsqueeze(0), tag_H[i][:batch_size_t], 2061 | h_LM) 2062 | # Decoder LSTM 2063 | gate = self.sigmoid(self.f_beta(h_cell[:batch_size_t])) # gating scalar, (batch_size_t, encoder_dim) 2064 | attention_weighted_encoding = gate * attention_weighted_encoding 2065 | if isinstance(self.decode_step_pred, nn.LSTMCell): 2066 | h_cell, c_cell = self.decode_step_pred( 2067 | attention_weighted_encoding, 2068 | (h_cell[:batch_size_t], c_cell[:batch_size_t])) # (batch_size_t, decoder_dim) 2069 | else: 2070 | h_cell = self.decode_step_pred( 2071 | attention_weighted_encoding, 2072 | h_cell[:batch_size_t]) # (batch_size_t, decoder_dim) 2073 | predictions[:batch_size_t, t, :] = self.fc(self.dropout(h_cell)) # (batch_size_t, vocab_size) 2074 | alphas[:batch_size_t, t, :] = alpha 2075 | predictions_cell.append(predictions) 2076 | alphas_cell.append(alphas) 2077 | decode_lengths_cell.append(decode_lengths) 2078 | 2079 | return predictions_cell, decode_lengths_cell, alphas_cell 2080 | 2081 | class DualDecoder(nn.Module): 2082 | """ 2083 | Dual decoder model: 2084 | first decoder generates structure of the table 2085 | second decoder generates cell content 2086 | """ 2087 | def __init__(self, tag_attention_dim, cell_attention_dim, tag_embed_dim, cell_embed_dim, 2088 | tag_decoder_dim, language_dim, cell_decoder_dim, 2089 | tag_vocab_size, cell_vocab_size, td_encode, 2090 | decoder_cell=nn.LSTMCell, encoder_dim=512, dropout=0.5, 2091 | cell_decoder_type=1, 2092 | cnn_layer_stride=None, tag_H_grad=True, predict_content=True, predict_bbox=False): 2093 | """ 2094 | :param tag_attention_dim: size of attention network for tags 2095 | :param cell_attention_dim: size of attention network for cells 2096 | :param tag_embed_dim: embedding size of tags 2097 | :param cell_embed_dim: embedding size of cell content 2098 | :param tag_decoder_dim: size of tag decoder's RNN 2099 | :param language_dim: size of language model's RNN 2100 | :param cell_decoder_dim: size of cell decoder's RNN 2101 | :param tag_vocab_size: size of tag vocabulary 2102 | :param cell_vocab_size: size of cellvocabulary 2103 | :param td_encode: encodings for ('', ' >') 2104 | :param encoder_dim: feature size of encoded images 2105 | :param dropout: dropout 2106 | :param mini_batch_size: batch size of cells to reduce GPU memory usage 2107 | """ 2108 | super(DualDecoder, self).__init__() 2109 | 2110 | self.tag_attention_dim = tag_attention_dim 2111 | self.cell_attention_dim = cell_attention_dim 2112 | self.tag_embed_dim = tag_embed_dim 2113 | self.cell_embed_dim = cell_embed_dim 2114 | self.tag_decoder_dim = tag_decoder_dim 2115 | self.language_dim = language_dim 2116 | self.cell_decoder_dim = cell_decoder_dim 2117 | self.tag_vocab_size = tag_vocab_size 2118 | self.cell_vocab_size = cell_vocab_size 2119 | self.decoder_cell = decoder_cell 2120 | self.encoder_dim = encoder_dim 2121 | self.dropout = dropout 2122 | self.td_encode = td_encode 2123 | self.tag_H_grad = tag_H_grad 2124 | self.predict_content = predict_content 2125 | self.predict_bbox = predict_bbox 2126 | self.relu_tag = nn.ReLU() 2127 | self.relu_cell = nn.ReLU() 2128 | 2129 | self.tag_decoder = TagDecoder( 2130 | tag_attention_dim, 2131 | tag_embed_dim, 2132 | tag_decoder_dim, 2133 | tag_vocab_size, 2134 | td_encode, 2135 | decoder_cell, 2136 | encoder_dim, 2137 | dropout, 2138 | cnn_layer_stride['tag'] if isinstance(cnn_layer_stride, dict) else None, 2139 | self.tag_H_grad) 2140 | if cell_decoder_type == 1: 2141 | self.cell_decoder = CellDecoder_baseline( 2142 | cell_attention_dim, 2143 | cell_embed_dim, 2144 | tag_decoder_dim, 2145 | cell_decoder_dim, 2146 | cell_vocab_size, 2147 | decoder_cell, 2148 | encoder_dim, 2149 | dropout, 2150 | cnn_layer_stride['cell'] if isinstance(cnn_layer_stride, dict) else None) 2151 | elif cell_decoder_type == 2: 2152 | self.cell_decoder = CellDecoder( 2153 | cell_attention_dim, 2154 | cell_embed_dim, 2155 | tag_decoder_dim, 2156 | language_dim, 2157 | cell_decoder_dim, 2158 | cell_vocab_size, 2159 | decoder_cell, 2160 | encoder_dim, 2161 | dropout, 2162 | cnn_layer_stride['cell'] if isinstance(cnn_layer_stride, dict) else None) 2163 | self.bbox_loss = BBoxLoss() 2164 | self.cell_bbox_regressor = CellBBox(tag_decoder_dim) 2165 | 2166 | if torch.cuda.device_count() > 1: 2167 | self.tag_decoder = MyDataParallel(self.tag_decoder) 2168 | self.cell_decoder = MyDataParallel(self.cell_decoder) 2169 | self.cell_bbox_regressor = MyDataParallel(self.cell_bbox_regressor) 2170 | 2171 | def load_pretrained_tag_decoder(self, tag_decoder): 2172 | self.tag_decoder = tag_decoder 2173 | 2174 | def fine_tune_tag_decoder(self, fine_tune=False): 2175 | for p in self.tag_decoder.parameters(): 2176 | p.requires_grad = fine_tune 2177 | 2178 | def inference(self, encoder_out, word_map, 2179 | max_steps={'tag': 400, 'cell': 200}, 2180 | beam_size={'tag': 5, 'cell': 5}, 2181 | return_attention=False): 2182 | """ 2183 | Inference on test images with beam search 2184 | """ 2185 | res = self.tag_decoder.inference( 2186 | encoder_out, 2187 | word_map['word_map_tag'], 2188 | max_steps['tag'], 2189 | beam_size['tag'], 2190 | return_attention=return_attention 2191 | ) 2192 | if res is not None: 2193 | output, tag_H = res[:-1], res[-1] 2194 | if self.predict_content: 2195 | cell_res = self.cell_decoder.inference( 2196 | encoder_out, 2197 | tag_H, 2198 | word_map['word_map_cell'], 2199 | max_steps['cell'], 2200 | beam_size['cell'], 2201 | return_attention=return_attention 2202 | ) 2203 | if return_attention: 2204 | cell_seqs, cell_alphas = cell_res 2205 | output += (cell_seqs, cell_alphas) 2206 | else: 2207 | cell_seqs = cell_res 2208 | output += (cell_seqs,) 2209 | if self.predict_bbox: 2210 | cell_bbox = self.cell_bbox_regressor( 2211 | encoder_out, 2212 | tag_H 2213 | ) 2214 | output += (cell_bbox,) 2215 | return output 2216 | else: 2217 | return None 2218 | 2219 | def forward(self, encoder_out, encoded_tags_sorted, tag_lengths, cells=None, cell_lens=None, num_cells=None): 2220 | """ 2221 | Forward propagation. 2222 | :param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim) 2223 | :param encoded_tags_sorted: encoded tags, a tensor of dimension (batch_size, max_tag_length) 2224 | :param tag_lengths: caption lengths, a tensor of dimension (batch_size, 1) 2225 | :param encoded_cells: encoded cells, a list of batch_size tensors of dimension (num_cells, max_cell_length) 2226 | :param cell_lengths: caption lengths, a list of batch_size tensor of dimension (num_cells, 1) 2227 | :return: scores for vocabulary, sorted encoded captions, decode lengths, weights 2228 | """ 2229 | batch_size = encoder_out.size(0) 2230 | N_GPUS = torch.cuda.device_count() 2231 | if N_GPUS > 1 and N_GPUS != batch_size: 2232 | # WHen multiple GPUs are available, rearrange the samples 2233 | # in the batch so that partition is more balanced. This 2234 | # increases training speed and reduce the chance of 2235 | # GPU memory overflow. 2236 | balance_inds = np.arange(np.ceil(batch_size / N_GPUS) * N_GPUS, dtype=int).reshape(-1, N_GPUS).flatten('F')[:batch_size] 2237 | encoder_out = encoder_out[balance_inds] 2238 | encoded_tags_sorted = encoded_tags_sorted[balance_inds] 2239 | tag_lengths = tag_lengths[balance_inds] 2240 | if num_cells is not None: 2241 | num_cells = num_cells[balance_inds] 2242 | if self.predict_content: 2243 | cells = [cells[ind] for ind in balance_inds] 2244 | cell_lens = [cell_lens[ind] for ind in balance_inds] 2245 | 2246 | output = self.tag_decoder( 2247 | encoder_out, 2248 | encoded_tags_sorted, 2249 | tag_lengths, 2250 | num_cells=num_cells if self.predict_content or self.predict_bbox else None, 2251 | max_tag_len=(tag_lengths[0] - 1).item() 2252 | ) 2253 | 2254 | if self.predict_content or self.predict_bbox: 2255 | tag_H = output[-1] 2256 | if self.predict_bbox: 2257 | predictions_cell_bboxes = self.cell_bbox_regressor( 2258 | encoder_out, 2259 | tag_H 2260 | ) 2261 | 2262 | if self.predict_content: 2263 | # Sort cells of each sample by decreasing length 2264 | for j in range(len(cells)): 2265 | cell_lens[j], s_ind = cell_lens[j].sort(dim=0, descending=True) 2266 | cells[j] = cells[j][s_ind] 2267 | tag_H[j] = tag_H[j][s_ind] 2268 | 2269 | predictions_cell, decode_lengths_cell, alphas_cell = self.cell_decoder( 2270 | encoder_out, 2271 | cells, 2272 | cell_lens, 2273 | tag_H 2274 | ) 2275 | 2276 | output = output[:3] 2277 | if self.predict_content: 2278 | output += (predictions_cell, decode_lengths_cell, alphas_cell, cells) 2279 | if self.predict_bbox: 2280 | output += (predictions_cell_bboxes,) 2281 | 2282 | if N_GPUS > 1 and N_GPUS != batch_size: 2283 | # Restore the correct order of samples in the batch to compute 2284 | # the correct loss 2285 | restore_inds = np.arange(np.ceil(batch_size / N_GPUS) * N_GPUS, dtype=int).reshape(N_GPUS, -1).flatten('F')[:batch_size] 2286 | output = tuple([item[ind] for ind in restore_inds] if isinstance(item, list) else item[restore_inds] for item in output) 2287 | return output 2288 | 2289 | def train_epoch(self, train_loader, encoder, criterion, 2290 | encoder_optimizer, tag_decoder_optimizer, cell_decoder_optimizer, cell_bbox_regressor_optimizer, 2291 | epoch, args): 2292 | """ 2293 | Performs one epoch's training. 2294 | 2295 | :param train_loader: DataLoader for training data 2296 | :param encoder: encoder model 2297 | :param criterion: loss layer 2298 | :param encoder_optimizer: optimizer to update encoder's weights (if fine-tuning) 2299 | :param tag_decoder_optimizer: optimizer to update tag decoder's weights 2300 | :param cell_decoder_optimizer: optimizer to update cell decoder's weights 2301 | :param epoch: epoch number 2302 | """ 2303 | self.train() # train mode (dropout and batchnorm is used) 2304 | encoder.train() 2305 | 2306 | batch_time = AverageMeter() # forward prop. + back prop. time 2307 | losses_tag = AverageMeter() # loss (per word decoded) 2308 | losses_total = AverageMeter() # loss (per word decoded) 2309 | top1accs_tag = AverageMeter() # top1 accuracy 2310 | if self.predict_content: 2311 | losses_cell = AverageMeter() # loss (per word decoded) 2312 | top1accs_cell = AverageMeter() # top1 accuracy 2313 | if self.predict_bbox: 2314 | losses_cell_box = AverageMeter() # top1 accuracy 2315 | 2316 | start = time.time() 2317 | # Batches 2318 | train_loader.shuffle() 2319 | for i, batch in enumerate(train_loader): 2320 | try: 2321 | imgs, tags, tag_lens, num_cells = batch[:4] 2322 | # Move to GPU, if available 2323 | imgs = imgs.to(device) 2324 | tags = tags.to(device) 2325 | tag_lens = tag_lens.to(device) 2326 | num_cells = num_cells.to(device) 2327 | if self.predict_content: 2328 | cells, cell_lens = batch[4:6] 2329 | cells = [c.to(device) for c in cells] 2330 | cell_lens = [c.to(device) for c in cell_lens] 2331 | else: 2332 | cells = None 2333 | cell_lens = None 2334 | 2335 | if self.predict_bbox: 2336 | cell_bboxes = batch[-1] 2337 | cell_bboxes = [c.to(device) for c in cell_bboxes] 2338 | 2339 | # Forward prop. 2340 | imgs = encoder(imgs) 2341 | 2342 | # Flatten image 2343 | batch_size = imgs.size(0) 2344 | # encoder_dim = imgs.size(-1) 2345 | # imgs = imgs.view(batch_size, -1, encoder_dim) # (batch_size, num_pixels, encoder_dim) 2346 | tag_lens = tag_lens.squeeze(1) 2347 | 2348 | # Sort input data by decreasing tag lengths 2349 | tag_lens, sort_ind = tag_lens.sort(dim=0, descending=True) 2350 | imgs = imgs[sort_ind] 2351 | tags_sorted = tags[sort_ind] 2352 | num_cells = num_cells[sort_ind] 2353 | if self.predict_content: 2354 | cells = [cells[ind] for ind in sort_ind] 2355 | cell_lens = [cell_lens[ind] for ind in sort_ind] 2356 | if self.predict_bbox: 2357 | cell_bboxes = [cell_bboxes[ind] for ind in sort_ind] 2358 | 2359 | output = self(imgs, tags_sorted, tag_lens, cells, cell_lens, num_cells) 2360 | 2361 | scores_tag, decode_lengths_tag, alphas_tag = output[:3] 2362 | if self.predict_content: 2363 | scores_cell, decode_lengths_cell, alphas_cell, cells = output[3:7] 2364 | if self.predict_bbox: 2365 | predictions_cell_bboxes = output[-1] 2366 | 2367 | # Gather results to the same GPU 2368 | if torch.cuda.device_count() > 1: 2369 | if self.predict_content: 2370 | for s, cell in zip(range(len(scores_cell)), cells): 2371 | if scores_cell[s].get_device() != cell.get_device(): 2372 | scores_cell[s] = scores_cell[s].to(device) 2373 | alphas_cell[s] = alphas_cell[s].to(device) 2374 | if self.predict_bbox: 2375 | for s, cell_bbox in zip(range(len(predictions_cell_bboxes)), cell_bboxes): 2376 | if predictions_cell_bboxes[s].get_device() != cell_bbox.get_device(): 2377 | predictions_cell_bboxes[s] = predictions_cell_bboxes[s].to(device) 2378 | 2379 | # Calculate tag loss 2380 | targets_tag = tags_sorted[:, 1:] 2381 | scores_tag = pack_padded_sequence(scores_tag, decode_lengths_tag, batch_first=True)[0] 2382 | targets_tag = pack_padded_sequence(targets_tag, decode_lengths_tag, batch_first=True)[0] 2383 | loss_tag = criterion['tag'](scores_tag, targets_tag) 2384 | # Add doubly stochastic attention regularization 2385 | # loss_tag += args.alpha_c * ((1. - alphas_tag.sum(dim=1)) ** 2).mean() 2386 | loss_tag += args.alpha_tag * (self.relu_tag(1. - alphas_tag.sum(dim=1)) ** 2).mean() 2387 | loss = args.tag_loss_weight * loss_tag 2388 | top1_tag = accuracy(scores_tag, targets_tag, 1) 2389 | tag_count = sum(decode_lengths_tag) 2390 | losses_tag.update(loss_tag.item(), tag_count) 2391 | top1accs_tag.update(top1_tag, tag_count) 2392 | 2393 | # Calculate cell loss 2394 | if self.predict_content and args.cell_loss_weight > 0: 2395 | loss_cell = 0. 2396 | reg_alphas_cell = 0 2397 | for scores, gt, decode_lengths, alpha in zip(scores_cell, cells, decode_lengths_cell, alphas_cell): 2398 | targets = gt[:, 1:] 2399 | scores = pack_padded_sequence(scores, decode_lengths, batch_first=True)[0] 2400 | targets = pack_padded_sequence(targets, decode_lengths, batch_first=True)[0] 2401 | __loss_cell = criterion['cell'](scores, targets) 2402 | # __loss_cell += args.alpha_c * ((1. - alpha.sum(dim=1)) ** 2).mean() 2403 | reg_alphas_cell += args.alpha_cell * (self.relu_cell(1. - alpha.sum(dim=(0, 1))) ** 2).mean() 2404 | top1_cell = accuracy(scores, targets, 1) 2405 | cell_count = sum(decode_lengths) 2406 | losses_cell.update(__loss_cell.item(), cell_count) 2407 | top1accs_cell.update(top1_cell, cell_count) 2408 | loss_cell += __loss_cell 2409 | loss_cell /= batch_size 2410 | loss_cell += reg_alphas_cell / batch_size 2411 | loss += args.cell_loss_weight * loss_cell 2412 | # Calculate cell bbox loss 2413 | if self.predict_bbox and args.cell_bbox_loss_weight > 0: 2414 | loss_cell_bbox = 0. 2415 | for pred, gt in zip(predictions_cell_bboxes, cell_bboxes): 2416 | __loss_cell_bbox = self.bbox_loss(gt, pred) 2417 | losses_cell_bbox.update(__loss_cell_bbox.item(), pred.size(0)) 2418 | loss_cell_bbox += __loss_cell_bbox 2419 | loss_cell_bbox /= batch_size 2420 | loss += args.cell_bbox_loss_weight * loss_cell_bbox 2421 | 2422 | losses_total.update(loss.item(), 1) 2423 | 2424 | # Back prop. 2425 | if encoder_optimizer is not None: 2426 | encoder_optimizer.zero_grad() 2427 | if tag_decoder_optimizer is not None: 2428 | tag_decoder_optimizer.zero_grad() 2429 | if self.predict_content: 2430 | cell_decoder_optimizer.zero_grad() 2431 | if self.predict_bbox: 2432 | cell_bbox_regressor_optimizer.zero_grad() 2433 | loss.backward() 2434 | 2435 | # Clip gradients 2436 | if args.grad_clip is not None: 2437 | if encoder_optimizer is not None: 2438 | clip_gradient(encoder_optimizer, args.grad_clip) 2439 | if tag_decoder_optimizer is not None: 2440 | clip_gradient(tag_decoder_optimizer, args.grad_clip) 2441 | if self.predict_content: 2442 | clip_gradient(cell_decoder_optimizer, args.grad_clip) 2443 | if self.predict_bbox: 2444 | clip_gradient(cell_bbox_regressor_optimizer, args.grad_clip) 2445 | 2446 | # Update weights 2447 | if encoder_optimizer is not None: 2448 | encoder_optimizer.step() 2449 | if tag_decoder_optimizer is not None: 2450 | tag_decoder_optimizer.step() 2451 | if self.predict_content: 2452 | cell_decoder_optimizer.step() 2453 | if self.predict_bbox: 2454 | cell_bbox_regressor_optimizer.step() 2455 | 2456 | batch_time.update(time.time() - start) 2457 | start = time.time() 2458 | 2459 | # Print status 2460 | if i % args.print_freq == 0: 2461 | verbose = 'Epoch: [{0}][{1}/{2}]\t'.format(epoch, i, len(train_loader)) + \ 2462 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'.format(batch_time=batch_time) + \ 2463 | 'Loss_total {loss_total.val:.4f} ({loss_total.avg:.4f})\t'.format(loss_total=losses_total) + \ 2464 | 'Loss_tag {loss_tag.val:.4f} ({loss_tag.avg:.4f})\t'.format(loss_tag=losses_tag) + \ 2465 | 'Acc_tag {top1_tag.val:.3f} ({top1_tag.avg:.3f})\t'.format(top1_tag=top1accs_tag) 2466 | if self.predict_content: 2467 | verbose += 'Loss_cell {loss_cell.val:.4f} ({loss_cell.avg:.4f})\t'.format(loss_cell=losses_cell) + \ 2468 | 'Acc_cell {top1_cell.val:.3f} ({top1_cell.avg:.3f})\t'.format(top1_cell=top1accs_cell) 2469 | if self.predict_bbox: 2470 | verbose += 'Loss_cell_bbox {loss_cell_bbox.val:.4f} ({loss_cell_bbox.avg:.4f})\t'.format(loss_cell_bbox=losses_cell_bbox) 2471 | 2472 | print(verbose, file=sys.stderr) 2473 | sys.stderr.flush() 2474 | 2475 | batch_time.reset() 2476 | losses_total.reset() 2477 | losses_tag.reset() 2478 | top1accs_tag.reset() 2479 | if self.predict_content: 2480 | losses_cell.reset() 2481 | top1accs_cell.reset() 2482 | if self.predict_bbox: 2483 | losses_cell_bbox.reset() 2484 | except Exception as e: 2485 | raise 2486 | -------------------------------------------------------------------------------- /parallel.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from concurrent.futures import ProcessPoolExecutor, as_completed 3 | 4 | def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=0): 5 | """ 6 | A parallel version of the map function with a progress bar. 7 | 8 | Args: 9 | array (array-like): An array to iterate over. 10 | function (function): A python function to apply to the elements of array 11 | n_jobs (int, default=16): The number of cores to use 12 | use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of 13 | keyword arguments to function 14 | front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job. 15 | Useful for catching bugs 16 | Returns: 17 | [function(array[0]), function(array[1]), ...] 18 | """ 19 | # We run the first few iterations serially to catch bugs 20 | if front_num > 0: 21 | front = [function(**a) if use_kwargs else function(a) for a in array[:front_num]] 22 | else: 23 | front = [] 24 | # If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging. 25 | if n_jobs == 1: 26 | return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:])] 27 | # Assemble the workers 28 | with ProcessPoolExecutor(max_workers=n_jobs) as pool: 29 | # Pass the elements of array into function 30 | if use_kwargs: 31 | futures = [pool.submit(function, **a) for a in array[front_num:]] 32 | else: 33 | futures = [pool.submit(function, a) for a in array[front_num:]] 34 | kwargs = { 35 | 'total': len(futures), 36 | 'unit': 'it', 37 | 'unit_scale': True, 38 | 'leave': True 39 | } 40 | # Print out the progress as tasks complete 41 | for f in tqdm(as_completed(futures), **kwargs): 42 | pass 43 | out = [] 44 | # Get the results from the futures. 45 | for i, future in tqdm(enumerate(futures)): 46 | try: 47 | out.append(future.result()) 48 | except Exception as e: 49 | out.append(e) 50 | return front + out 51 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Prepares data files for training and validating EDD model. The following files 3 | are generated in : 4 | - TRAIN_IMAGES_.h5 # Training images 5 | - TRAIN_TAGS_.json # Training structural tokens 6 | - TRAIN_TAGLENS_.json # Length of training structural tokens 7 | - TRAIN_CELLS_.json # Training cell tokens 8 | - TRAIN_CELLLENS_.json # Length of training cell tokens 9 | - TRAIN_CELLBBOXES_.json # Training cell bboxes 10 | - VAL.json # Validation ground truth 11 | - WORDMAP_.json # Vocab 12 | 13 | is formatted according to input args (keep_AR, max_tag_len, ...) 14 | ''' 15 | import json 16 | import jsonlines 17 | from tqdm import tqdm 18 | import argparse 19 | from collections import Counter 20 | import os 21 | from PIL import Image 22 | import h5py 23 | import numpy as np 24 | from utils import image_rescale 25 | from html import escape 26 | from lxml import html 27 | 28 | def is_valid(img): 29 | if len(img['html']['structure']['tokens']) > args.max_tag_len: 30 | return False 31 | for cell in img['html']['cells']: 32 | if len(cell['tokens']) > args.max_cell_len: 33 | return False 34 | with Image.open(os.path.join(args.image_dir, img['split'], img['filename'])) as im: 35 | if im.width > args.max_image_size or im.height > args.max_image_size: 36 | return False 37 | return True 38 | 39 | def scale(bbox, orig_size): 40 | ''' Normalizes bbox to 0 - 1 41 | ''' 42 | if bbox[0] == 0: 43 | return bbox 44 | else: 45 | x = float((bbox[3] + bbox[1]) / 2) / orig_size[0] # x center 46 | y = float((bbox[4] + bbox[2]) / 2) / orig_size[1] # y center 47 | width = float(bbox[3] - bbox[1]) / orig_size[0] 48 | height = float(bbox[4] - bbox[2]) / orig_size[1] 49 | return [1, x, y, width, height] 50 | 51 | def format_html(img): 52 | ''' Formats HTML code from tokenized annotation of img 53 | ''' 54 | tag_len = len(img['html']['structure']['tokens']) 55 | cell_len_max = max([len(c['tokens']) for c in img['html']['cells']]) 56 | HTML = img['html']['structure']['tokens'].copy() 57 | to_insert = [i for i, tag in enumerate(HTML) if tag in ('', '>')] 58 | for i, cell in zip(to_insert[::-1], img['html']['cells'][::-1]): 59 | if cell: 60 | cell = ''.join([escape(token) if len(token) == 1 else token for token in cell['tokens']]) 61 | HTML.insert(i + 1, cell) 62 | HTML = '%s
' % ''.join(HTML) 63 | root = html.fromstring(HTML) 64 | for td, cell in zip(root.iter('td'), img['html']['cells']): 65 | if 'bbox' in cell: 66 | bbox = cell['bbox'] 67 | td.attrib['x'] = str(bbox[0]) 68 | td.attrib['y'] = str(bbox[1]) 69 | td.attrib['width'] = str(bbox[2] - bbox[0]) 70 | td.attrib['height'] = str(bbox[3] - bbox[1]) 71 | HTML = html.tostring(root, encoding='utf-8').decode() 72 | return HTML, tag_len, cell_len_max 73 | 74 | 75 | parser = argparse.ArgumentParser(description='Prepares data files for training EDD') 76 | parser.add_argument('--annotation', type=str, help='path to annotation file') 77 | parser.add_argument('--image_dir', type=str, help='path to image folder') 78 | parser.add_argument('--out_dir', type=str, help='path to folder to save data files') 79 | parser.add_argument('--min_word_freq', default=5, type=int, help='minimium frequency for a token to be included in vocab') 80 | parser.add_argument('--max_tag_len', default=300, type=int, help='maximium number of structural tokens for a sample to be included') 81 | parser.add_argument('--max_cell_len', default=100, type=int, help='maximium number tokens in a cell for a sample to be included') 82 | parser.add_argument('--max_image_size', default=512, type=int, help='maximium image width/height a sample to be included') 83 | parser.add_argument('--image_size', default=448, type=int, help='target image rescaling size') 84 | parser.add_argument('--keep_AR', default=False, action='store_true', help='keep aspect ratio and pad with zeros when rescaling images') 85 | 86 | args = parser.parse_args() 87 | 88 | # Read image paths and captions for each image 89 | dataset = 'PubTabNet' 90 | train_image_paths = [] 91 | train_image_tags = [] 92 | train_image_cells = [] 93 | train_image_cell_bboxes = [] 94 | val_gt = dict() 95 | word_freq_tag = Counter() 96 | 97 | word_freq_cell = Counter() 98 | with jsonlines.open(args.annotation, 'r') as reader: 99 | for img in tqdm(reader): 100 | if img['split'] == 'train': 101 | if is_valid(img): 102 | tags = [] 103 | cells = [] 104 | cell_bboxes = [] 105 | word_freq_tag.update(img['html']['structure']['tokens']) 106 | tags.append(img['html']['structure']['tokens']) 107 | for cell in img['html']['cells']: 108 | word_freq_cell.update(cell['tokens']) 109 | cells.append(cell['tokens']) 110 | if 'bbox' in cell: 111 | cell_bboxes.append([1] + cell['bbox']) 112 | else: 113 | cell_bboxes.append([0, 0, 0, 0, 0]) 114 | 115 | path = os.path.join(args.image_dir, img['split'], img['filename']) 116 | 117 | train_image_paths.append(path) 118 | train_image_tags.append(tags) 119 | train_image_cells.append(cells) 120 | train_image_cell_bboxes.append(cell_bboxes) 121 | elif img['split'] == 'val': 122 | HTML, tag_len, cell_len_max = format_html(img) 123 | with Image.open(os.path.join(args.image_dir, img['split'], img['filename'])) as im: 124 | val_gt[img['filename']] = { 125 | 'html': HTML, 126 | 'tag_len': tag_len, 127 | 'cell_len_max': cell_len_max, 128 | 'width': im.width, 129 | 'height': im.height, 130 | 'type': 'complex' if '>' in img['html']['structure']['tokens'] else 'simple' 131 | } 132 | 133 | 134 | if not os.path.exists(args.out_dir): 135 | os.makedirs(args.out_dir) 136 | 137 | # Save ground truth html of validation set 138 | with open(os.path.join(args.out_dir, 'VAL.json'), 'w') as j: 139 | json.dump(val_gt, j) 140 | 141 | # Sanity check 142 | assert len(train_image_paths) == len(train_image_tags) 143 | 144 | # Create a base/root name for all output files 145 | base_filename = dataset + '_' + \ 146 | str(args.keep_AR) + '_keep_AR_' + \ 147 | str(args.max_tag_len) + '_max_tag_len_' + \ 148 | str(args.max_cell_len) + '_max_cell_len_' + \ 149 | str(args.max_image_size) + '_max_image_size' 150 | 151 | words_tag = [w for w in word_freq_tag.keys() if word_freq_tag[w] >= args.min_word_freq] 152 | words_cell = [w for w in word_freq_cell.keys() if word_freq_cell[w] >= args.min_word_freq] 153 | 154 | word_map_tag = {k: v + 1 for v, k in enumerate(words_tag)} 155 | word_map_tag[''] = len(word_map_tag) + 1 156 | word_map_tag[''] = len(word_map_tag) + 1 157 | word_map_tag[''] = len(word_map_tag) + 1 158 | word_map_tag[''] = 0 159 | 160 | word_map_cell = {k: v + 1 for v, k in enumerate(words_cell)} 161 | word_map_cell[''] = len(word_map_cell) + 1 162 | word_map_cell[''] = len(word_map_cell) + 1 163 | word_map_cell[''] = len(word_map_cell) + 1 164 | word_map_cell[''] = 0 165 | 166 | # Save word map to a JSON 167 | with open(os.path.join(args.out_dir, 'WORDMAP_' + base_filename + '.json'), 'w') as j: 168 | json.dump({"word_map_tag": word_map_tag, "word_map_cell": word_map_cell}, j) 169 | 170 | with h5py.File(os.path.join(args.out_dir, 'TRAIN_IMAGES_' + base_filename + '.hdf5'), 'a') as h: 171 | dataset_name = 'images' 172 | 173 | # Check if the dataset already exists and delete it if it does 174 | if dataset_name in h: 175 | del h[dataset_name] 176 | 177 | # Create dataset inside HDF5 file to store images 178 | images = h.create_dataset(dataset_name, (len(train_image_paths), 3, args.image_size, args.image_size), dtype='uint8') 179 | 180 | enc_tags = [] 181 | tag_lens = [] 182 | enc_cells = [] 183 | cell_lens = [] 184 | cell_bboxes = [] 185 | 186 | for i, path in enumerate(tqdm(train_image_paths)): 187 | # Read images 188 | img, orig_size = image_rescale(train_image_paths[i], args.image_size, args.keep_AR, return_size=True) 189 | assert img.shape == (3, args.image_size, args.image_size) 190 | assert np.max(img) <= 255 191 | 192 | # Save image to HDF5 file 193 | images[i] = img 194 | 195 | for tag in train_image_tags[i]: 196 | # Encode captions 197 | enc_tag = [word_map_tag['']] + [word_map_tag.get(word, word_map_tag['']) for word in tag] + \ 198 | [word_map_tag['']] + [word_map_tag['']] * (args.max_tag_len - len(tag)) 199 | # Find caption lengths 200 | tag_len = len(tag) + 2 201 | 202 | enc_tags.append(enc_tag) 203 | tag_lens.append(tag_len) 204 | 205 | __enc_cell = [] 206 | __cell_len = [] 207 | for cell in train_image_cells[i]: 208 | # Encode captions 209 | enc_cell = [word_map_cell['']] + [word_map_cell.get(word, word_map_cell['']) for word in cell] + \ 210 | [word_map_cell['']] + [word_map_cell['']] * (args.max_cell_len - len(cell)) 211 | # Find caption lengths 212 | cell_len = len(cell) + 2 213 | 214 | __enc_cell.append(enc_cell) 215 | __cell_len.append(cell_len) 216 | enc_cells.append(__enc_cell) 217 | cell_lens.append(__cell_len) 218 | 219 | __cell_bbox = [] 220 | for bbox in train_image_cell_bboxes[i]: 221 | __cell_bbox.append(scale(bbox, orig_size)) 222 | cell_bboxes.append(__cell_bbox) 223 | 224 | # Save encoded captions and their lengths to JSON files 225 | with open(os.path.join(args.out_dir, 'TRAIN_TAGS_' + base_filename + '.json'), 'w') as j: 226 | json.dump(enc_tags, j) 227 | 228 | with open(os.path.join(args.out_dir, 'TRAIN_TAGLENS_' + base_filename + '.json'), 'w') as j: 229 | json.dump(tag_lens, j) 230 | 231 | with open(os.path.join(args.out_dir, 'TRAIN_CELLS_' + base_filename + '.json'), 'w') as j: 232 | json.dump(enc_cells, j) 233 | 234 | with open(os.path.join(args.out_dir, 'TRAIN_CELLLENS_' + base_filename + '.json'), 'w') as j: 235 | json.dump(cell_lens, j) 236 | 237 | with open(os.path.join(args.out_dir, 'TRAIN_CELLBBOXES_' + base_filename + '.json'), 'w') as j: 238 | json.dump(cell_bboxes, j) 239 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Distance>=0.1.3 2 | apted>=1.0.3 3 | torch>=1.0 4 | torchvision>=0.4.2 5 | 6 | jsonlines 7 | tqdm 8 | Pillow 9 | h5py 10 | lxml 11 | -------------------------------------------------------------------------------- /train_dual_decoder.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Trains encoder-dual-decoder model 3 | ''' 4 | import torch.backends.cudnn as cudnn 5 | import torch.optim 6 | import torch.utils.data 7 | import torchvision.transforms as transforms 8 | from torch import nn 9 | from models import Encoder, DualDecoder 10 | from datasets import * 11 | from utils import * 12 | import argparse 13 | import sys 14 | from glob import glob 15 | import time 16 | 17 | def create_model(): 18 | encoder = Encoder(args.encoded_image_size, 19 | use_RNN=args.use_RNN, 20 | rnn_size=args.encoder_RNN_size, 21 | last_layer_stride=args.cnn_stride if isinstance(args.cnn_stride, int) else None) 22 | encoder.fine_tune(args.fine_tune_encoder) 23 | encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()), 24 | lr=args.encoder_lr) if args.fine_tune_encoder else None 25 | 26 | decoder = DualDecoder(tag_attention_dim=args.tag_attention_dim, 27 | cell_attention_dim=args.cell_attention_dim, 28 | tag_embed_dim=args.tag_embed_dim, 29 | cell_embed_dim=args.cell_embed_dim, 30 | tag_decoder_dim=args.tag_decoder_dim, 31 | language_dim=args.language_dim, 32 | cell_decoder_dim=args.cell_decoder_dim, 33 | tag_vocab_size=len(word_map['word_map_tag']), 34 | cell_vocab_size=len(word_map['word_map_cell']), 35 | td_encode=(word_map['word_map_tag'][''], word_map['word_map_tag']['>']), 36 | decoder_cell=nn.LSTMCell if args.decoder_cell == 'LSTM' else nn.GRUCell, 37 | encoder_dim=512, 38 | dropout=args.dropout, 39 | cell_decoder_type=args.cell_decoder_type, 40 | cnn_layer_stride=args.cnn_stride if isinstance(args.cnn_stride, dict) else None, 41 | tag_H_grad=not args.detach, 42 | predict_content=args.predict_content, 43 | predict_bbox=args.predict_bbox) 44 | decoder.fine_tune_tag_decoder(args.fine_tune_tag_decoder) 45 | tag_decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.tag_decoder.parameters()), 46 | lr=args.tag_decoder_lr) if args.fine_tune_tag_decoder else None 47 | 48 | cell_decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.cell_decoder.parameters()), 49 | lr=args.cell_decoder_lr) if args.predict_content else None 50 | cell_bbox_regressor_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.cell_bbox_regressor.parameters()), 51 | lr=args.cell_bbox_regressor_lr) if args.predict_bbox else None 52 | return encoder, decoder, encoder_optimizer, tag_decoder_optimizer, cell_decoder_optimizer, cell_bbox_regressor_optimizer 53 | 54 | def load_checkpoint(checkpoint): 55 | # Wait until model file exists 56 | if not os.path.isfile(checkpoint): 57 | while not os.path.isfile(checkpoint): 58 | print('Model not found, retry in 10 minutes', file=sys.stderr) 59 | sys.stderr.flush() 60 | time.sleep(600) 61 | # Make sure model file is saved completely 62 | time.sleep(10) 63 | 64 | checkpoint = torch.load(checkpoint) 65 | start_epoch = checkpoint['epoch'] + 1 66 | 67 | encoder = checkpoint['encoder'] 68 | encoder_optimizer = checkpoint['encoder_optimizer'] 69 | encoder.fine_tune(args.fine_tune_encoder) 70 | if args.fine_tune_encoder: 71 | if encoder_optimizer is None: 72 | encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()), 73 | lr=args.encoder_lr) 74 | elif encoder_optimizer.param_groups[0]['lr'] != args.encoder_lr: 75 | change_learning_rate(encoder_optimizer, args.encoder_lr) 76 | print('Encoder LR changed to %f' % args.encoder_lr, file=sys.stderr) 77 | sys.stderr.flush() 78 | 79 | decoder = checkpoint['decoder'] 80 | decoder.tag_H_grad = not args.detach 81 | decoder.tag_decoder.tag_H_grad = not args.detach 82 | decoder.predict_content = args.predict_content 83 | decoder.predict_bbox = args.predict_bbox 84 | 85 | tag_decoder_optimizer = checkpoint['tag_decoder_optimizer'] 86 | decoder.fine_tune_tag_decoder(args.fine_tune_tag_decoder) 87 | if args.fine_tune_tag_decoder: 88 | if tag_decoder_optimizer is None: 89 | tag_decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.tag_decoder.parameters()), 90 | lr=args.tag_decoder_lr) 91 | elif tag_decoder_optimizer.param_groups[0]['lr'] != args.tag_decoder_lr: 92 | change_learning_rate(tag_decoder_optimizer, args.tag_decoder_lr) 93 | print('Tag Decoder LR changed to %f' % args.tag_decoder_lr, file=sys.stderr) 94 | sys.stderr.flush() 95 | 96 | cell_decoder_optimizer = checkpoint['cell_decoder_optimizer'] 97 | if args.predict_content: 98 | if cell_decoder_optimizer is None: 99 | cell_decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.cell_decoder.parameters()), 100 | lr=args.cell_decoder_lr) 101 | elif cell_decoder_optimizer.param_groups[0]['lr'] != args.cell_decoder_lr: 102 | change_learning_rate(cell_decoder_optimizer, args.cell_decoder_lr) 103 | print('Cell Decoder LR changed to %f' % args.cell_decoder_lr, file=sys.stderr) 104 | sys.stderr.flush() 105 | 106 | cell_bbox_regressor_optimizer = checkpoint['cell_bbox_regressor_optimizer'] 107 | if args.predict_bbox: 108 | if cell_bbox_regressor_optimizer is None: 109 | cell_bbox_regressor_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.cell_bbox_regressor.parameters()), 110 | lr=args.cell_bbox_regressor_lr) 111 | elif cell_bbox_regressor_optimizer.param_groups[0]['lr'] != args.cell_bbox_regressor_lr: 112 | change_learning_rate(cell_bbox_regressor_optimizer, args.cell_bbox_regressor_lr) 113 | print('Cell bbox regressor LR changed to %f' % args.cell_bbox_regressor_lr, file=sys.stderr) 114 | sys.stderr.flush() 115 | 116 | return start_epoch, encoder, decoder, encoder_optimizer, tag_decoder_optimizer, cell_decoder_optimizer, cell_bbox_regressor_optimizer 117 | 118 | 119 | if __name__ == '__main__': 120 | parser = argparse.ArgumentParser(description='Train encoder-dual-decoder table2html model') 121 | parser.add_argument('--cnn_stride', default=2, type=json.loads, help='stride for last CNN layer in encoder') 122 | parser.add_argument('--tag_embed_dim', default=16, type=int, help='embedding dimension') 123 | parser.add_argument('--cell_embed_dim', default=16, type=int, help='embedding dimension') 124 | parser.add_argument('--encoded_image_size', default=14, type=int, help='encoded image size') 125 | parser.add_argument('--tag_attention_dim', default=512, type=int, help='tag attention dimension') 126 | parser.add_argument('--cell_attention_dim', default=512, type=int, help='tag attention dimension') 127 | parser.add_argument('--language_dim', default=512, type=int, help='language model dimension') 128 | parser.add_argument('--tag_decoder_dim', default=512, type=int, help='tag decoder dimension') 129 | parser.add_argument('--cell_decoder_dim', default=512, type=int, help='cell decoder dimension') 130 | parser.add_argument('--dropout', default=0.5, type=float, help='dropout') 131 | parser.add_argument('--epochs', default=10, type=int, help='epochs to train') 132 | parser.add_argument('--batch_size', default=16, type=int, help='batch size') 133 | parser.add_argument('--encoder_lr', default=0.001, type=float, help='encoder learning rate') 134 | parser.add_argument('--tag_decoder_lr', default=0.001, type=float, help='tag decoder learning rate') 135 | parser.add_argument('--cell_decoder_lr', default=0.001, type=float, help='cell decoder learning rate') 136 | parser.add_argument('--cell_bbox_regressor_lr', default=0.001, type=float, help='cell bbox regressor learning rate') 137 | parser.add_argument('--grad_clip', default=5., type=float, help='clip gradients at an absolute value') 138 | parser.add_argument('--alpha_tag', default=0., type=float, help='regularization parameter in tag decoder for doubly stochastic attention') 139 | parser.add_argument('--alpha_cell', default=0., type=float, help='regularization parameter in cell decoder for doubly stochastic attention') 140 | parser.add_argument('--tag_loss_weight', default=0.5, type=float, help='weight of tag loss') 141 | parser.add_argument('--cell_loss_weight', default=0.5, type=float, help='weight of cell content loss') 142 | parser.add_argument('--cell_bbox_loss_weight', default=0.0, type=float, help='weight of cell bbox loss') 143 | parser.add_argument('--print_freq', default=100, type=int, help='verbose frequency') 144 | parser.add_argument('--fine_tune_encoder', dest='fine_tune_encoder', action='store_true', help='fine-tune encoder') 145 | parser.add_argument('--fine_tune_tag_decoder', dest='fine_tune_tag_decoder', action='store_true', help='fine-tune tag decoder') 146 | parser.add_argument('--cell_decoder_type', default=1, type=int, help='Type of cell decoder (1: baseline, 2: with LM)') 147 | parser.add_argument('--decoder_cell', default='LSTM', type=str, help='RNN Cell (LSTM or GRU)') 148 | parser.add_argument('--use_RNN', dest='use_RNN', action='store_true', help='transform image features with LSTM') 149 | parser.add_argument('--detach', dest='detach', default=False, action='store_true', help='detach the hidden state between structure and cell decoders') 150 | parser.add_argument('--encoder_RNN_size', default=512, type=int, help='LSTM size for the encoder') 151 | parser.add_argument('--checkpoint', default=None, type=str, help='path to checkpoint') 152 | parser.add_argument('--data_folder', default='data/pubmed_dual', type=str, help='path to folder with data files saved by create_input_files.py') 153 | parser.add_argument('--data_name', default='pubmed_False_keep_AR_300_max_tag_len_100_max_cell_len_512_max_image_size', type=str, help='base name shared by data files') 154 | parser.add_argument('--out_dir', type=str, help='path to save checkpoints') 155 | parser.add_argument('--resume', dest='resume', action='store_true', help='Resume from latest checkpoint if exists') 156 | parser.add_argument('--predict_content', dest='predict_content', default=False, action='store_true', help='Predict cell content') 157 | parser.add_argument('--predict_bbox', dest='predict_bbox', default=False, action='store_true', help='Predict cell bbox') 158 | 159 | args = parser.parse_args() 160 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # sets device for model and PyTorch tensors 161 | cudnn.benchmark = True # set to true only if inputs to model are fixed size; otherwise lot of computational overhead 162 | 163 | # Read word map 164 | word_map_file = os.path.join(args.data_folder, 'WORDMAP_' + args.data_name + '.json') 165 | with open(word_map_file, 'r') as j: 166 | word_map = json.load(j) 167 | 168 | # Initialize / load checkpoint 169 | if args.resume: 170 | existing_ckps = glob(os.path.join(args.out_dir, args.data_name, 'checkpoint_*.pth.tar')) 171 | existing_ckps = [ckp for ckp in existing_ckps if len(os.path.basename(ckp).split('_')) == 2] 172 | if existing_ckps: 173 | existing_ckps = sorted(existing_ckps, key=lambda x: int(os.path.basename(x).split('.')[0].split('_')[1])) 174 | latest_ckp = existing_ckps[-1] 175 | if args.checkpoint is not None: 176 | latest_epoch = int(os.path.basename(latest_ckp).split('.')[0].split('_')[1]) 177 | checkpoint_epoch = int(os.path.basename(args.checkpoint).split('.')[0].split('_')[1]) 178 | if latest_epoch > checkpoint_epoch: 179 | print('Resume from latest checkpoint: %s' % latest_ckp, file=sys.stderr) 180 | start_epoch, encoder, decoder, encoder_optimizer, tag_decoder_optimizer, cell_decoder_optimizer, cell_bbox_regressor_optimizer = load_checkpoint(latest_ckp) 181 | else: 182 | print('Start from checkpoint: %s' % args.checkpoint, file=sys.stderr) 183 | start_epoch, encoder, decoder, encoder_optimizer, tag_decoder_optimizer, cell_decoder_optimizer, cell_bbox_regressor_optimizer = load_checkpoint(args.checkpoint) 184 | else: 185 | print('Resume from latest checkpoint: %s' % latest_ckp, file=sys.stderr) 186 | start_epoch, encoder, decoder, encoder_optimizer, tag_decoder_optimizer, cell_decoder_optimizer, cell_bbox_regressor_optimizer = load_checkpoint(latest_ckp) 187 | elif args.checkpoint is not None: 188 | start_epoch, encoder, decoder, encoder_optimizer, tag_decoder_optimizer, cell_decoder_optimizer, cell_bbox_regressor_optimizer = load_checkpoint(args.checkpoint) 189 | else: 190 | encoder, decoder, encoder_optimizer, tag_decoder_optimizer, cell_decoder_optimizer, cell_bbox_regressor_optimizer = create_model() 191 | start_epoch = 0 192 | else: 193 | if args.checkpoint is not None: 194 | start_epoch, encoder, decoder, encoder_optimizer, tag_decoder_optimizer, cell_decoder_optimizer, cell_bbox_regressor_optimizer = load_checkpoint(args.checkpoint) 195 | else: 196 | encoder, decoder, encoder_optimizer, tag_decoder_optimizer, cell_decoder_optimizer, cell_bbox_regressor_optimizer = create_model() 197 | start_epoch = 0 198 | 199 | # Move to GPU, if available 200 | if torch.cuda.device_count() > 1: 201 | print('Using %d GPUs' % torch.cuda.device_count(), file=sys.stderr) 202 | if not hasattr(encoder, 'module'): 203 | print('Parallelize encoder', file=sys.stderr) 204 | encoder = MyDataParallel(encoder) 205 | if not hasattr(decoder.tag_decoder, 'module'): 206 | print('Parallelize tag decoder', file=sys.stderr) 207 | decoder.tag_decoder = MyDataParallel(decoder.tag_decoder) 208 | if not hasattr(decoder.cell_decoder, 'module'): 209 | print('Parallelize cell decoder', file=sys.stderr) 210 | decoder.cell_decoder = MyDataParallel(decoder.cell_decoder) 211 | decoder = decoder.to(device) 212 | encoder = encoder.to(device) 213 | 214 | # Loss function 215 | criterion = {'tag': nn.CrossEntropyLoss().to(device), 216 | 'cell': nn.CrossEntropyLoss().to(device)} 217 | 218 | # mean and std of PubMed Central table images 219 | normalize = transforms.Normalize(mean=[0.94247851, 0.94254675, 0.94292611], 220 | std=[0.17910956, 0.17940403, 0.17931663]) 221 | mode = 'tag' 222 | if args.predict_content: 223 | mode += '+cell' 224 | if args.predict_bbox: 225 | mode += '+bbox' 226 | train_loader = TagCellDataset(args.data_folder, args.data_name, 'TRAIN', 227 | batch_size=args.batch_size, mode=mode, 228 | transform=transforms.Compose([normalize])) 229 | 230 | # Epochs 231 | for epoch in range(start_epoch, args.epochs): 232 | # One epoch's training 233 | decoder.train_epoch(train_loader=train_loader, 234 | encoder=encoder, 235 | criterion=criterion, 236 | encoder_optimizer=encoder_optimizer, 237 | tag_decoder_optimizer=tag_decoder_optimizer, 238 | cell_decoder_optimizer=cell_decoder_optimizer, 239 | cell_bbox_regressor_optimizer=cell_decoder_optimizer, 240 | epoch=epoch, 241 | args=args) 242 | 243 | # Save checkpoint 244 | save_checkpoint_dual(args.out_dir, args.data_name, epoch, encoder, decoder, encoder_optimizer, 245 | tag_decoder_optimizer, cell_decoder_optimizer, cell_bbox_regressor_optimizer) 246 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn.parallel._functions import Scatter, Gather 6 | from PIL import Image, ImageOps 7 | from math import ceil 8 | 9 | def scatter(inputs, target_gpus, dim=0): 10 | r""" 11 | Slices tensors into approximately equal chunks and 12 | distributes them across given GPUs. Duplicates 13 | references to objects that are not tensors. 14 | """ 15 | def scatter_map(obj): 16 | if isinstance(obj, torch.Tensor): 17 | return Scatter.apply(target_gpus, None, dim, obj) 18 | if isinstance(obj, tuple) and len(obj) > 0: 19 | return list(zip(*map(scatter_map, obj))) 20 | if isinstance(obj, list) and len(obj) > 0: 21 | per_gpu = ceil(len(obj) / len(target_gpus)) 22 | partition = [obj[k * per_gpu: min(len(obj), (k + 1) * per_gpu)] for k, _ in enumerate(target_gpus)] 23 | for i, target in zip(range(len(partition)), target_gpus): 24 | for j in range(len(partition[i])): 25 | partition[i][j] = partition[i][j].to(torch.device('cuda:%d' % target)) 26 | return partition 27 | # return list(map(list, zip(*map(scatter_map, obj)))) 28 | if isinstance(obj, dict) and len(obj) > 0: 29 | return list(map(type(obj), zip(*map(scatter_map, obj.items())))) 30 | return [obj for targets in target_gpus] 31 | 32 | # After scatter_map is called, a scatter_map cell will exist. This cell 33 | # has a reference to the actual function scatter_map, which has references 34 | # to a closure that has a reference to the scatter_map cell (because the 35 | # fn is recursive). To avoid this reference cycle, we set the function to 36 | # None, clearing the cell 37 | try: 38 | return scatter_map(inputs) 39 | finally: 40 | scatter_map = None 41 | 42 | def scatter_kwargs(inputs, kwargs, target_gpus, dim=0): 43 | r"""Scatter with support for kwargs dictionary""" 44 | inputs = scatter(inputs, target_gpus, dim) if inputs else [] 45 | kwargs = scatter(kwargs, target_gpus, dim) if kwargs else [] 46 | if len(inputs) < len(kwargs): 47 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 48 | elif len(kwargs) < len(inputs): 49 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 50 | inputs = tuple(inputs) 51 | kwargs = tuple(kwargs) 52 | return inputs, kwargs 53 | 54 | def gather(outputs, target_device, dim=0): 55 | r""" 56 | Gathers tensors from different GPUs on a specified device 57 | (-1 means the CPU). 58 | """ 59 | def gather_map(outputs): 60 | out = outputs[0] 61 | if isinstance(out, torch.Tensor): 62 | return Gather.apply(target_device, dim, *outputs) 63 | if out is None: 64 | return None 65 | if isinstance(out, dict): 66 | if not all((len(out) == len(d) for d in outputs)): 67 | raise ValueError('All dicts must have the same number of keys') 68 | return type(out)(((k, gather_map([d[k] for d in outputs])) 69 | for k in out)) 70 | if isinstance(out, list): 71 | return [item for output in outputs for item in output] 72 | return type(out)(map(gather_map, zip(*outputs))) 73 | 74 | # Recursive function calls like this create reference cycles. 75 | # Setting the function to None clears the refcycle. 76 | try: 77 | return gather_map(outputs) 78 | finally: 79 | gather_map = None 80 | 81 | class MyDataParallel(nn.DataParallel): 82 | def __init__(self, model): 83 | super(MyDataParallel, self).__init__(model) 84 | 85 | def __getattr__(self, name): 86 | try: 87 | return super(MyDataParallel, self).__getattr__(name) 88 | except AttributeError: 89 | return getattr(self.module, name) 90 | 91 | def scatter(self, inputs, kwargs, device_ids): 92 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) 93 | 94 | def gather(self, outputs, output_device): 95 | return gather(outputs, output_device, dim=self.dim) 96 | 97 | def init_embedding(embeddings): 98 | """ 99 | Fills embedding tensor with values from the uniform distribution. 100 | 101 | :param embeddings: embedding tensor 102 | """ 103 | bias = np.sqrt(3.0 / embeddings.size(1)) 104 | torch.nn.init.uniform_(embeddings, -bias, bias) 105 | 106 | 107 | def load_embeddings(emb_file, word_map): 108 | """ 109 | Creates an embedding tensor for the specified word map, for loading into the model. 110 | 111 | :param emb_file: file containing embeddings (stored in GloVe format) 112 | :param word_map: word map 113 | :return: embeddings in the same order as the words in the word map, dimension of embeddings 114 | """ 115 | 116 | # Find embedding dimension 117 | with open(emb_file, 'r') as f: 118 | emb_dim = len(f.readline().split(' ')) - 1 119 | 120 | vocab = set(word_map.keys()) 121 | 122 | # Create tensor to hold embeddings, initialize 123 | embeddings = torch.FloatTensor(len(vocab), emb_dim) 124 | init_embedding(embeddings) 125 | 126 | # Read embedding file 127 | print("\nLoading embeddings...") 128 | for line in open(emb_file, 'r'): 129 | line = line.split(' ') 130 | 131 | emb_word = line[0] 132 | embedding = list(map(lambda t: float(t), filter(lambda n: n and not n.isspace(), line[1:]))) 133 | 134 | # Ignore word if not in train_vocab 135 | if emb_word not in vocab: 136 | continue 137 | 138 | embeddings[word_map[emb_word]] = torch.FloatTensor(embedding) 139 | 140 | return embeddings, emb_dim 141 | 142 | 143 | def clip_gradient(optimizer, grad_clip): 144 | """ 145 | Clips gradients computed during backpropagation to avoid explosion of gradients. 146 | 147 | :param optimizer: optimizer with the gradients to be clipped 148 | :param grad_clip: clip value 149 | """ 150 | for group in optimizer.param_groups: 151 | for param in group['params']: 152 | if param.grad is not None: 153 | param.grad.data.clamp_(-grad_clip, grad_clip) 154 | 155 | 156 | def save_checkpoint(out_dir, data_name, epoch, encoder, decoder, encoder_optimizer, decoder_optimizer): 157 | """ 158 | Saves model checkpoint. 159 | :param out_dir: output dir 160 | :param data_name: base name of processed dataset 161 | :param epoch: epoch number 162 | :param encoder: encoder model 163 | :param decoder: decoder model 164 | :param encoder_optimizer: optimizer to update encoder's weights, if fine-tuning 165 | :param decoder_optimizer: optimizer to update decoder's weights 166 | """ 167 | state = {'epoch': epoch, 168 | 'encoder': encoder, 169 | 'decoder': decoder, 170 | 'encoder_optimizer': encoder_optimizer, 171 | 'decoder_optimizer': decoder_optimizer} 172 | filename = 'checkpoint_' + str(epoch) + '.pth.tar' 173 | try: 174 | if not os.path.exists(os.path.join(out_dir, data_name)): 175 | os.makedirs(os.path.join(out_dir, data_name)) 176 | torch.save(state, os.path.join(out_dir, data_name, filename)) 177 | except Exception: 178 | torch.save(state, os.path.join(os.environ['RESULT_DIR'], filename)) 179 | 180 | def save_checkpoint_dual(out_dir, data_name, epoch, 181 | encoder, decoder, encoder_optimizer, 182 | tag_decoder_optimizer, cell_decoder_optimizer, 183 | cell_bbox_regressor_optimizer): 184 | """ 185 | Saves EDD model checkpoint. 186 | :param out_dir: output dir 187 | :param data_name: base name of processed dataset 188 | :param epoch: epoch number 189 | :param encoder: encoder model 190 | :param decoder: decoder model 191 | :param encoder_optimizer: optimizer to update encoder's weights, if fine-tuning 192 | :param tag_decoder_optimizer: optimizer to update tag decoder's weights 193 | :param cell_decoder_optimizer: optimizer to update cell decoder's weights 194 | :param cell_bbox_regressor_optimizer: optimizer to update cell bbox regressor's weights 195 | """ 196 | state = {'epoch': epoch, 197 | 'encoder': encoder, 198 | 'decoder': decoder, 199 | 'encoder_optimizer': encoder_optimizer, 200 | 'tag_decoder_optimizer': tag_decoder_optimizer, 201 | 'cell_decoder_optimizer': cell_decoder_optimizer, 202 | 'cell_bbox_regressor_optimizer': cell_bbox_regressor_optimizer} 203 | filename = 'checkpoint_' + str(epoch) + '.pth.tar' 204 | if not os.path.exists(os.path.join(out_dir, data_name)): 205 | os.makedirs(os.path.join(out_dir, data_name)) 206 | torch.save(state, os.path.join(out_dir, data_name, filename)) 207 | 208 | class AverageMeter(object): 209 | """ 210 | Keeps track of most recent, average, sum, and count of a metric. 211 | """ 212 | 213 | def __init__(self): 214 | self.reset() 215 | 216 | def reset(self): 217 | self.val = 0 218 | self.avg = 0 219 | self.sum = 0 220 | self.count = 0 221 | 222 | def update(self, val, n=1): 223 | self.val = val 224 | self.sum += val * n 225 | self.count += n 226 | self.avg = self.sum / self.count 227 | 228 | 229 | def adjust_learning_rate(optimizer, shrink_factor): 230 | """ 231 | Shrinks learning rate by a specified factor. 232 | 233 | :param optimizer: optimizer whose learning rate must be shrunk. 234 | :param shrink_factor: factor in interval (0, 1) to multiply learning rate with. 235 | """ 236 | 237 | print("\nDECAYING learning rate.") 238 | for param_group in optimizer.param_groups: 239 | param_group['lr'] = param_group['lr'] * shrink_factor 240 | print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],)) 241 | 242 | def change_learning_rate(optimizer, new_lr): 243 | """ 244 | Change learning rate. 245 | 246 | :param optimizer: optimizer whose learning rate must be shrunk. 247 | :param new_lr: new learning rate. 248 | """ 249 | for param_group in optimizer.param_groups: 250 | param_group['lr'] = new_lr 251 | 252 | def image_resize(imagepath, image_size, keep_AR=True): 253 | with Image.open(imagepath) as im: 254 | old_size = im.size # old_size[0] is in (width, height) format 255 | if keep_AR: 256 | ratio = float(image_size) / max(old_size) 257 | new_size = tuple([int(x * ratio) for x in old_size]) 258 | im = im.resize(new_size, Image.Resampling.LANCZOS) 259 | delta_w = image_size - new_size[0] 260 | delta_h = image_size - new_size[1] 261 | padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2)) 262 | new_im = ImageOps.expand(im, padding) 263 | else: 264 | new_im = im.resize((image_size, image_size), Image.Resampling.LANCZOS) 265 | return new_im, old_size 266 | 267 | def image_rescale(imagepath, image_size, keep_AR=True, transpose=True, return_size=False): 268 | new_im, old_size = image_resize(imagepath, image_size, keep_AR) 269 | img = np.array(new_im) 270 | if img.shape[2] > 3: 271 | img = img[:, :, :3] 272 | if transpose: 273 | img = img.transpose(2, 0, 1) 274 | if return_size: 275 | return img, old_size 276 | else: 277 | return img 278 | 279 | def accuracy(scores, targets, k): 280 | """ 281 | Computes top-k accuracy, from predicted and true labels. 282 | 283 | :param scores: scores from the model 284 | :param targets: true labels 285 | :param k: k in top-k accuracy 286 | :return: top-k accuracy 287 | """ 288 | 289 | batch_size = targets.size(0) 290 | _, ind = scores.topk(k, 1, True, True) 291 | correct = ind.eq(targets.view(-1, 1).expand_as(ind)) 292 | correct_total = correct.view(-1).float().sum() # 0D tensor 293 | return correct_total.item() * (100.0 / batch_size) 294 | 295 | 296 | if __name__ == '__main__': 297 | pass 298 | --------------------------------------------------------------------------------