├── metric ├── __init__.py ├── parallel.py └── metric_score.py ├── .gitignore ├── backbones ├── __init__.py ├── utils.py └── resnet.py ├── requirements.txt ├── README.md ├── LICENSE ├── create_input.py ├── run_infer.sh ├── run_train.sh ├── distributed.py ├── dataset.py ├── utils.py ├── inference.py ├── train.py └── models.py /metric/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | -------------------------------------------------------------------------------- /backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | apted 2 | distance 3 | lxml 4 | tqdm 5 | h5py 6 | scipy==1.1.0 7 | jsonlines 8 | html3 9 | beautifulsoup4 10 | torch==1.6.0 11 | apex 12 | -------------------------------------------------------------------------------- /backbones/utils.py: -------------------------------------------------------------------------------- 1 | try: 2 | from torch.hub import load_state_dict_from_url 3 | except ImportError: 4 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EDD-third-party 2 | The third-party implements of Encoder Dual Decoder method for table recognition 3 | **Paper:** [*Image-based table recognition: data, model, evaluation*][paper] 4 | **Official implements: [link][EDD_orig_repo_link]** 5 | # Requirements 6 | ```bash 7 | pip install -r requirements 8 | ``` 9 | 10 | # Training and testing on PubTabNet 11 | ### Prepare training data & Training 12 | ```bash 13 | bash run_train.sh 14 | ``` 15 | ### Prepare inference data & Inference with beam search 16 | ```bash 17 | bash run_infer.sh 18 | ``` 19 | ### Model parameters 20 | Trained model with settings as shown in `run_train.sh` can download from [google drive][model]. 21 | 22 | 23 | [EDD_orig_repo_link]:https://github.com/ibm-aur-nlp/EDD 24 | [paper]:https://arxiv.org/pdf/1911.10683.pdf 25 | [model]:https://drive.google.com/file/d/1e2SJ-3A5k0Q4ouaTPfmzindZ3ksjcXh9/view?usp=sharing -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Daquan Lin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /create_input.py: -------------------------------------------------------------------------------- 1 | from utils import create_input_files 2 | import argparse 3 | 4 | 5 | if __name__ == '__main__': 6 | # Create input files (along with word map) 7 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 8 | parser.add_argument("--image_folder", type=str, default="pubtabnet", 9 | help="Source table images' folder.") 10 | parser.add_argument('--output_folder', type=str, default='output_w_none_399k_memory_effi', 11 | help='Output folder to save processed data') 12 | # Training 13 | parser.add_argument("--max_len_token_structure", type=int, default=300, 14 | help="Maximal length of structure's token") 15 | parser.add_argument("--max_len_token_cell", type=int, default=100, 16 | help="Maximal length of each cell's token.") 17 | parser.add_argument("--image_size", type=int, default=80000, 18 | help="Maximal image's height and width.") 19 | args = parser.parse_args() 20 | 21 | create_input_files(image_folder=args.image_folder, 22 | output_folder=args.output_folder, 23 | max_len_token_structure=args.max_len_token_structure, 24 | max_len_token_cell=args.max_len_token_cell, 25 | image_size=args.image_size) 26 | -------------------------------------------------------------------------------- /run_infer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | # Create dataset for validation and test 5 | src_pubtabnet_data_path='pubtabnet' 6 | DATA_FOLDER="output_w_none_all" 7 | 8 | if [[ -d "${DATA_FOLDER}" ]]; then 9 | echo "${DATA_FOLDER} is existing" 10 | else 11 | echo "${DATA_FOLDER} is not existing" 12 | echo "Create dataset ..." 13 | python create_input.py \ 14 | --image_folder $src_pubtabnet_data_path \ 15 | --output_folder $DATA_FOLDER \ 16 | --max_len_token_structure 999999999 \ 17 | --max_len_token_cell 9999999999 18 | fi 19 | 20 | MODEL=$1 21 | SPLIT=$2 22 | backbone="resnext101_32x8d" 23 | word_map_structure="${DATA_FOLDER}/WORDMAP_STRUCTURE.json" 24 | word_map_cell="${DATA_FOLDER}/WORDMAP_CELL.json" 25 | beam_size_structure=3 26 | beam_size_cell=3 27 | T=0.65 28 | 29 | offset=1150 30 | 31 | #for gpu_id in `seq 0 7` 32 | for gpu_id in 0 33 | do 34 | CUDA_VISIBLE_DEVICES=$((gpu_id)) nohup python inference.py \ 35 | --model $MODEL \ 36 | --data_folder ${DATA_FOLDER} \ 37 | --word_map_structure $word_map_structure \ 38 | --word_map_cell $word_map_cell \ 39 | --beam_size_structure $beam_size_structure \ 40 | --beam_size_cell $beam_size_cell \ 41 | --max_seq_len_structure 1536 \ 42 | --max_seq_len_cell 300 \ 43 | --backbone $backbone \ 44 | --EDD_type "S1S1" \ 45 | --image_size 640 \ 46 | --split $SPLIT \ 47 | --print_freq 100 \ 48 | --all \ 49 | --T $T \ 50 | --rank_method "sum" \ 51 | --start_idx $((gpu_id*$offset)) \ 52 | --offset $offset > log_${SPLIT}_all_part_$((gpu_id*$offset))_$((gpu_id*$offset+$offset)).txt & 53 | done 54 | 55 | -------------------------------------------------------------------------------- /run_train.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | #max_structure=300 5 | max_structure_len=1024 6 | max_cell_len=100 7 | 8 | #max_structure=1536 9 | # Create dataset for train and validation 10 | src_pubtabnet_data_path='pubtabnet' 11 | DATA_FOLDER="output_w_none_stru_${max_structure_len}_cellClip_100" 12 | 13 | if [[ -d "${DATA_FOLDER}" ]]; then 14 | echo "${DATA_FOLDER} is existing" 15 | else 16 | echo "${DATA_FOLDER} is not existing" 17 | echo "Create dataset ..." 18 | python create_input.py \ 19 | --image_folder $src_pubtabnet_data_path \ 20 | --output_folder $DATA_FOLDER \ 21 | --max_len_token_structure ${max_structure_len} \ 22 | --max_len_token_cell ${max_cell_len} 23 | fi 24 | 25 | model_dir='checkpoints' 26 | backbone='resnext101_32x8d' 27 | image_size=640 28 | #backbone='resnet18' 29 | #image_size=448 30 | 31 | #STAGE='structure' 32 | STAGE='cell' 33 | 34 | PYTHON_FILE=train.py 35 | 36 | if [[ "$STAGE" == "structure" ]]; then 37 | hyper_loss=1.0 38 | else 39 | hyper_loss=0.5 40 | fi 41 | echo "$hyper_loss" 42 | 43 | 44 | EDD_type='S1S1' 45 | #EDD_type='S2S2' 46 | GPUS_PER_NODE=2 47 | # Change for multinode config 48 | MASTER_ADDR=localhost 49 | MASTER_PORT=6000 50 | NNODES=1 51 | NODE_RANK=0 52 | WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) 53 | 54 | DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT" 55 | 56 | CMD="python -m torch.distributed.launch $DISTRIBUTED_ARGS \ 57 | $PYTHON_FILE \ 58 | --data_folder ${DATA_FOLDER} \ 59 | --num_epochs 26 \ 60 | --batch_size 1 \ 61 | --learning_rate 1e-3 \ 62 | --model_dir $model_dir \ 63 | --backbone $backbone \ 64 | --EDD_type $EDD_type \ 65 | --stage $STAGE \ 66 | --hyper_loss $hyper_loss \ 67 | --first_epoch 1 \ 68 | --second_epoch 1 \ 69 | --print_freq 1 \ 70 | --grad_clip 5.0 \ 71 | --image_size $image_size \ 72 | --max_len_token_structure $max_structure_len " 73 | 74 | if [[ ! -z $1 ]]; then 75 | CMD+="--pretrained_model_path $1 --resume " 76 | fi 77 | 78 | $CMD 79 | -------------------------------------------------------------------------------- /metric/parallel.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from concurrent.futures import ProcessPoolExecutor, as_completed 3 | 4 | 5 | def parallel_process(array, function, n_jobs=16, use_kwargs=False, front_num=0, is_tqdm=False): 6 | """ 7 | A parallel version of the map function with a progress bar. 8 | Args: 9 | array (array-like): An array to iterate over. 10 | function (function): A python function to apply to the elements of array 11 | n_jobs (int, default=16): The number of cores to use 12 | use_kwargs (boolean, default=False): Whether to consider the elements of array as dictionaries of 13 | keyword arguments to function 14 | front_num (int, default=3): The number of iterations to run serially before kicking off the parallel job. 15 | Useful for catching bugs 16 | is_tqdm (boolean, default=False): Show tqdm bar when true 17 | Returns: 18 | [function(array[0]), function(array[1]), ...] 19 | """ 20 | # We run the first few iterations serially to catch bugs 21 | if front_num > 0: 22 | front = [function(**a) if use_kwargs else function(a) 23 | for a in array[:front_num]] 24 | else: 25 | front = [] 26 | # If we set n_jobs to 1, just run a list comprehension. This is useful for benchmarking and debugging. 27 | if n_jobs == 1: 28 | return front + [function(**a) if use_kwargs else function(a) for a in tqdm(array[front_num:], disable=(not is_tqdm))] 29 | # Assemble the workers 30 | with ProcessPoolExecutor(max_workers=n_jobs) as pool: 31 | # Pass the elements of array into function 32 | if use_kwargs: 33 | futures = [pool.submit(function, **a) for a in array[front_num:]] 34 | else: 35 | futures = [pool.submit(function, a) for a in array[front_num:]] 36 | kwargs = { 37 | 'total': len(futures), 38 | 'unit': 'it', 39 | 'unit_scale': True, 40 | 'leave': True 41 | } 42 | # Print out the progress as tasks complete 43 | for f in tqdm(as_completed(futures), disable=(not is_tqdm), **kwargs): 44 | pass 45 | out = [] 46 | # Get the results from the futures. 47 | for i, future in tqdm(enumerate(futures), disable=(not is_tqdm)): 48 | try: 49 | out.append(future.result()) 50 | except Exception as e: 51 | out.append(e) 52 | return front + out 53 | -------------------------------------------------------------------------------- /distributed.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 18 | import torch.distributed as dist 19 | from torch.nn.modules import Module 20 | from torch.autograd import Variable 21 | 22 | 23 | def allreduce_params(model, reduce_after=False, no_scale=False, fp32_allreduce=True): 24 | buckets = {} 25 | for name, param in model.named_parameters(): 26 | if param.requires_grad and param.grad is not None: 27 | tp = (param.data.type()) 28 | if tp not in buckets: 29 | buckets[tp] = [] 30 | buckets[tp].append(param) 31 | 32 | for tp in buckets: 33 | bucket = buckets[tp] 34 | grads = [param.grad.data for param in bucket] 35 | coalesced = _flatten_dense_tensors(grads) 36 | if fp32_allreduce: 37 | coalesced = coalesced.float() 38 | if not no_scale and not reduce_after: 39 | coalesced /= dist.get_world_size() 40 | dist.all_reduce(coalesced) 41 | torch.cuda.synchronize() 42 | if not no_scale and reduce_after: 43 | coalesced /= dist.get_world_size() 44 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): 45 | buf.copy_(synced) 46 | 47 | 48 | def allreduce_params_opt(optimizer, reduce_after=False, no_scale=False, fp32_allreduce=False): 49 | buckets = {} 50 | for group in optimizer.param_groups: 51 | for param in group['params']: 52 | if param.requires_grad and param.grad is not None: 53 | tp = (param.data.type()) 54 | if tp not in buckets: 55 | buckets[tp] = [] 56 | buckets[tp].append(param) 57 | 58 | for tp in buckets: 59 | bucket = buckets[tp] 60 | grads = [param.grad.data for param in bucket] 61 | coalesced = _flatten_dense_tensors(grads) 62 | if fp32_allreduce: 63 | coalesced = coalesced.float() 64 | if not no_scale and not reduce_after: 65 | coalesced /= dist.get_world_size() 66 | dist.all_reduce(coalesced) 67 | torch.cuda.synchronize() 68 | if not no_scale and reduce_after: 69 | coalesced /= dist.get_world_size() 70 | for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): 71 | buf.copy_(synced) 72 | 73 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import h5py 4 | import json 5 | import os 6 | import numpy as np 7 | from scipy.misc import imread, imresize 8 | 9 | 10 | class CaptionDataset(Dataset): 11 | """ 12 | A PyTorch Dataset class to be used in a PyTorch DataLoader to create batches. 13 | """ 14 | 15 | def __init__(self, data_folder, split, transform=None, 16 | max_len_token_structure=300, max_len_token_cell=100, image_size=448): 17 | """ 18 | :param data_folder: folder where data files are stored 19 | :param split: split, one of 'TRAIN', 'VAL', or 'TEST' 20 | :param transform: image transform pipeline 21 | """ 22 | self.split = split 23 | self.image_size = image_size 24 | self.max_len_token_structure = max_len_token_structure 25 | self.max_len_token_cell = max_len_token_cell 26 | 27 | # Load image path 28 | if isinstance(self.split, list): 29 | pass 30 | elif isinstance(self.split, str): 31 | self.split = [self.split] 32 | self.img_paths = [] 33 | self.captions_structure = [] 34 | self.caplens_structure = [] 35 | self.captions_cell = [] 36 | self.caplens_cell = [] 37 | self.number_cell_per_images = [] 38 | 39 | for tmp_split in self.split: 40 | with open(os.path.join(data_folder, tmp_split + '_IMAGE_PATHS.txt'), 'r') as f: 41 | for line in f: 42 | self.img_paths.append(line.strip()) 43 | 44 | print("Split: %s, number of images: %d" % (tmp_split, len(self.img_paths))) 45 | 46 | # Load encoded captions structure 47 | with open(os.path.join(data_folder, tmp_split + '_CAPTIONS_STRUCTURE' + '.json'), 'r') as j: 48 | self.captions_structure.extend(json.load(j)) 49 | 50 | # Load caption structure length (completely into memory) 51 | with open(os.path.join(data_folder, tmp_split + '_CAPLENS_STRUCTURE' + '.json'), 'r') as j: 52 | self.caplens_structure.extend(json.load(j)) 53 | 54 | # Load encoded captions cell 55 | with open(os.path.join(data_folder, tmp_split + '_CAPTIONS_CELL' + '.json'), 'r') as j: 56 | self.captions_cell.extend(json.load(j)) 57 | 58 | # Load caption cell length 59 | with open(os.path.join(data_folder, tmp_split + '_CAPLENS_CELL' + '.json'), 'r') as j: 60 | self.caplens_cell.extend(json.load(j)) 61 | 62 | with open(os.path.join(data_folder, tmp_split + "_NUMBER_CELLS_PER_IMAGE.json"), "r") as j: 63 | self.number_cell_per_images.extend(json.load(j)) 64 | 65 | self.max_cells_per_images = max(self.number_cell_per_images) 66 | # PyTorch transformation pipeline for the image (normalizing, etc.) 67 | self.transform = transform 68 | 69 | # Total number of data image 70 | sored_caplens_structure = [(i, it) for i, it in enumerate(self.caplens_structure)] 71 | sored_caplens_structure.sort(key=lambda x: x[1], reverse=True) 72 | self.idx_mp = [it[0] for i, it in enumerate(sored_caplens_structure)] 73 | self.dataset_size = len(self.idx_mp) 74 | 75 | def __getitem__(self, idx): 76 | # The Nth caption structure corresponds to the Nth image 77 | # Load image 78 | i = self.idx_mp[idx] 79 | img = imread(self.img_paths[i]) 80 | if len(img.shape) == 2: 81 | img = img[:, :, np.newaxis] 82 | img = np.concatenate([img, img, img], axis=2) 83 | img = imresize( 84 | img, (self.image_size, self.image_size), interp="cubic") 85 | img = img.transpose(2, 0, 1) 86 | img = torch.FloatTensor(img / 255.) 87 | 88 | if self.transform is not None: 89 | img = self.transform(img) 90 | 91 | # padding caption structure, 1 dimension 92 | captions_structure = self.captions_structure[i] 93 | captions_structure += [0] * (self.max_len_token_structure + 2 - len(captions_structure)) 94 | 95 | caption_structure = torch.LongTensor(captions_structure) 96 | caplen_structure = torch.LongTensor([self.caplens_structure[i]]) 97 | 98 | # padding caption cell, 2 dimension 99 | captions_cell = self.captions_cell[i] 100 | caplen_cell = self.caplens_cell[i] 101 | 102 | captions_cell = [it + [0]*(self.max_len_token_cell + 2 - len(it)) 103 | for it in captions_cell] 104 | padding_enc_caption_cell = [[0]*(self.max_len_token_cell + 2) 105 | for x in range(self.max_cells_per_images - len(captions_cell))] 106 | padding_len_caption_cell = [0] * (self.max_cells_per_images - len(captions_cell)) 107 | captions_cell += padding_enc_caption_cell 108 | caplen_cell += padding_len_caption_cell 109 | 110 | captions_cell = torch.LongTensor(captions_cell) 111 | caplen_cell = torch.LongTensor(caplen_cell) 112 | 113 | number_cell_per_image = torch.LongTensor( 114 | [self.number_cell_per_images[i]]) 115 | 116 | return img, caption_structure, caplen_structure, captions_cell, caplen_cell, number_cell_per_image 117 | 118 | def __len__(self): 119 | return self.dataset_size 120 | 121 | -------------------------------------------------------------------------------- /metric/metric_score.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 IBM 2 | # Author: peter.zhong@au1.ibm.com 3 | # 4 | # This is free software; you can redistribute it and/or modify 5 | # it under the terms of the Apache 2.0 License. 6 | # 7 | # This software is distributed in the hope that it will be useful, 8 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 9 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 10 | # Apache 2.0 License for more details. 11 | 12 | import distance 13 | from apted import APTED, Config 14 | from apted.helpers import Tree 15 | from lxml import etree, html 16 | from collections import deque 17 | from metric.parallel import parallel_process 18 | from tqdm import tqdm 19 | 20 | 21 | class TableTree(Tree): 22 | def __init__(self, tag, colspan=None, rowspan=None, content=None, *children): 23 | self.tag = tag 24 | self.colspan = colspan 25 | self.rowspan = rowspan 26 | self.content = content 27 | self.children = list(children) 28 | 29 | def bracket(self): 30 | """Show tree using brackets notation""" 31 | if self.tag == 'td': 32 | result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \ 33 | (self.tag, self.colspan, self.rowspan, self.content) 34 | else: 35 | result = '"tag": %s' % self.tag 36 | for child in self.children: 37 | result += child.bracket() 38 | return "{{{}}}".format(result) 39 | 40 | 41 | class CustomConfig(Config): 42 | @staticmethod 43 | def maximum(*sequences): 44 | """Get maximum possible value 45 | """ 46 | return max(map(len, sequences)) 47 | 48 | def normalized_distance(self, *sequences): 49 | """Get distance from 0 to 1 50 | """ 51 | return float(distance.levenshtein(*sequences)) / self.maximum(*sequences) 52 | 53 | def rename(self, node1, node2): 54 | """Compares attributes of trees""" 55 | if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan): 56 | return 1. 57 | if node1.tag == 'td': 58 | if node1.content or node2.content: 59 | return self.normalized_distance(node1.content, node2.content) 60 | return 0. 61 | 62 | 63 | class TEDS(object): 64 | ''' Tree Edit Distance basead Similarity 65 | ''' 66 | 67 | def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None): 68 | assert isinstance(n_jobs, int) and ( 69 | n_jobs >= 1), 'n_jobs must be an integer greather than 1' 70 | self.structure_only = structure_only 71 | self.n_jobs = n_jobs 72 | self.ignore_nodes = ignore_nodes 73 | self.__tokens__ = [] 74 | 75 | def tokenize(self, node): 76 | ''' Tokenizes table cells 77 | ''' 78 | self.__tokens__.append('<%s>' % node.tag) 79 | if node.text is not None: 80 | self.__tokens__ += list(node.text) 81 | for n in node.getchildren(): 82 | self.tokenize(n) 83 | if node.tag != 'unk': 84 | self.__tokens__.append('' % node.tag) 85 | if node.tag != 'td' and node.tail is not None: 86 | self.__tokens__ += list(node.tail) 87 | 88 | def load_html_tree(self, node, parent=None): 89 | ''' Converts HTML tree to the format required by apted 90 | ''' 91 | global __tokens__ 92 | if node.tag == 'td': 93 | if self.structure_only: 94 | cell = [] 95 | else: 96 | self.__tokens__ = [] 97 | self.tokenize(node) 98 | cell = self.__tokens__[1:-1].copy() 99 | new_node = TableTree(node.tag, 100 | int(node.attrib.get('colspan', '1')), 101 | int(node.attrib.get('rowspan', '1')), 102 | cell, *deque()) 103 | else: 104 | new_node = TableTree(node.tag, None, None, None, *deque()) 105 | if parent is not None: 106 | parent.children.append(new_node) 107 | if node.tag != 'td': 108 | for n in node.getchildren(): 109 | self.load_html_tree(n, new_node) 110 | if parent is None: 111 | return new_node 112 | 113 | def evaluate(self, pred, true): 114 | ''' Computes TEDS score between the prediction and the ground truth of a 115 | given sample 116 | ''' 117 | if (not pred) or (not true): 118 | return 0.0 119 | parser = html.HTMLParser(remove_comments=True, encoding='utf-8') 120 | pred = html.fromstring(pred, parser=parser) 121 | true = html.fromstring(true, parser=parser) 122 | if pred.xpath('body/table') and true.xpath('body/table'): 123 | pred = pred.xpath('body/table')[0] 124 | true = true.xpath('body/table')[0] 125 | if self.ignore_nodes: 126 | etree.strip_tags(pred, *self.ignore_nodes) 127 | etree.strip_tags(true, *self.ignore_nodes) 128 | n_nodes_pred = len(pred.xpath(".//*")) 129 | n_nodes_true = len(true.xpath(".//*")) 130 | n_nodes = max(n_nodes_pred, n_nodes_true) 131 | tree_pred = self.load_html_tree(pred) 132 | tree_true = self.load_html_tree(true) 133 | distance = APTED(tree_pred, tree_true, 134 | CustomConfig()).compute_edit_distance() 135 | return 1.0 - (float(distance) / n_nodes) 136 | else: 137 | return 0.0 138 | 139 | def batch_evaluate(self, pred_json, true_json, is_tqdm=True): 140 | ''' Computes TEDS score between the prediction and the ground truth of 141 | a batch of samples 142 | @params pred_json: {'FILENAME': 'HTML CODE', ...} 143 | @params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...} 144 | @params is_tqdm: boolean, show tqdm bar when true 145 | @output: {'FILENAME': 'TEDS SCORE', ...} 146 | ''' 147 | samples = true_json.keys() 148 | if self.n_jobs == 1: 149 | scores = [self.evaluate(pred_json.get( 150 | filename, ''), true_json[filename]['html']) for filename in tqdm(samples, disable=(not is_tqdm))] 151 | else: 152 | inputs = [{'pred': pred_json.get( 153 | filename, ''), 'true': true_json[filename]['html']} for filename in samples] 154 | scores = parallel_process( 155 | inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1, is_tqdm=is_tqdm) 156 | scores = dict(zip(samples, scores)) 157 | return scores 158 | 159 | def batch_evaluate_html(self, pred_htmls, true_htmls, is_tqdm=True): 160 | ''' Computes TEDS score between the prediction and the ground truth of 161 | a batch of samples 162 | ''' 163 | if self.n_jobs == 1: 164 | scores = [self.evaluate(pred_html, true_html) for ( 165 | pred_html, true_html) in zip(pred_htmls, true_htmls)] 166 | else: 167 | inputs = [{"pred": pred_html, "true": true_html} for( 168 | pred_html, true_html) in zip(pred_htmls, true_htmls)] 169 | 170 | scores = parallel_process( 171 | inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1, is_tqdm=is_tqdm) 172 | return scores 173 | 174 | 175 | if __name__ == '__main__': 176 | import json 177 | import pprint 178 | with open('sample_pred.json') as fp: 179 | pred_json = json.load(fp) 180 | with open('sample_gt.json') as fp: 181 | true_json = json.load(fp) 182 | teds = TEDS(n_jobs=4) 183 | scores = teds.batch_evaluate(pred_json, true_json) 184 | pp = pprint.PrettyPrinter() 185 | pp.pprint(scores) 186 | -------------------------------------------------------------------------------- /backbones/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn as nn 4 | from .utils import load_state_dict_from_url 5 | from typing import Type, Any, Callable, Union, List, Optional 6 | 7 | 8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 9 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 10 | 'wide_resnet50_2', 'wide_resnet101_2'] 11 | 12 | 13 | model_urls = { 14 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 15 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 16 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 17 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 18 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 19 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 20 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 21 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 22 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 23 | } 24 | 25 | 26 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 27 | """3x3 convolution with padding""" 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 29 | padding=dilation, groups=groups, bias=False, dilation=dilation) 30 | 31 | 32 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 33 | """1x1 convolution""" 34 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 35 | 36 | 37 | class BasicBlock(nn.Module): 38 | expansion: int = 1 39 | 40 | def __init__( 41 | self, 42 | inplanes: int, 43 | planes: int, 44 | stride: int = 1, 45 | downsample: Optional[nn.Module] = None, 46 | groups: int = 1, 47 | base_width: int = 64, 48 | dilation: int = 1, 49 | norm_layer: Optional[Callable[..., nn.Module]] = None 50 | ) -> None: 51 | super(BasicBlock, self).__init__() 52 | if norm_layer is None: 53 | norm_layer = nn.BatchNorm2d 54 | if groups != 1 or base_width != 64: 55 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 56 | if dilation > 1: 57 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 58 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 59 | self.conv1 = conv3x3(inplanes, planes, stride) 60 | self.bn1 = norm_layer(planes) 61 | self.relu = nn.ReLU(inplace=True) 62 | self.conv2 = conv3x3(planes, planes) 63 | self.bn2 = norm_layer(planes) 64 | self.downsample = downsample 65 | self.stride = stride 66 | 67 | def forward(self, x: Tensor) -> Tensor: 68 | identity = x 69 | 70 | out = self.conv1(x) 71 | out = self.bn1(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv2(out) 75 | out = self.bn2(out) 76 | 77 | if self.downsample is not None: 78 | identity = self.downsample(x) 79 | 80 | out += identity 81 | out = self.relu(out) 82 | 83 | return out 84 | 85 | 86 | class Bottleneck(nn.Module): 87 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 88 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 89 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 90 | # This variant is also known as ResNet V1.5 and improves accuracy according to 91 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 92 | 93 | expansion: int = 4 94 | 95 | def __init__( 96 | self, 97 | inplanes: int, 98 | planes: int, 99 | stride: int = 1, 100 | downsample: Optional[nn.Module] = None, 101 | groups: int = 1, 102 | base_width: int = 64, 103 | dilation: int = 1, 104 | norm_layer: Optional[Callable[..., nn.Module]] = None 105 | ) -> None: 106 | super(Bottleneck, self).__init__() 107 | if norm_layer is None: 108 | norm_layer = nn.BatchNorm2d 109 | width = int(planes * (base_width / 64.)) * groups 110 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 111 | self.conv1 = conv1x1(inplanes, width) 112 | self.bn1 = norm_layer(width) 113 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 114 | self.bn2 = norm_layer(width) 115 | self.conv3 = conv1x1(width, planes * self.expansion) 116 | self.bn3 = norm_layer(planes * self.expansion) 117 | self.relu = nn.ReLU(inplace=True) 118 | self.downsample = downsample 119 | self.stride = stride 120 | 121 | def forward(self, x: Tensor) -> Tensor: 122 | identity = x 123 | 124 | out = self.conv1(x) 125 | out = self.bn1(out) 126 | out = self.relu(out) 127 | 128 | out = self.conv2(out) 129 | out = self.bn2(out) 130 | out = self.relu(out) 131 | 132 | out = self.conv3(out) 133 | out = self.bn3(out) 134 | 135 | if self.downsample is not None: 136 | identity = self.downsample(x) 137 | 138 | out += identity 139 | out = self.relu(out) 140 | 141 | return out 142 | 143 | 144 | class ResNet(nn.Module): 145 | 146 | def __init__( 147 | self, 148 | block: Type[Union[BasicBlock, Bottleneck]], 149 | layers: List[int], 150 | num_classes: int = 1000, 151 | zero_init_residual: bool = False, 152 | groups: int = 1, 153 | width_per_group: int = 64, 154 | replace_stride_with_dilation: Optional[List[bool]] = None, 155 | norm_layer: Optional[Callable[..., nn.Module]] = None, 156 | last_stride: int = 2 157 | ) -> None: 158 | super(ResNet, self).__init__() 159 | if norm_layer is None: 160 | norm_layer = nn.BatchNorm2d 161 | self._norm_layer = norm_layer 162 | 163 | self.inplanes = 64 164 | self.dilation = 1 165 | self.last_stride = last_stride 166 | if replace_stride_with_dilation is None: 167 | # each element in the tuple indicates if we should replace 168 | # the 2x2 stride with a dilated convolution instead 169 | replace_stride_with_dilation = [False, False, False] 170 | if len(replace_stride_with_dilation) != 3: 171 | raise ValueError("replace_stride_with_dilation should be None " 172 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 173 | self.groups = groups 174 | self.base_width = width_per_group 175 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 176 | bias=False) 177 | self.bn1 = norm_layer(self.inplanes) 178 | self.relu = nn.ReLU(inplace=True) 179 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 180 | self.layer1 = self._make_layer(block, 64, layers[0]) 181 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 182 | dilate=replace_stride_with_dilation[0]) 183 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 184 | dilate=replace_stride_with_dilation[1]) 185 | self.layer4 = self._make_layer(block, 512, layers[3], stride=self.last_stride, 186 | dilate=replace_stride_with_dilation[2]) 187 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 188 | self.fc = nn.Linear(512 * block.expansion, num_classes) 189 | 190 | for m in self.modules(): 191 | if isinstance(m, nn.Conv2d): 192 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 193 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 194 | nn.init.constant_(m.weight, 1) 195 | nn.init.constant_(m.bias, 0) 196 | 197 | # Zero-initialize the last BN in each residual branch, 198 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 199 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 200 | if zero_init_residual: 201 | for m in self.modules(): 202 | if isinstance(m, Bottleneck): 203 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 204 | elif isinstance(m, BasicBlock): 205 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 206 | 207 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, 208 | stride: int = 1, dilate: bool = False) -> nn.Sequential: 209 | norm_layer = self._norm_layer 210 | downsample = None 211 | previous_dilation = self.dilation 212 | if dilate: 213 | self.dilation *= stride 214 | stride = 1 215 | if stride != 1 or self.inplanes != planes * block.expansion: 216 | downsample = nn.Sequential( 217 | conv1x1(self.inplanes, planes * block.expansion, stride), 218 | norm_layer(planes * block.expansion), 219 | ) 220 | 221 | layers = [] 222 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 223 | self.base_width, previous_dilation, norm_layer)) 224 | self.inplanes = planes * block.expansion 225 | for _ in range(1, blocks): 226 | layers.append(block(self.inplanes, planes, groups=self.groups, 227 | base_width=self.base_width, dilation=self.dilation, 228 | norm_layer=norm_layer)) 229 | 230 | return nn.Sequential(*layers) 231 | 232 | def _forward_impl(self, x: Tensor) -> Tensor: 233 | # See note [TorchScript super()] 234 | x = self.conv1(x) 235 | x = self.bn1(x) 236 | x = self.relu(x) 237 | x = self.maxpool(x) 238 | 239 | x = self.layer1(x) 240 | x = self.layer2(x) 241 | x = self.layer3(x) 242 | x = self.layer4(x) 243 | 244 | x = self.avgpool(x) 245 | x = torch.flatten(x, 1) 246 | x = self.fc(x) 247 | 248 | return x 249 | 250 | def forward(self, x: Tensor) -> Tensor: 251 | return self._forward_impl(x) 252 | 253 | 254 | def _resnet( 255 | arch: str, 256 | block: Type[Union[BasicBlock, Bottleneck]], 257 | layers: List[int], 258 | pretrained: bool, 259 | progress: bool, 260 | **kwargs: Any 261 | ) -> ResNet: 262 | model = ResNet(block, layers, **kwargs) 263 | if pretrained: 264 | state_dict = load_state_dict_from_url(model_urls[arch], 265 | progress=progress) 266 | model.load_state_dict(state_dict) 267 | return model 268 | 269 | 270 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 271 | r"""ResNet-18 model from 272 | `"Deep Residual Learning for Image Recognition" `_. 273 | Args: 274 | pretrained (bool): If True, returns a model pre-trained on ImageNet 275 | progress (bool): If True, displays a progress bar of the download to stderr 276 | """ 277 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 278 | **kwargs) 279 | 280 | 281 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 282 | r"""ResNet-34 model from 283 | `"Deep Residual Learning for Image Recognition" `_. 284 | Args: 285 | pretrained (bool): If True, returns a model pre-trained on ImageNet 286 | progress (bool): If True, displays a progress bar of the download to stderr 287 | """ 288 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 289 | **kwargs) 290 | 291 | 292 | def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 293 | r"""ResNet-50 model from 294 | `"Deep Residual Learning for Image Recognition" `_. 295 | Args: 296 | pretrained (bool): If True, returns a model pre-trained on ImageNet 297 | progress (bool): If True, displays a progress bar of the download to stderr 298 | """ 299 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 300 | **kwargs) 301 | 302 | 303 | def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 304 | r"""ResNet-101 model from 305 | `"Deep Residual Learning for Image Recognition" `_. 306 | Args: 307 | pretrained (bool): If True, returns a model pre-trained on ImageNet 308 | progress (bool): If True, displays a progress bar of the download to stderr 309 | """ 310 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 311 | **kwargs) 312 | 313 | 314 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 315 | r"""ResNet-152 model from 316 | `"Deep Residual Learning for Image Recognition" `_. 317 | Args: 318 | pretrained (bool): If True, returns a model pre-trained on ImageNet 319 | progress (bool): If True, displays a progress bar of the download to stderr 320 | """ 321 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 322 | **kwargs) 323 | 324 | 325 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 326 | r"""ResNeXt-50 32x4d model from 327 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 328 | Args: 329 | pretrained (bool): If True, returns a model pre-trained on ImageNet 330 | progress (bool): If True, displays a progress bar of the download to stderr 331 | """ 332 | kwargs['groups'] = 32 333 | kwargs['width_per_group'] = 4 334 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 335 | pretrained, progress, **kwargs) 336 | 337 | 338 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 339 | r"""ResNeXt-101 32x8d model from 340 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 341 | Args: 342 | pretrained (bool): If True, returns a model pre-trained on ImageNet 343 | progress (bool): If True, displays a progress bar of the download to stderr 344 | """ 345 | kwargs['groups'] = 32 346 | kwargs['width_per_group'] = 8 347 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 348 | pretrained, progress, **kwargs) 349 | 350 | 351 | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 352 | r"""Wide ResNet-50-2 model from 353 | `"Wide Residual Networks" `_. 354 | The model is the same as ResNet except for the bottleneck number of channels 355 | which is twice larger in every block. The number of channels in outer 1x1 356 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 357 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 358 | Args: 359 | pretrained (bool): If True, returns a model pre-trained on ImageNet 360 | progress (bool): If True, displays a progress bar of the download to stderr 361 | """ 362 | kwargs['width_per_group'] = 64 * 2 363 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 364 | pretrained, progress, **kwargs) 365 | 366 | 367 | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 368 | r"""Wide ResNet-101-2 model from 369 | `"Wide Residual Networks" `_. 370 | The model is the same as ResNet except for the bottleneck number of channels 371 | which is twice larger in every block. The number of channels in outer 1x1 372 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 373 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 374 | Args: 375 | pretrained (bool): If True, returns a model pre-trained on ImageNet 376 | progress (bool): If True, displays a progress bar of the download to stderr 377 | """ 378 | kwargs['width_per_group'] = 64 * 2 379 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 380 | pretrained, progress, **kwargs) 381 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | import torch 5 | import torch.distributed as dist 6 | 7 | from scipy.misc import imread, imresize 8 | from tqdm import tqdm 9 | from collections import Counter 10 | import random 11 | import jsonlines 12 | from bs4 import BeautifulSoup as bs 13 | from html import escape 14 | 15 | 16 | def check_longest_cell(cells): 17 | length_cells = [len(cell["tokens"]) for cell in cells] 18 | return max(length_cells) 19 | 20 | 21 | def create_input_files(image_folder="pubtabnet", output_folder="output_w_none_399k", 22 | max_len_token_structure=300, 23 | max_len_token_cell=100, 24 | image_size=512 25 | ): 26 | """ 27 | Creates input files for training, validation, and test data. 28 | 29 | :param image_folder: folder with downloaded images 30 | :param output_folder: folder to save files 31 | :param max_len_token_structure: don't sample captions_structure longer than this length 32 | :param max_len_token_cell: sample captions_cell longer than this length will be clipped 33 | """ 34 | print("create_input .....") 35 | with open(os.path.join(image_folder, "PubTabNet_2.0.0.jsonl"), 'r') as reader: 36 | imgs = list(reader) 37 | if not os.path.exists(output_folder): 38 | os.makedirs(output_folder) 39 | 40 | # Read image paths and captions for each image 41 | train_image_captions_structure = [] 42 | train_image_captions_cells = [] 43 | train_image_paths = [] 44 | 45 | valid_image_captions_structure = [] 46 | valid_image_captions_cells = [] 47 | valid_image_paths = [] 48 | 49 | test_image_captions_structure = [] 50 | test_image_captions_cells = [] 51 | test_image_paths = [] 52 | word_freq_structure = Counter() 53 | word_freq_cells = Counter() 54 | 55 | max_number_imgs_train = 100000000 56 | max_number_imgs_val = 1000000 57 | 58 | total_number_imgs_train = 0 59 | total_number_imgs_val = 0 60 | total_number_imgs_test = 0 61 | 62 | for (index, image) in tqdm(enumerate(imgs)): 63 | img = eval(image) 64 | word_freq_structure.update(img["html"]["structure"]["tokens"]) 65 | 66 | for cell in img["html"]["cells"]: 67 | if len(cell["tokens"]) == 0: # an empty cell 68 | cell["tokens"].append('') 69 | word_freq_cells.update(cell["tokens"]) 70 | 71 | captions_structure = [] 72 | caption_cells = [] 73 | path = os.path.join("{}/{}".format(image_folder, 74 | img["split"]), img['filename']) 75 | 76 | captions_structure.append(img["html"]["structure"]['tokens']) # List 77 | 78 | if img["split"] == "train" and total_number_imgs_train < max_number_imgs_train: 79 | if len(img["html"]["structure"]["tokens"]) <= max_len_token_structure: 80 | # img_pic = imread(path) 81 | # if img_pic.shape[0] <= image_size and img_pic.shape[1] <= image_size: 82 | for cell in img["html"]["cells"]: 83 | caption_cells.append(cell["tokens"][:max_len_token_cell]) 84 | train_image_captions_structure.append(captions_structure) # List[List] 85 | train_image_captions_cells.append(caption_cells) # List[List[List]] 86 | train_image_paths.append(path) 87 | total_number_imgs_train += 1 88 | else: 89 | continue 90 | elif img["split"] == "val" and total_number_imgs_val < max_number_imgs_val: 91 | if len(img["html"]["structure"]["tokens"]) <= max_len_token_structure: 92 | for cell in img["html"]["cells"]: 93 | caption_cells.append(cell["tokens"][:max_len_token_cell]) 94 | valid_image_captions_structure.append(captions_structure) 95 | valid_image_captions_cells.append(caption_cells) 96 | valid_image_paths.append(path) 97 | total_number_imgs_val += 1 98 | elif img["split"] == "test": 99 | test_image_captions_structure.append(captions_structure) 100 | test_image_captions_cells.append(caption_cells) 101 | test_image_paths.append(path) 102 | total_number_imgs_test += 1 103 | else: 104 | continue 105 | print("Total number imgs for train: ", total_number_imgs_train) 106 | print("Total number imgs for val: ", total_number_imgs_val) 107 | print("Total number imgs for test: ", total_number_imgs_test) 108 | 109 | # create vocabluary structure 110 | words_structure = [w for w in word_freq_structure.keys()] 111 | word_map_structure = {k: v + 1 for v, k in enumerate(words_structure)} 112 | word_map_structure[''] = len(word_map_structure) + 1 113 | word_map_structure[''] = len(word_map_structure) + 1 114 | word_map_structure[''] = len(word_map_structure) + 1 115 | word_map_structure[''] = 0 116 | 117 | # create vocabluary cells 118 | words_cell = [w for w in word_freq_cells.keys()] 119 | word_map_cell = {k: v + 1 for v, k in enumerate(words_cell)} 120 | word_map_cell[''] = len(word_map_cell) + 1 121 | word_map_cell[''] = len(word_map_cell) + 1 122 | word_map_cell[''] = len(word_map_cell) + 1 123 | word_map_cell[''] = 0 124 | 125 | # save vocabluary to json 126 | with open(os.path.join(output_folder, 'WORDMAP_' + "STRUCTURE" + '.json'), 'w') as j: 127 | json.dump(word_map_structure, j) 128 | 129 | with open(os.path.join(output_folder, 'WORDMAP_' + "CELL" + '.json'), 'w') as j: 130 | json.dump(word_map_cell, j) 131 | 132 | for impaths, imcaps_structure, imcaps_cell, split in [(train_image_paths, 133 | train_image_captions_structure, 134 | train_image_captions_cells, 135 | 'train'), 136 | (valid_image_paths, 137 | valid_image_captions_structure, 138 | valid_image_captions_cells, 139 | 'val'), 140 | (test_image_paths, 141 | test_image_captions_structure, 142 | test_image_captions_cells, 143 | 'test')]: 144 | 145 | if len(imcaps_structure) == 0 and split in ['train', 'val']: 146 | continue 147 | 148 | with open(os.path.join(output_folder, split + '_IMAGE_PATHS.txt'), 'a') as f: 149 | print("\nReading %s images and captions, storing to file...\n" % split) 150 | enc_captions_structure = [] 151 | enc_captions_cells = [] 152 | cap_structure_len = [] 153 | cap_cell_len = [] 154 | number_cell_per_images = [] 155 | for i, path in enumerate(tqdm(impaths)): 156 | captions_structure = imcaps_structure[i] 157 | captions_cell = imcaps_cell[i] 158 | f.write(impaths[i]+'\n') 159 | 160 | # encode caption cell and structure 161 | for j, c in enumerate(captions_structure): 162 | enc_c = [word_map_structure['']] + \ 163 | [word_map_structure.get(word, word_map_structure['']) for word in c] + \ 164 | [word_map_structure['']] 165 | c_len = len(c) + 2 166 | enc_captions_structure.append(enc_c) 167 | cap_structure_len.append(c_len) 168 | 169 | # for each img have many cell captions 170 | each_enc_captions_cell = [] 171 | each_cap_cell_len = [] 172 | for j, c in enumerate(captions_cell): 173 | enc_c = [word_map_cell['']] + \ 174 | [word_map_cell.get(word, word_map_cell['']) for word in c] + \ 175 | [word_map_cell['']] 176 | c_len = len(c) + 2 177 | each_enc_captions_cell.append(enc_c) 178 | each_cap_cell_len.append(c_len) 179 | 180 | # save encoding cell in per image 181 | enc_captions_cells.append(each_enc_captions_cell) 182 | cap_cell_len.append(each_cap_cell_len) 183 | number_cell_per_images.append(len(captions_cell)) 184 | if split == 'train' or split == 'val': 185 | with open(os.path.join(output_folder, split + '_CAPTIONS_STRUCTURE' + '.json'), 'w') as j: 186 | json.dump(enc_captions_structure, j) 187 | with open(os.path.join(output_folder, split + '_CAPLENS_STRUCTURE' + '.json'), 'w') as j: 188 | json.dump(cap_structure_len, j) 189 | with open(os.path.join(output_folder, split + '_CAPTIONS_CELL' + '.json'), 'w') as j: 190 | json.dump(enc_captions_cells, j) 191 | with open(os.path.join(output_folder, split + '_CAPLENS_CELL' + '.json'), 'w') as j: 192 | json.dump(cap_cell_len, j) 193 | with open(os.path.join(output_folder, split + '_NUMBER_CELLS_PER_IMAGE' + '.json'), 'w') as j: 194 | json.dump(number_cell_per_images, j) 195 | 196 | 197 | def id_to_word(vocabluary): 198 | id2word = {value: key for key, value in vocabluary.items()} 199 | return id2word 200 | 201 | 202 | def init_embedding(embeddings): 203 | """ 204 | Fills embedding tensor with values from the uniform distribution. 205 | 206 | :param embeddings: embedding tensor 207 | """ 208 | bias = np.sqrt(3.0 / embeddings.size(1)) 209 | torch.nn.init.uniform_(embeddings, -bias, bias) 210 | 211 | 212 | def load_embeddings(emb_file, word_map): 213 | """ 214 | Creates an embedding tensor for the specified word map, for loading into the model. 215 | 216 | :param emb_file: file containing embeddings (stored in GloVe format) 217 | :param word_map: word map 218 | :return: embeddings in the same order as the words in the word map, dimension of embeddings 219 | """ 220 | 221 | # Find embedding dimension 222 | with open(emb_file, 'r') as f: 223 | emb_dim = len(f.readline().split(' ')) - 1 224 | 225 | vocab = set(word_map.keys()) 226 | 227 | # Create tensor to hold embeddings, initialize 228 | embeddings = torch.FloatTensor(len(vocab), emb_dim) 229 | init_embedding(embeddings) 230 | 231 | # Read embedding file 232 | print("\nLoading embeddings...") 233 | for line in open(emb_file, 'r'): 234 | line = line.split(' ') 235 | 236 | emb_word = line[0] 237 | embedding = list(map(lambda t: float(t), filter( 238 | lambda n: n and not n.isspace(), line[1:]))) 239 | 240 | # Ignore word if not in train_vocab 241 | if emb_word not in vocab: 242 | continue 243 | 244 | embeddings[word_map[emb_word]] = torch.FloatTensor(embedding) 245 | 246 | return embeddings, emb_dim 247 | 248 | 249 | def clip_gradient(optimizer, grad_clip): 250 | """ 251 | Clips gradients computed during backpropagation to avoid explosion of gradients. 252 | 253 | :param optimizer: optimizer with the gradients to be clipped 254 | :param grad_clip: clip value 255 | """ 256 | for group in optimizer.param_groups: 257 | for param in group['params']: 258 | if param.grad is not None: 259 | if param.grad.data.max().item() > grad_clip or param.grad.data.min().item() < -grad_clip: 260 | print("Clip gradient......") 261 | param.grad.data.clamp_(-grad_clip, grad_clip) 262 | 263 | 264 | def save_tmp_grad(optimizer, filename): 265 | save_list = [] 266 | for group in optimizer.param_groups: 267 | # torch.save([(p, p.grad.data) for p in group['params'] if p.grad is not None], filename) 268 | 269 | # torch.save([p for p in model.parameters() if p.requires_grad], "model_"+filename) 270 | for param in group['params']: 271 | if param.grad is not None: 272 | print(param.name, param.grad.data.size()) 273 | save_list.append(param.grad.data.cpu().numpy()) 274 | # torch.save(save_list, filename) 275 | np.save(filename, np.array(save_list)) 276 | 277 | 278 | def save_checkpoint(epoch, epochs_since_improvement, encoder, decoder_structure, decoder_cell, 279 | encoder_optimizer, decoder_structure_optimizer, decoder_cell_optimizer, recent_ted_score, is_best): 280 | """ 281 | Saves model checkpoint. 282 | 283 | :param data_name: base name of processed dataset 284 | :param epoch: epoch number 285 | :param epochs_since_improvement: number of epochs since last improvement in BLEU-4 score 286 | :param encoder: encoder model 287 | :param decoder: decoder model 288 | :param encoder_optimizer: optimizer to update encoder's weights, if fine-tuning 289 | :param decoder_optimizer: optimizer to update decoder's weights 290 | :param recent_ted_score: validation TED score for this epoch 291 | :param is_best: is this checkpoint the best so far? 292 | """ 293 | state = {'epoch': epoch, 294 | 'epochs_since_improvement': epochs_since_improvement, 295 | 'ted_score': recent_ted_score, 296 | 'encoder': encoder, 297 | 'decoder_structure': decoder_structure, 298 | 'encoder_optimizer': encoder_optimizer, 299 | 'decoder_structure_optimizer': decoder_structure_optimizer, 300 | 'decoder_cell': decoder_cell, 301 | 'decoder_cell_optimizer': decoder_cell_optimizer, 302 | } 303 | filename = 'checkpoint_table' + '.pth.tar' 304 | torch.save(state, filename) 305 | # If this checkpoint is the best so far, store a copy so it doesn't get overwritten by a worse checkpoint 306 | if is_best: 307 | torch.save(state, 'BEST_' + filename) 308 | 309 | 310 | def create_html(html_code): 311 | return ''' 312 | 313 | 314 | 320 | 321 | 322 | 323 | %s 324 |
325 | 326 | ''' % html_code 327 | 328 | 329 | class AverageMeter(object): 330 | """ 331 | Keeps track of most recent, average, sum, and count of a metric. 332 | """ 333 | 334 | def __init__(self): 335 | self.reset() 336 | 337 | def reset(self): 338 | self.val = 0 339 | self.avg = 0 340 | self.sum = 0 341 | self.count = 0 342 | 343 | def update(self, val, n=1): 344 | self.val = val 345 | self.sum += val * n 346 | self.count += n 347 | self.avg = self.sum / self.count 348 | 349 | 350 | def adjust_learning_rate(optimizer, lr): 351 | """ 352 | Shrinks learning rate by a specified factor. 353 | 354 | :param optimizer: optimizer whose learning rate must be shrunk. 355 | :param lr: new lr. 356 | """ 357 | 358 | # print("\nDECAYING learning rate.") 359 | for param_group in optimizer.param_groups: 360 | param_group['lr'] = lr 361 | # optimizer.state_dict()['param_groups'][0]['lr'] = lr 362 | # print("The new learning rate is %f\n" % (optimizer.param_groups[0]['lr'],)) 363 | 364 | 365 | def accuracy(output, target, topk=(1,)): 366 | """ 367 | Computes the accuracy over the k top predictions for the specified values of k 368 | 369 | :return List[Tensor]: in order to all reduce 370 | """ 371 | with torch.no_grad(): 372 | maxk = max(topk) 373 | batch_size = target.size(0) 374 | 375 | _, pred = output.topk(maxk, 1, True, True) 376 | pred = pred.t() 377 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 378 | 379 | res = [] 380 | for k in topk: 381 | correct_k = correct[:k].reshape(-1).float().sum() 382 | res.append(correct_k * (100.0 / batch_size)) 383 | return res 384 | 385 | 386 | def format_html(html): 387 | ''' Formats HTML code from tokenized annotation of img 388 | ''' 389 | 390 | html_code = ''' 391 | 392 | 393 | 399 | 400 | 401 | 402 | %s 403 |
404 | 405 | ''' % html 406 | 407 | # prettify the html 408 | soup = bs(html_code) 409 | html_code = soup.prettify() 410 | return html_code 411 | 412 | 413 | def convertId2wordSentence(id2word, idwords): 414 | words = [id2word[idword] for idword in idwords] 415 | words = [word for word in words if word != "" and word != "" and word != ""] 416 | words = "".join(words) 417 | return words 418 | 419 | 420 | def mean_loss(loss_list): 421 | loss_tensor = torch.stack(loss_list) 422 | loss_tensor_mean = torch.mean(loss_tensor) 423 | return loss_tensor_mean 424 | 425 | 426 | def is_dist_avail_and_initialized(): 427 | if not dist.is_available(): 428 | return False 429 | if not dist.is_initialized(): 430 | return False 431 | return True 432 | 433 | 434 | def get_world_size(): 435 | if not is_dist_avail_and_initialized(): 436 | return 1 437 | return dist.get_world_size() 438 | 439 | 440 | def get_rank(): 441 | if not is_dist_avail_and_initialized(): 442 | return 0 443 | return dist.get_rank() 444 | 445 | 446 | def is_main_process(): 447 | return get_rank() == 0 448 | 449 | 450 | def save_on_master(epoch, epochs_since_improvement, model, 451 | optimizer, recent_ted_score, is_best, filename): 452 | """ 453 | Saves model checkpoint. 454 | 455 | :param data_name: base name of processed dataset 456 | :param epoch: epoch number 457 | :param epochs_since_improvement: number of epochs since last improvement in BLEU-4 score 458 | :param encoder: encoder model 459 | :param decoder: decoder model 460 | :param encoder_optimizer: optimizer to update encoder's weights, if fine-tuning 461 | :param decoder_optimizer: optimizer to update decoder's weights 462 | :param recent_ted_score: validation TED score for this epoch 463 | :param is_best: is this checkpoint the best so far? 464 | :param filename: checkpoint's name 465 | """ 466 | if is_main_process(): 467 | state = {'epoch': epoch, 468 | 'epochs_since_improvement': epochs_since_improvement, 469 | 'ted_score': recent_ted_score, 470 | 'model': model.module.state_dict(), 471 | 'optimizer': optimizer.state_dict(), 472 | } 473 | torch.save(state, filename) 474 | # If this checkpoint is the best so far, store a copy so it doesn't get overwritten by a worse checkpoint 475 | if is_best: 476 | best_filename = filename.split('/') 477 | best_filename[-1] = 'BEST_' + best_filename[-1] 478 | best_filename = '/'.join(best_filename) 479 | torch.save(state, best_filename) 480 | 481 | 482 | def reduce_dict(input_dict, average=True): 483 | """ 484 | Args: 485 | input_dict (dict): all the values will be reduced 486 | average (bool): whether to do average or sum 487 | Reduce the values in the dictionary from all processes so that all processes 488 | have the averaged results. Returns a dict with the same fields as 489 | input_dict, after reduction. 490 | """ 491 | world_size = get_world_size() 492 | if world_size < 2: 493 | return input_dict 494 | with torch.no_grad(): 495 | names = [] 496 | values = [] 497 | # sort the keys so that they are consistent across processes 498 | for k in sorted(input_dict.keys()): 499 | names.append(k) 500 | values.append(input_dict[k]) 501 | values = torch.stack(values, dim=0) 502 | dist.all_reduce(values) 503 | if average: 504 | values /= world_size 505 | reduced_dict = {k: v for k, v in zip(names, values)} 506 | return reduced_dict 507 | 508 | 509 | def set_random_seeds(random_seed=0): 510 | torch.manual_seed(random_seed) 511 | torch.backends.cudnn.deterministic = True 512 | torch.backends.cudnn.benchmark = False 513 | np.random.seed(random_seed) 514 | random.seed(random_seed) 515 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import json 6 | import torchvision.transforms as transforms 7 | import skimage.transform 8 | import argparse 9 | from scipy.misc import imread, imresize 10 | from PIL import Image 11 | from utils import * 12 | from models import EDD 13 | from metric.metric_score import TEDS 14 | from tqdm import tqdm 15 | from collections import OrderedDict 16 | import pandas as pd 17 | 18 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 19 | 20 | 21 | def encoderImage(encoder, image_path, image_size): 22 | img = imread(image_path) 23 | if len(img.shape) == 2: 24 | img = img[:, :, np.newaxis] 25 | img = np.concatenate([img, img, img], axis=2) 26 | img = imresize(img, (image_size, image_size)) 27 | img = img.transpose(2, 0, 1) 28 | img = img / 255. 29 | img = torch.FloatTensor(img).to(device) 30 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 31 | std=[0.229, 0.224, 0.225]) 32 | transform = transforms.Compose([normalize]) 33 | image = transform(img) # 34 | 35 | # Encode 36 | image = image.unsqueeze(0) # (1, 3, img_size, img_size) 37 | # (1, enc_image_size, enc_image_size, encoder_dim) 38 | encoder_out_structure, encoder_out_cell = encoder(image) 39 | return encoder_out_structure, encoder_out_cell 40 | 41 | 42 | def structure_image_beam_search(encoder_out_structure, decoder, word_map, structure_weight=None, 43 | beam_size=3, max_seq_len=300, rank_method='sum', T=0.6): 44 | is_overflow = False 45 | k = beam_size 46 | vocab_size = len(word_map) 47 | decoder_structure_dim = 256 48 | # Read image and process 49 | encoder_dim = encoder_out_structure.size(3) # 512 50 | 51 | # Flatten encoding 52 | # (1, num_pixels, encoder_dim) 53 | encoder_out_structure = encoder_out_structure.view(1, -1, encoder_dim) 54 | num_pixels = encoder_out_structure.size(1) 55 | 56 | # We'll treat the problem as having a batch size of k 57 | # (k, num_pixels, encoder_dim) 58 | encoder_out_structure = encoder_out_structure.expand(k, num_pixels, encoder_dim) 59 | 60 | # Tensor to store top k previous words at each step decode structure; construct just 61 | k_prev_words = torch.LongTensor( 62 | [[word_map['']]] * k).to(device) # (k, 1) 63 | 64 | # Tensor to store top k sequences; now they're just 65 | seqs = k_prev_words # (k, 1) 66 | 67 | # Tensor to store top k sequences' scores; now they're just 0 68 | top_k_scores = torch.zeros(k, 1).to(device) # (k, 1) 69 | 70 | # tensor save hidden state and after filter to choice hidden state to pass cell decoder 71 | seqs_hidden_states = torch.zeros(k, 1, decoder_structure_dim).to(device) 72 | 73 | # Lists to store completed sequences, their alphas and scores, hidden 74 | complete_seqs = list() 75 | complete_seqs_scores = list() 76 | complete_seqs_hiddens = list() 77 | 78 | # start decoding 79 | step = 1 80 | h, c = decoder.init_hidden_state(encoder_out_structure) # h, c: (k, decoder_structure_dim) 81 | 82 | # s is a number less than or equal to k, because sequences are removed from this process once they hit 83 | while True: 84 | embeddings = decoder.embedding( 85 | k_prev_words).squeeze(1) # (s, embed_dim) (s, 16) 86 | 87 | # (s, encoder_dim), (s, num_pixels) 88 | awe, alpha = decoder.attention(encoder_out_structure, h) 89 | 90 | # gating scalar, (s, encoder_dim) 91 | gate = decoder.sigmoid(decoder.f_beta(h)) 92 | awe = gate * awe 93 | 94 | h, c = decoder.decode_step( 95 | torch.cat([embeddings, awe], dim=1), (h, c)) # (s, decoder_dim) 96 | 97 | scores = decoder.fc(h) # (s, vocab_size) 98 | if structure_weight is not None: 99 | scores = scores * structure_weight 100 | scores = scores / T 101 | scores = F.log_softmax(scores, dim=1) 102 | 103 | if step == 1: 104 | # scores's shape: (1, vocab_size) 105 | top_k_scores, top_k_words = (top_k_scores+scores)[0].topk(k, 0, True, True) 106 | else: 107 | # scores's shape: (s, vocab_size) 108 | if rank_method == 'mean': 109 | top_k_scores, top_k_words = ((top_k_scores*(step-1)+scores)/step).view(-1).topk(k, 0, True, True) 110 | elif rank_method == 'sum': 111 | top_k_scores, top_k_words = (top_k_scores+scores).view(-1).topk(k, 0, True, True) 112 | else: 113 | RuntimeError("Invalid rank method: ", rank_method) 114 | 115 | # Convert unrolled indices to actual indices of scores 116 | prev_word_inds = top_k_words // vocab_size # (s) 117 | next_word_inds = top_k_words % vocab_size # (s) 118 | 119 | # Add new words to sequences, alphas, and hidden_state 120 | seqs = torch.cat( 121 | [seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1) 122 | 123 | if step == 1: 124 | seqs_hidden_states = h.unsqueeze(1) 125 | else: 126 | seqs_hidden_states = torch.cat( 127 | [seqs_hidden_states[prev_word_inds], h[prev_word_inds].unsqueeze(1)], dim=1) # (s, step+1, decoder_structure_dim) 128 | # Which sequences are incomplete (didn't reach )? 129 | 130 | incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if 131 | next_word != word_map['']] 132 | # Break if things have been going on too long 133 | if step > max_seq_len: 134 | incomplete_inds = [] 135 | is_overflow = True 136 | 137 | complete_inds = list( 138 | set(range(len(next_word_inds))) - set(incomplete_inds)) 139 | 140 | # Set aside complete sequences 141 | if len(complete_inds) > 0: 142 | complete_seqs.extend(seqs[complete_inds].tolist()) 143 | complete_seqs_scores.extend(top_k_scores[complete_inds]) 144 | complete_seqs_hiddens.extend( 145 | seqs_hidden_states[complete_inds]) 146 | k -= len(complete_inds) # reduce beam length accordingly 147 | 148 | if k == 0: 149 | break 150 | seqs = seqs[incomplete_inds] 151 | seqs_hidden_states = seqs_hidden_states[incomplete_inds] 152 | h = h[prev_word_inds[incomplete_inds]] 153 | 154 | c = c[prev_word_inds[incomplete_inds]] 155 | encoder_out_structure = encoder_out_structure[prev_word_inds[incomplete_inds]] 156 | top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1) 157 | k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1) 158 | 159 | step += 1 160 | 161 | max_score = max(complete_seqs_scores) 162 | i = complete_seqs_scores.index(max_score) 163 | seq = complete_seqs[i] 164 | hidden_states = complete_seqs_hiddens[i] 165 | 166 | return seq, hidden_states, is_overflow, max_score.cpu().numpy() 167 | 168 | 169 | def cell_image_beam_search(encoder_out, decoder, word_map, hidden_state_structure, 170 | beam_size=3., max_seq_len=100, rank_method='sum', T=0.6): 171 | is_overflow = False 172 | k = beam_size 173 | vocab_size = len(word_map) 174 | encoder_dim = encoder_out.size(3) 175 | decoder_structure_dim = 256 176 | 177 | # Flatten encoding 178 | # (1, num_pixels, encoder_dim) 179 | encoder_out = encoder_out.view(1, -1, encoder_dim) 180 | num_pixels = encoder_out.size(1) 181 | 182 | # We'll treat the problem as having a batch size of k 183 | # (k, num_pixels, encoder_dim) 184 | encoder_out = encoder_out.expand(k, num_pixels, encoder_dim) 185 | 186 | # Tensor to store top k previous words at each step decode structure; construct just 187 | k_prev_words = torch.LongTensor( 188 | [[word_map['']]] * k).to(device) # (k, 1) 189 | 190 | # Tensor to store top k sequences; now they're just 191 | seqs = k_prev_words # (k, 1) 192 | 193 | # Tensor to store top k sequences' scores; now they're just 0 194 | top_k_scores = torch.zeros(k, 1).to(device) # (k, 1) 195 | 196 | complete_seqs = list() 197 | complete_seqs_scores = list() 198 | step = 1 199 | h, c = decoder.init_hidden_state(encoder_out) 200 | 201 | # s is a number less than or equal to k, because sequences are removed from this process once they hit 202 | while True: 203 | embeddings = decoder.embedding( 204 | k_prev_words).squeeze(1) # (s, embed_dim) 205 | 206 | s = list(encoder_out.size())[0] 207 | hidden_state_structure_s = hidden_state_structure.expand( 208 | s, decoder_structure_dim) 209 | # (s, encoder_dim), (s, num_pixels) 210 | awe, alpha = decoder.attention(encoder_out, h, hidden_state_structure_s) 211 | 212 | # gating scalar, (s, encoder_dim) 213 | gate = decoder.sigmoid(decoder.f_beta(h)) 214 | awe = gate * awe 215 | 216 | h, c = decoder.decode_step( 217 | torch.cat([embeddings, awe], dim=1), (h, c)) # (s, decoder_dim) 218 | 219 | scores = decoder.fc(h) # (s, vocab_size) 220 | scores = scores / T 221 | scores = F.log_softmax(scores, dim=1) 222 | 223 | if step == 1: 224 | # scores's shape: (1, vocab_size) 225 | top_k_scores, top_k_words = (top_k_scores+scores)[0].topk(k, 0, True, True) 226 | else: 227 | # scores's shape: (s, vocab_size) 228 | if rank_method == 'mean': 229 | top_k_scores, top_k_words = ((top_k_scores*(step-1)+scores)/step).view(-1).topk(k, 0, True, True) 230 | elif rank_method == 'sum': 231 | top_k_scores, top_k_words = (top_k_scores+scores).view(-1).topk(k, 0, True, True) 232 | else: 233 | RuntimeError("Invalid rank method: ", rank_method) 234 | 235 | # Convert unrolled indices to actual indices of scores 236 | prev_word_inds = top_k_words // vocab_size # (s) 237 | next_word_inds = top_k_words % vocab_size # (s) 238 | 239 | # Add new words to sequences, alphas, and hidden_state 240 | seqs = torch.cat( 241 | [seqs[prev_word_inds], next_word_inds.unsqueeze(1)], dim=1) # (s, step+1) 242 | 243 | # Which sequences are incomplete (didn't reach )? 244 | incomplete_inds = [ind for ind, next_word in enumerate(next_word_inds) if 245 | next_word != word_map['']] 246 | 247 | # Break if things have been going on too long 248 | if step > max_seq_len: 249 | incomplete_inds = [] 250 | is_overflow = True 251 | 252 | complete_inds = list( 253 | set(range(len(next_word_inds))) - set(incomplete_inds)) 254 | 255 | # Set aside complete sequences 256 | if len(complete_inds) > 0: 257 | complete_seqs.extend(seqs[complete_inds].tolist()) 258 | complete_seqs_scores.extend(top_k_scores[complete_inds]) 259 | k -= len(complete_inds) # reduce beam length accordingly 260 | 261 | if k == 0: 262 | break 263 | seqs = seqs[incomplete_inds] 264 | h = h[prev_word_inds[incomplete_inds]] 265 | 266 | c = c[prev_word_inds[incomplete_inds]] 267 | encoder_out = encoder_out[prev_word_inds[incomplete_inds]] 268 | top_k_scores = top_k_scores[incomplete_inds].unsqueeze(1) 269 | k_prev_words = next_word_inds[incomplete_inds].unsqueeze(1) 270 | 271 | step += 1 272 | 273 | max_score = max(complete_seqs_scores) 274 | i = complete_seqs_scores.index(max_score) 275 | seq = complete_seqs[i] 276 | return seq, is_overflow, max_score.cpu().numpy() 277 | 278 | 279 | if __name__ == '__main__': 280 | parser = argparse.ArgumentParser( 281 | description='Show, Attend, and Tell - Tutorial - Generate Caption') 282 | 283 | parser.add_argument("--start_idx", type=int, default=0, 284 | help="start indices in all test images.") 285 | parser.add_argument("--offset", type=int, default=1000, 286 | help="we test start_idx: start_idx+offset in all test images.") 287 | 288 | parser.add_argument('--img', '-i', default='img.png', help='path to image') 289 | parser.add_argument('--model', '-m', help='path to model') 290 | parser.add_argument('--word_map_structure', '-wms', 291 | help='path to word map structure JSON') 292 | parser.add_argument('--word_map_cell', '-wmc', 293 | help='path to word map cell JSON') 294 | parser.add_argument('--beam_size_structure', '-bs', default=3, 295 | type=int, help='beam size for beam search') 296 | parser.add_argument('--beam_size_cell', '-bc', default=3, 297 | type=int, help='beam size for beam search') 298 | parser.add_argument("--max_seq_len_structure", type=int, default=300, 299 | help="Maximal number of tokens generated by structure decoder") 300 | parser.add_argument("--max_seq_len_cell", type=int, default=100, 301 | help="Maximal number of tokens generated by cell decoder") 302 | parser.add_argument('--dont_smooth', dest='smooth', 303 | action='store_false', help='do not smooth alpha overlay') 304 | parser.add_argument("--img_from_val", action="store_true", 305 | help="image from validation set.") 306 | parser.add_argument("--data_folder", type=str, default='output_w_none_399k_memory_effi', 307 | help="Directory for dataset.") 308 | parser.add_argument("--split", type=str, default='val', 309 | help="evaluate part.") 310 | parser.add_argument("--all", action="store_true", 311 | help="All test samples.") 312 | parser.add_argument("--n_samples", type=int, default=100, 313 | help="Number of test samples.") 314 | parser.add_argument("--print_freq", type=int, default=1000, 315 | help="Print result.") 316 | parser.add_argument("--not_save", action="store_true", 317 | help="Save final results") 318 | parser.add_argument("--T", type=float, default=0.6, 319 | help="Temperature.") 320 | 321 | # Model setting 322 | parser.add_argument("--backbone", type=str, default='resnet18', 323 | help="The backbone of encoder") 324 | parser.add_argument("--EDD_type", type=str, default='S2S2', 325 | help="The type of EDD, choice in S1S1, S2S2") 326 | parser.add_argument("--emb_dim_structure", type=int, default=16, 327 | help="Dimension of word embeddings for structure token") 328 | parser.add_argument("--emb_dim_cell", type=int, default=80, 329 | help="Dimension of word embeddings for cell token") 330 | parser.add_argument("--attention_dim", type=int, default=512, 331 | help="Dimension of attention linear layers") 332 | parser.add_argument("--decoder_dim_structure", type=int, default=256, 333 | help="Dimension of decoder RNN structure") 334 | parser.add_argument("--decoder_dim_cell", type=int, default=512, 335 | help="Dimension of decoder RNN cell") 336 | parser.add_argument("--fp16", action="store_true", 337 | help="Model with FP16.") 338 | parser.add_argument("--image_size", type=int, default=448, 339 | help="Different image's height and width for different backbone.") 340 | parser.add_argument("--rank_method", type=str, default='mean', 341 | help="The method of rank beam search, choosing in 'mean' and 'sum'.") 342 | 343 | args = parser.parse_args() 344 | teds = TEDS(n_jobs=8) 345 | split = args.split 346 | 347 | # Load image path 348 | with open(os.path.join(args.data_folder, split + '_IMAGE_PATHS.txt'), 'r') as f: 349 | img_paths = [] 350 | for line in f: 351 | img_paths.append(line.strip()) 352 | print("Split: %s, number of images: %d" % (split, len(img_paths))) 353 | 354 | if split == 'val': 355 | # Load encoded captions structure 356 | with open(os.path.join(args.data_folder, split + '_CAPTIONS_STRUCTURE' + '.json'), 'r') as j: 357 | captions_structure = json.load(j) 358 | 359 | # Load caption structure length (completely into memory) 360 | with open(os.path.join(args.data_folder, split + '_CAPLENS_STRUCTURE' + '.json'), 'r') as j: 361 | caplens_structure = json.load(j) 362 | 363 | # Load encoded captions cell 364 | with open(os.path.join(args.data_folder, split + '_CAPTIONS_CELL' + '.json'), 'r') as j: 365 | captions_cell = json.load(j) 366 | # Load caption cell length 367 | with open(os.path.join(args.data_folder, split + '_CAPLENS_CELL' + '.json'), 'r') as j: 368 | caplens_cell = json.load(j) 369 | with open(os.path.join(args.data_folder, split + "_NUMBER_CELLS_PER_IMAGE.json"), "r") as j: 370 | number_cell_per_images = json.load(j) 371 | 372 | with open(args.word_map_structure, 'r') as j: 373 | word_map_structure = json.load(j) 374 | with open(args.word_map_cell, "r") as j: 375 | word_map_cell = json.load(j) 376 | id2word_stucture = id_to_word(word_map_structure) 377 | id2word_cell = id_to_word(word_map_cell) 378 | 379 | # Load model 380 | criterion = torch.nn.CrossEntropyLoss(reduction='mean') 381 | 382 | if args.EDD_type == 'S1S1': 383 | encoded_image_size = args.image_size // 16 384 | last_conv_stride = 1 385 | elif args.EDD_type == 'S2S2': 386 | encoded_image_size = args.image_size // 32 387 | last_conv_stride = 2 388 | 389 | model = EDD(encoded_image_size=encoded_image_size, 390 | encoder_dim=512, 391 | pretrained=False, 392 | structure_attention_dim=args.attention_dim, 393 | structure_embed_dim=args.emb_dim_structure, 394 | structure_decoder_dim=args.decoder_dim_structure, 395 | structure_dropout=0., 396 | structure_vocab=word_map_structure, 397 | cell_attention_dim=args.attention_dim, 398 | cell_embed_dim=args.emb_dim_cell, 399 | cell_decoder_dim=args.decoder_dim_cell, 400 | cell_dropout=0., 401 | cell_vocab=word_map_cell, 402 | criterion_structure=criterion, 403 | criterion_cell=criterion, 404 | alpha_c=1., 405 | id2word_structure=id2word_stucture, 406 | id2word_cell=id2word_cell, 407 | last_conv_stride=last_conv_stride, 408 | lstm_bias=True, 409 | backbone=args.backbone) 410 | model = model.to(device) 411 | 412 | checkpoint = torch.load(args.model, map_location=str(device)) 413 | 414 | # try: 415 | # model.load_state_dict(checkpoint["model"]) 416 | # except: 417 | # reform_checkpoint = OrderedDict() 418 | # for k, v in checkpoint["model"].items(): 419 | # new_k = k 420 | # if 'resnet' in k: 421 | # new_k = k.replace('resnet', 'backbone_net') 422 | # reform_checkpoint[new_k] = v 423 | # model.load_state_dict(reform_checkpoint) 424 | model.load_state_dict(checkpoint["model"], strict=False) 425 | #structure_weight = pd.read_csv('structure_class_weight.csv').values[:, -1].astype(np.float32) 426 | #structure_weight = (torch.cuda.FloatTensor(structure_weight)-1)/20.+ 1. 427 | structure_weight = None 428 | model.eval() 429 | 430 | encoder = model.encoder 431 | decoder_structure = model.decoder_structure 432 | decoder_cell = model.decoder_cell 433 | encoder.eval() 434 | decoder_cell.eval() 435 | decoder_structure.eval() 436 | 437 | pred_html_only_structures = [] 438 | gt_html_only_structures = [] 439 | pred_html_alls = [] 440 | gt_html_alls = [] 441 | skipped_idx = set() 442 | test_img_paths = [] 443 | 444 | max_score_structure_save = [] 445 | max_score_cell_mean_save = [] 446 | 447 | if args.all: 448 | n_samples = len(img_paths) 449 | else: 450 | n_samples = args.n_samples 451 | for img_idx, img_path in tqdm(enumerate(img_paths[args.start_idx:args.start_idx + args.offset])): 452 | img_index = img_idx + args.start_idx 453 | args.img = img_paths[img_index] 454 | test_img_paths.append(img_paths[img_index]) 455 | with torch.no_grad(): 456 | encoder_out_structure, encoder_out_cell = encoderImage(encoder, args.img, args.image_size) 457 | 458 | seq, hidden_states, is_overflow_structure, max_score_structure = structure_image_beam_search( 459 | encoder_out_structure, decoder_structure, word_map_structure, 460 | beam_size=args.beam_size_structure, 461 | max_seq_len=args.max_seq_len_structure, 462 | rank_method=args.rank_method, 463 | T=args.T, 464 | structure_weight=structure_weight) 465 | if is_overflow_structure: 466 | print("skip {0}, length of generated structure's token larger than {1}.".format( 467 | img_index, args.max_seq_len_structure)) 468 | skipped_idx.add(img_index) 469 | 470 | cells = [] 471 | max_score_cell_list = [] 472 | html = "" 473 | html_only_structure = "" 474 | is_overflow_cell = False 475 | for index, s in enumerate(seq[1:-1]): # ignore and 476 | html += id2word_stucture[s] 477 | html_only_structure += id2word_stucture[s] 478 | if id2word_stucture[s] == "" or id2word_stucture[s] == ">": 479 | hidden_state_structure = hidden_states[index+1] 480 | seq_cell, is_overflow_cell, max_score_cell = cell_image_beam_search( 481 | encoder_out_cell, decoder_cell, word_map_cell, hidden_state_structure, 482 | beam_size=args.beam_size_cell, 483 | max_seq_len=args.max_seq_len_cell, 484 | rank_method=args.rank_method, 485 | T=args.T) 486 | max_score_cell_list.append(max_score_cell) 487 | if is_overflow_cell: 488 | print("skip {0}, length of generated cell's token larger than {1}.".format( 489 | img_index, args.max_seq_len_cell)) 490 | skipped_idx.add(img_index) 491 | # break 492 | html_cell = convertId2wordSentence(id2word_cell, seq_cell) 493 | html += html_cell 494 | 495 | if args.split == 'val': 496 | one_captions_structure = captions_structure[img_index] 497 | one_captions_cell = captions_cell[img_index] 498 | gt_html = "" 499 | gt_html_only_structure = "" 500 | cell_index = 0 501 | for s in one_captions_structure[1:-1]: # ignore and 502 | gt_html += id2word_stucture[s] 503 | gt_html_only_structure += id2word_stucture[s] 504 | if id2word_stucture[s] == "" or id2word_stucture[s] == ">": 505 | seq_cell = one_captions_cell[cell_index] 506 | html_cell = convertId2wordSentence(id2word_cell, seq_cell) 507 | gt_html += html_cell 508 | cell_index += 1 509 | gt_html_only_structures.append(create_html(gt_html_only_structure)) 510 | gt_html_alls.append(create_html(gt_html)) 511 | 512 | pred_html_only_structures.append(create_html(html_only_structure)) 513 | pred_html_alls.append(create_html(html)) 514 | 515 | 516 | max_score_cell_mean = np.mean(max_score_cell_list) 517 | max_score_structure_save.append(max_score_structure) 518 | max_score_cell_mean_save.append(max_score_cell_mean) 519 | if img_idx % args.print_freq == 0: 520 | print("Index: ", img_index) 521 | print("#" * 80) 522 | print("Pred. only structure: ") 523 | print(html_only_structure) 524 | if args.split == 'val': 525 | print("GT. only structure: ") 526 | print(gt_html_only_structure) 527 | print("ted score only structure: ", teds.evaluate(create_html(html_only_structure), 528 | create_html(gt_html_only_structure))) 529 | print("#"*80) 530 | print("Pred html all: ") 531 | print(html) 532 | if args.split == 'val': 533 | print("GT html all: ") 534 | print(gt_html) 535 | print("ted score all: ", teds.evaluate(create_html(html), 536 | create_html(gt_html))) 537 | sys.stdout.flush() 538 | if args.split == 'val': 539 | score_only_structures = teds.batch_evaluate_html(pred_html_only_structures, gt_html_only_structures) 540 | score_alls = teds.batch_evaluate_html(pred_html_alls, gt_html_alls) 541 | 542 | print("TEDS score only structure: ", np.mean(score_only_structures)) 543 | print("TEDS score all: ", np.mean(score_alls)) 544 | print("Skipped indices are: ", list(skipped_idx)) 545 | 546 | if not args.not_save: 547 | if args.split == 'val': 548 | df = pd.DataFrame({ 549 | "img_path": test_img_paths, 550 | "only_structure_teds_score": score_only_structures, 551 | "teds_score": score_alls, 552 | "pred_structure_html": pred_html_only_structures, 553 | "gt_structure_html": gt_html_only_structures, 554 | "pred_html": pred_html_alls, 555 | "gt_html": gt_html_alls, 556 | "structure_mean_log_prob": max_score_structure_save, 557 | "cell_mean_log_prob": max_score_cell_mean_save}) 558 | else: 559 | df = pd.DataFrame({ 560 | "img_path": test_img_paths, 561 | "pred_structure_html": pred_html_only_structures, 562 | "pred_html": pred_html_alls, 563 | "structure_mean_log_prob": max_score_structure_save, 564 | "cell_mean_log_prob": max_score_cell_mean_save}) 565 | df.to_csv("%s_%s_top_%d_%s_results_SBS_%d_CBS_%d_startIdx_%d_offset_%d.csv" % (args.backbone, 566 | args.EDD_type, 567 | n_samples, 568 | split, 569 | args.beam_size_structure, 570 | args.beam_size_cell, 571 | args.start_idx, 572 | args.offset), 573 | index=False) 574 | 575 | 576 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import sys 4 | import random 5 | import torch.backends.cudnn as cudnn 6 | import torch.optim 7 | import torch.utils.data 8 | from torch.utils.data.distributed import DistributedSampler 9 | import torchvision.transforms as transforms 10 | from torch import nn 11 | from torch.nn.utils.rnn import pack_padded_sequence 12 | from models import EDD 13 | from dataset import * 14 | from utils import * 15 | from metric.metric_score import TEDS 16 | import numpy as np 17 | import pandas as pd 18 | 19 | import apex 20 | from apex import amp 21 | from apex.parallel import DistributedDataParallel as DDP 22 | from distributed import allreduce_params_opt 23 | from collections import OrderedDict 24 | import argparse 25 | 26 | def main(): 27 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 28 | parser.add_argument("--local_rank", type=int, 29 | help="Local rank. Necessary for using the torch.distributed.launch utility.") 30 | parser.add_argument('--dist_url', default='env://', 31 | help='url used to set up distributed training') 32 | # Datasets 33 | parser.add_argument("--data_folder", type=str, default='output_w_none_399k', 34 | help="Directory for dataset.") 35 | parser.add_argument("--max_len_token_structure", type=int, default=300, 36 | help="Don't sample captions structure longer than this length.") 37 | parser.add_argument("--max_len_token_cell", type=int, default=100, 38 | help="Don't sample captions cell longer than this length.") 39 | parser.add_argument("--image_size", type=int, default=448, 40 | help="Different image's height and width for different backbone.") 41 | # Training 42 | parser.add_argument("--num_epochs", type=int, default=13, 43 | help="Number of training epochs.") 44 | parser.add_argument("--batch_size", type=int, default=2, 45 | help="Training batch size for one process.") 46 | parser.add_argument("--num_workers", type=int, default=4, 47 | help="Number workers per process(GPU) to loading data.") 48 | parser.add_argument("--learning_rate", type=float, default=1e-3, 49 | help="Learning rate.") 50 | parser.add_argument("--structure_dropout", type=float, default=0.5, 51 | help="Dropout ratio of structure module.") 52 | parser.add_argument("--cell_dropout", type=float, default=0.2, 53 | help="Dropout ratio of cell module.") 54 | parser.add_argument("--grad_clip", type=float, 55 | help="Clip gradients at an absolute value of.") 56 | parser.add_argument("--alpha_c", type=float, default=1.0, 57 | help="Regularization parameter for 'doubly stochastic attention', as in the paper.") 58 | parser.add_argument("--random_seed", type=int, default=123, 59 | help="Random seed.") 60 | parser.add_argument("--print_freq", type=int, default=20, 61 | help="Print training/validation stats every __ batches.") 62 | parser.add_argument("--model_dir", type=str, default='checkpoints', 63 | help="Directory for saving models.") 64 | parser.add_argument("--model_filename", type=str, 65 | default='dist', 66 | help="Model filename.") 67 | parser.add_argument("--stage", type=str, default='structure', 68 | help="Choice in 'structure' and 'cell' ") 69 | parser.add_argument("--hyper_loss", type=float, default=1.0, 70 | help="when stage is structure, hyper_loss is 1.0. " 71 | "When stage is cell, hyper_loss is 0.5, as in the paper.") 72 | parser.add_argument("--resume", action="store_true", 73 | help="Resume training from saved checkpoint.") 74 | parser.add_argument("--pretrained_model_path", type=str, 75 | default=None, 76 | help="Resume training from saved checkpoint.") 77 | parser.add_argument("--first_epoch", type=int, default=10, 78 | help="Number of epoch in learning rate 1e-3.") 79 | parser.add_argument("--second_epoch", type=int, default=3, 80 | help="Number of epoch in learning rate 1e-4.") 81 | parser.add_argument("--p_structure", type=float, default=1.0, 82 | help="probability of using gt token to predict next token.") 83 | parser.add_argument("--p_cell", type=float, default=1.0, 84 | help="probability of using gt token to predict next token.") 85 | 86 | # Validation 87 | parser.add_argument("--only_val", action="store_true", 88 | help="Only validation.") 89 | 90 | # Model setting 91 | parser.add_argument("--backbone", type=str, default='resnet18', 92 | help="The backbone of encoder") 93 | parser.add_argument("--EDD_type", type=str, default='S1S1', 94 | help="The type of EDD, choice in S1S1, S2S2") 95 | parser.add_argument("--emb_dim_structure", type=int, default=16, 96 | help="Dimension of word embeddings for structure token") 97 | parser.add_argument("--emb_dim_cell", type=int, default=80, 98 | help="Dimension of word embeddings for cell token") 99 | parser.add_argument("--attention_dim", type=int, default=512, 100 | help="Dimension of attention linear layers") 101 | parser.add_argument("--decoder_dim_structure", type=int, default=256, 102 | help="Dimension of decoder RNN structure") 103 | parser.add_argument("--decoder_dim_cell", type=int, default=512, 104 | help="Dimension of decoder RNN cell") 105 | parser.add_argument("--fp16", action="store_true", 106 | help="Model with FP16.") 107 | args = parser.parse_args() 108 | 109 | 110 | # Create structure table and cell table 111 | word_map_structure_file = os.path.join( 112 | args.data_folder, "WORDMAP_STRUCTURE.json") 113 | word_map_cell_file = os.path.join(args.data_folder, "WORDMAP_CELL.json") 114 | with open(word_map_structure_file, "r") as j: 115 | word_map_structure = json.load(j) 116 | with open(word_map_cell_file, "r") as j: 117 | word_map_cell = json.load(j) 118 | id2word_structure = id_to_word(word_map_structure) 119 | id2word_cell = id_to_word(word_map_cell) 120 | 121 | # Init. distribution training setting 122 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 123 | args.local_rank = int(os.environ["RANK"]) 124 | args.world_size = int(os.environ['WORLD_SIZE']) 125 | random_seed = args.random_seed + args.local_rank 126 | set_random_seeds(random_seed=random_seed) 127 | torch.cuda.set_device(args.local_rank % torch.cuda.device_count()) # each node has same number of GPUs 128 | print('| distributed init (rank {}): {}'.format( 129 | args.local_rank, args.dist_url), flush=True) 130 | torch.distributed.init_process_group(backend="nccl", init_method=args.dist_url, 131 | world_size=args.world_size, rank=args.local_rank) 132 | torch.distributed.barrier() 133 | 134 | # Create metrics 135 | teds = TEDS(n_jobs=64//torch.cuda.device_count()) 136 | 137 | # Create save path 138 | save_folder = os.path.join(args.model_dir, args.backbone+'_'+args.EDD_type) 139 | args.save_folder = save_folder 140 | if args.local_rank == 0: 141 | if os.path.exists(save_folder) is False: 142 | os.makedirs(save_folder) 143 | torch.distributed.barrier() 144 | 145 | # Device ID in each node 146 | device_id = torch.cuda.current_device() 147 | device = torch.device("cuda:%s"%device_id) 148 | 149 | # Init. training 150 | start_epoch = 0 151 | # keeps track of number of epochs since there's been an improvement in validation 152 | epochs_since_improvement = 0 153 | best_TED = 0. # TED score right now 154 | logger_file = os.path.join(save_folder, 'logger.txt') 155 | 156 | # structure_weight = pd.read_csv('structure_class_weight.csv').values[:, -1].astype(np.float32) 157 | # structure_weight[structure_weight>1.] = structure_weight[structure_weight>1.]*2 158 | # print(structure_weight) 159 | # criterion_structure = nn.CrossEntropyLoss(reduction='mean', weight=torch.FloatTensor(structure_weight)) 160 | criterion_structure = nn.CrossEntropyLoss(reduction='mean') 161 | criterion_cell = nn.CrossEntropyLoss(reduction='mean') 162 | 163 | if args.EDD_type == 'S1S1': 164 | encoded_image_size = args.image_size // 16 165 | last_conv_stride = 1 166 | elif args.EDD_type == 'S2S2': 167 | encoded_image_size = args.image_size // 32 168 | last_conv_stride = 2 169 | 170 | model = EDD(encoded_image_size=encoded_image_size, # feature's size of last Conv. 171 | encoder_dim=512, # feature's channel of last Conv. 172 | pretrained=False, # pretrained backbone network 173 | structure_attention_dim=args.attention_dim, 174 | structure_embed_dim=args.emb_dim_structure, 175 | structure_decoder_dim=args.decoder_dim_structure, 176 | structure_dropout=args.structure_dropout, 177 | structure_vocab=word_map_structure, 178 | cell_attention_dim=args.attention_dim, 179 | cell_embed_dim=args.emb_dim_cell, 180 | cell_decoder_dim=args.decoder_dim_cell, 181 | cell_dropout=args.cell_dropout, 182 | cell_vocab=word_map_cell, 183 | criterion_structure=criterion_structure, 184 | criterion_cell=criterion_cell, 185 | alpha_c=args.alpha_c, 186 | id2word_structure=id2word_structure, 187 | id2word_cell=id2word_cell, 188 | last_conv_stride=last_conv_stride, 189 | lstm_bias=True, # https://github.com/pytorch/pytorch/issues/42605 190 | backbone=args.backbone) 191 | 192 | # Move to GPU, if available 193 | model = apex.parallel.convert_syncbn_model(model) 194 | model = model.to(device) 195 | 196 | optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()), 197 | lr=args.learning_rate) 198 | 199 | model, optimizer = amp.initialize(model, optimizer, opt_level="O0") 200 | # https://github.com/pytorch/pytorch/issues/24005 201 | model = DDP(model, delay_allreduce=True) 202 | model._disable_allreduce = True 203 | 204 | if is_main_process(): 205 | count = 0 206 | model_name_list = [] 207 | for name, param in model.named_parameters(): 208 | print(name, param.requires_grad, param.data.size()) 209 | count += np.prod(np.array(param.data.size())) 210 | model_name_list.append(name) 211 | print(count) 212 | 213 | if args.pretrained_model_path is not None: 214 | checkpoint = torch.load(args.pretrained_model_path, 215 | map_location='cpu') 216 | model.module.load_state_dict(checkpoint["model"], strict=False) 217 | if is_main_process(): 218 | print("Load pretrained model from: ", args.pretrained_model_path) 219 | 220 | if args.resume: 221 | optimizer.load_state_dict(checkpoint["optimizer"]) 222 | epochs_since_improvement = checkpoint['epochs_since_improvement'] 223 | best_TED = checkpoint['ted_score'] 224 | start_epoch = checkpoint['epoch'] + 1 225 | if is_main_process(): 226 | print("Start epoch: %d, best TED score: %.4f"%(start_epoch, best_TED), flush=True) 227 | 228 | # Custom dataloaders 229 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 230 | std=[0.229, 0.224, 0.225]) 231 | 232 | print("Local Rank: {}, Loading train_loader and val_loader:".format(args.local_rank)) 233 | 234 | val_set = CaptionDataset( 235 | args.data_folder, 236 | 'val', 237 | transform=transforms.Compose([normalize]), 238 | max_len_token_structure=args.max_len_token_structure, 239 | max_len_token_cell=args.max_len_token_cell, 240 | image_size=args.image_size) 241 | val_sampler = DistributedSampler(dataset=val_set) 242 | val_loader = torch.utils.data.DataLoader( 243 | dataset=val_set, 244 | batch_size=args.batch_size, 245 | sampler=val_sampler, 246 | shuffle=False, 247 | num_workers=args.num_workers, 248 | pin_memory=True) 249 | 250 | if args.only_val: 251 | print("Local Rank: {}, Only Validation ...".format(args.local_rank)) 252 | recent_ted_score = val(val_loader=val_loader, 253 | model=model, 254 | device=device, 255 | args=args, 256 | teds=teds, 257 | logger_file=logger_file) 258 | return recent_ted_score 259 | 260 | train_set = CaptionDataset( 261 | args.data_folder, 262 | ['train'], 263 | transform=transforms.Compose([normalize]), 264 | max_len_token_structure=args.max_len_token_structure, 265 | max_len_token_cell=args.max_len_token_cell, 266 | image_size=args.image_size) 267 | train_sampler = DistributedSampler(dataset=train_set) 268 | train_loader = torch.utils.data.DataLoader( 269 | dataset=train_set, 270 | batch_size=args.batch_size, 271 | sampler=train_sampler, 272 | num_workers=args.num_workers, 273 | pin_memory=True) 274 | print("Local Rank: {}, Done train_loader and val_loader:".format(args.local_rank)) 275 | 276 | 277 | # Train for each epoch 278 | for epoch in range(start_epoch, args.num_epochs): 279 | train_sampler.set_epoch(epoch) 280 | # Structure 281 | if args.first_epoch <= epoch < args.first_epoch + args.second_epoch: 282 | adjust_learning_rate(optimizer, 0.1*args.learning_rate) 283 | elif args.first_epoch + args.second_epoch <= epoch < args.first_epoch*4 + args.second_epoch: # cell 284 | args.stage = 'cell' 285 | args.hyper_loss = 0.5 286 | best_TED = 0. 287 | adjust_learning_rate(optimizer, args.learning_rate) 288 | elif args.first_epoch*4 + args.second_epoch <= epoch < args.first_epoch*6 + args.second_epoch: 289 | adjust_learning_rate(optimizer, 0.5*args.learning_rate) 290 | elif args.first_epoch * 6 + args.second_epoch <= epoch < args.first_epoch * 8 + args.second_epoch: 291 | adjust_learning_rate(optimizer, 0.1 * args.learning_rate) 292 | elif args.first_epoch * 8 + args.second_epoch <= epoch < args.first_epoch * 9 + args.second_epoch: 293 | adjust_learning_rate(optimizer, 0.05*args.learning_rate) 294 | elif epoch >= args.first_epoch*9 + args.second_epoch: 295 | adjust_learning_rate(optimizer, 0.01*args.learning_rate) 296 | else: 297 | pass 298 | 299 | if is_main_process(): 300 | print("Epoch: {}, Stage: {}, hyper_loss: {}, lr: {} ...".format( 301 | epoch, 302 | args.stage, 303 | args.hyper_loss, 304 | optimizer.state_dict()['param_groups'][0]['lr'])) 305 | 306 | print("Local Rank: {}, Epoch: {}, Training ...".format(args.local_rank, epoch)) 307 | train(train_loader=train_loader, 308 | model=model, 309 | optimizer=optimizer, 310 | epoch=epoch, 311 | device=device, 312 | args=args, 313 | logger_file=logger_file) 314 | 315 | print("Local Rank: {}, Epoch: {}, Validation ...".format(args.local_rank, epoch)) 316 | recent_ted_score = val(val_loader=val_loader, 317 | model=model, 318 | device=device, 319 | args=args, 320 | teds=teds, 321 | logger_file=logger_file) 322 | 323 | # Check if there was an improvement 324 | is_best = recent_ted_score > best_TED 325 | best_TED = max(recent_ted_score, best_TED) 326 | if not is_best: 327 | epochs_since_improvement += 1 328 | if is_main_process(): 329 | print("\nEpochs since last improvement: %d\n" % 330 | (epochs_since_improvement,)) 331 | else: 332 | epochs_since_improvement = 0 333 | 334 | # save checkpoint 335 | filename = os.path.join(save_folder, 336 | args.model_filename + '_' + 337 | args.stage+ '_epoch_'+ str(epoch) 338 | + "_score_"+str(recent_ted_score)[:6] + '.pth.tar') 339 | save_on_master(epoch, epochs_since_improvement, model, 340 | optimizer, recent_ted_score, is_best, filename) 341 | 342 | 343 | def train(train_loader, model, optimizer, epoch, device, args, logger_file): 344 | 345 | model.train() 346 | # model.module.encoder_wrapper.eval() 347 | world_size = args.world_size 348 | 349 | batch_time = AverageMeter() # forward prop. + back prop. time 350 | data_time = AverageMeter() # data loading time 351 | losses = AverageMeter() # loss (per word decoded) 352 | losses_structure = AverageMeter() 353 | losses_cell = AverageMeter() 354 | 355 | top5accs_structure = AverageMeter() # top5 accuracy 356 | top1accs_structure = AverageMeter() # top1 accuracy 357 | top5accs_cell = AverageMeter() # top5 accuracy 358 | top1accs_cell = AverageMeter() # top1 accuracy 359 | 360 | start = time.time() 361 | if is_main_process(): 362 | print("length of train_loader: {}".format(len(train_loader))) 363 | for i, (imgs, 364 | caption_structures, 365 | caplen_structures, 366 | caption_cells, 367 | caplen_cells, 368 | number_cell_per_images) in enumerate(train_loader): 369 | if epoch == 0: 370 | adjust_learning_rate(optimizer, min((i/500.), 1.)*args.learning_rate) 371 | 372 | data_time.update(time.time() - start) 373 | 374 | imgs = imgs.to(device) 375 | caption_structures = caption_structures.to(device) 376 | caplen_structures = caplen_structures.to(device) 377 | caption_cells = [caption_cell.to(device) for caption_cell in caption_cells] 378 | caplen_cells = [caplen_cell.to(device) for caplen_cell in caplen_cells] 379 | number_cell_per_images = number_cell_per_images.to(device) 380 | 381 | loss_structures, scores_structure, targets_structure, \ 382 | loss_cells, total_scores_cells, total_target_cells = \ 383 | model(imgs, 384 | caption_structures, 385 | caplen_structures, 386 | caption_cells, 387 | caplen_cells, 388 | number_cell_per_images) 389 | 390 | # Total loss 391 | if args.stage == 'cell': 392 | loss = args.hyper_loss * loss_structures + (1-args.hyper_loss) * loss_cells 393 | elif args.stage == 'structure': 394 | loss = args.hyper_loss * loss_structures + (1 - args.hyper_loss) * loss_cells.detach() 395 | 396 | # Back prop. 397 | iters_to_accumulate = 1 398 | with amp.scale_loss(loss/iters_to_accumulate, optimizer) as scaled_loss: 399 | scaled_loss.backward() 400 | if i% iters_to_accumulate == 0: 401 | allreduce_params_opt(optimizer) 402 | if args.grad_clip is not None: 403 | torch.nn.utils.clip_grad_value_(amp.master_params(optimizer), args.grad_clip) # actually optimizer 404 | # Update weights 405 | optimizer.step() 406 | optimizer.zero_grad() 407 | 408 | batch_time.update(time.time() - start) 409 | start = time.time() 410 | 411 | # Print status 412 | if i % args.print_freq == 0: 413 | # Keep track of metrics 414 | # all reduce LOSS 415 | torch.distributed.all_reduce(loss_structures) 416 | torch.distributed.all_reduce(loss_cells) 417 | torch.distributed.all_reduce(loss) 418 | 419 | losses_structure.update(loss_structures.item()/world_size, 1) 420 | losses_cell.update(loss_cells.item() / world_size, 1) 421 | losses.update(loss.item() / world_size, 1) 422 | 423 | # STRUCTURE ACC. 424 | top1_structure, top5_structure = accuracy(scores_structure, targets_structure, (1, 5)) 425 | targets_structure_size = torch.LongTensor([targets_structure.size(0)]).squeeze(0).to(device) 426 | torch.distributed.all_reduce(top1_structure) 427 | torch.distributed.all_reduce(top5_structure) 428 | torch.distributed.all_reduce(targets_structure_size) 429 | top5accs_structure.update(top5_structure.item()/world_size, targets_structure_size.item()) 430 | top1accs_structure.update(top1_structure.item()/world_size, targets_structure_size.item()) 431 | 432 | # CELL ACC. 433 | top1_cell, top5_cell = accuracy(total_scores_cells, total_target_cells, (1, 5)) 434 | total_target_cells_size = torch.LongTensor([total_target_cells.size(0)]).squeeze(0).to(device) 435 | torch.distributed.all_reduce(top1_cell) 436 | torch.distributed.all_reduce(top5_cell) 437 | torch.distributed.all_reduce(total_target_cells_size) 438 | top5accs_cell.update(top5_cell.item()/world_size, total_target_cells_size.item()) 439 | top1accs_cell.update(top1_cell.item()/world_size, total_target_cells_size.item()) 440 | if is_main_process(): 441 | print('Epoch: [{0}][{1}/{2}]\t' 442 | 'lr: {lr:.8f}' 443 | 'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 444 | 'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t' 445 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 446 | 'Loss Stru {loss_s.val:.4f} ({loss_s.avg:.4f})\t' 447 | 'Loss Cell {loss_c.val:.4f} ({loss_c.avg:.4f})\t' 448 | 'Top-5 Stru. Acc. {top5.val:.3f} ({top5.avg:.3f})\t' 449 | 'Top-1 Stru. Acc. {top1.val:.3f} ({top1.avg:.3f})\t' 450 | 'Top-5 Cell Acc. {top5_s.val:.3f} ({top5_s.avg:.3f})\t' 451 | 'Top-1 Cell Acc. {top1_s.val:.3f} ({top1_s.avg:.3f})'.format( 452 | epoch, i, len(train_loader), 453 | lr=optimizer.param_groups[0]['lr'], 454 | batch_time=batch_time, 455 | data_time=data_time, loss=losses, 456 | loss_s=losses_structure, loss_c=losses_cell, 457 | top5=top5accs_structure, top1=top1accs_structure, 458 | top5_s=top5accs_cell, top1_s=top1accs_cell), flush=True) 459 | 460 | # if i % 10000 == 0: 461 | # filename = os.path.join(args.save_folder, 462 | # args.model_filename + '_' + 463 | # args.stage + '_epoch_' + str(epoch) 464 | # + "_iters_" + str(i) + '.pth.tar') 465 | # save_on_master(epoch, 1, model, 466 | # optimizer, 0.0, False, filename) 467 | 468 | if is_main_process(): 469 | log_str = ('Epoch: [{0}]\n' 470 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 471 | 'Loss Stru {loss_s.val:.4f} ({loss_s.avg:.4f})\t' 472 | 'Loss Cell {loss_c.val:.4f} ({loss_c.avg:.4f})\t' 473 | 'Top-5 Stru. Acc. {top5.val:.3f} ({top5.avg:.3f})\t' 474 | 'Top-1 Stru. Acc. {top1.val:.3f} ({top1.avg:.3f})\t' 475 | 'Top-5 Cell Acc. {top5_s.val:.3f} ({top5_s.avg:.3f})\t' 476 | 'Top-1 Cell Acc. {top1_s.val:.3f} ({top1_s.avg:.3f})'.format( 477 | epoch, 478 | loss=losses, 479 | loss_s=losses_structure, loss_c=losses_cell, 480 | top5=top5accs_structure, top1=top1accs_structure, 481 | top5_s=top5accs_cell, top1_s=top1accs_cell)) 482 | with open(logger_file, 'a') as f: 483 | f.write(log_str+'\n') 484 | 485 | 486 | 487 | def val(val_loader, model, device, args, teds, logger_file): 488 | model.eval() 489 | world_size = args.world_size 490 | 491 | total_loss_structure = list() 492 | total_loss_cell = list() 493 | total_loss = list() 494 | 495 | top5accs_structure = AverageMeter() # top5 accuracy 496 | top1accs_structure = AverageMeter() # top1 accuracy 497 | top5accs_cell = AverageMeter() # top5 accuracy 498 | top1accs_cell = AverageMeter() # top1 accuracy 499 | 500 | html_predict_only_structures = list() 501 | html_true_only_structures = list() 502 | html_predict_only_cells = list() 503 | html_predict_alls = list() 504 | html_trues = list() 505 | 506 | with torch.no_grad(): 507 | for it, (imgs, 508 | caption_structures, 509 | caplen_structures, 510 | caption_cells, 511 | caplen_cells, 512 | number_cell_per_images) in enumerate(val_loader): 513 | imgs = imgs.to(device) 514 | caption_structures = caption_structures.to(device) 515 | caplen_structures = caplen_structures.to(device) 516 | caption_cells = [caption_cell.to(device) for caption_cell in caption_cells] 517 | caplen_cells = [caplen_cell.to(device) for caplen_cell in caplen_cells] 518 | number_cell_per_images = number_cell_per_images.to(device) 519 | 520 | loss_structures, loss_cells, \ 521 | batch_html_predict_only_structures, \ 522 | batch_html_true_only_structures, \ 523 | batch_html_predict_only_cells, \ 524 | batch_html_predict_alls, \ 525 | batch_html_trues, \ 526 | scores_structure, targets_structure, \ 527 | total_scores_cells, total_target_cells = model(imgs, 528 | caption_structures, 529 | caplen_structures, 530 | caption_cells, 531 | caplen_cells, 532 | number_cell_per_images) 533 | 534 | loss = args.hyper_loss * loss_structures + \ 535 | (1-args.hyper_loss) * loss_cells 536 | 537 | total_loss_structure.append(loss_structures.cpu()) 538 | total_loss_cell.append(loss_cells.cpu()) 539 | total_loss.append(loss.cpu()) 540 | 541 | html_predict_only_structures.extend(batch_html_predict_only_structures) 542 | html_true_only_structures.extend(batch_html_true_only_structures) 543 | html_predict_only_cells.extend(batch_html_predict_only_cells) 544 | html_predict_alls.extend(batch_html_predict_alls) 545 | html_trues.extend(batch_html_trues) 546 | 547 | top1_structure, top5_structure = accuracy(scores_structure, targets_structure, (1, 5)) 548 | targets_structure_size = torch.LongTensor([targets_structure.size(0)]).squeeze(0).to(device) 549 | torch.distributed.all_reduce(top1_structure) 550 | torch.distributed.all_reduce(top5_structure) 551 | torch.distributed.all_reduce(targets_structure_size) 552 | top5accs_structure.update(top5_structure.item() / world_size, targets_structure_size.item()) 553 | top1accs_structure.update(top1_structure.item() / world_size, targets_structure_size.item()) 554 | 555 | # CELL ACC. 556 | top1_cell, top5_cell = accuracy(total_scores_cells, total_target_cells, (1, 5)) 557 | total_target_cells_size = torch.LongTensor([total_target_cells.size(0)]).squeeze(0).to(device) 558 | torch.distributed.all_reduce(top1_cell) 559 | torch.distributed.all_reduce(top5_cell) 560 | torch.distributed.all_reduce(total_target_cells_size) 561 | top5accs_cell.update(top5_cell.item() / world_size, total_target_cells_size.item()) 562 | top1accs_cell.update(top1_cell.item() / world_size, total_target_cells_size.item()) 563 | if is_main_process(): 564 | print('it [{0}/{1}]\t' 565 | 'Top-5 Stru. Acc. {top5.val:.3f} ({top5.avg:.3f})\t' 566 | 'Top-1 Stru. Acc. {top1.val:.3f} ({top1.avg:.3f})\t' 567 | 'Top-5 Cell Acc. {top5_s.val:.3f} ({top5_s.avg:.3f})\t' 568 | 'Top-1 Cell Acc. {top1_s.val:.3f} ({top1_s.avg:.3f})'.format( 569 | it, len(val_loader), 570 | top5=top5accs_structure, top1=top1accs_structure, 571 | top5_s=top5accs_cell, top1_s=top1accs_cell), flush=True) 572 | 573 | # Average loss in local rank 574 | loss_structures = mean_loss(total_loss_structure).to(device) 575 | loss_cells = mean_loss(total_loss_cell).to(device) 576 | loss = mean_loss(total_loss).to(device) 577 | 578 | # Average loss in all rank 579 | torch.distributed.all_reduce(loss_structures) 580 | torch.distributed.all_reduce(loss_cells) 581 | torch.distributed.all_reduce(loss) 582 | 583 | # Calculate val. set size 584 | pred_local_rank_size = torch.LongTensor([len(html_predict_alls)]).squeeze(0).to(device) 585 | torch.distributed.all_reduce(pred_local_rank_size) 586 | 587 | # Calculate average TEDS scores in local rank 588 | scores_only_structure = teds.batch_evaluate_html( 589 | html_predict_only_structures, html_true_only_structures, is_tqdm=is_main_process()) 590 | 591 | scores_only_cell = teds.batch_evaluate_html( 592 | html_predict_only_cells, html_trues, is_tqdm=is_main_process()) 593 | 594 | scores_all = teds.batch_evaluate_html( 595 | html_predict_alls, html_trues, is_tqdm=is_main_process()) 596 | 597 | torch.distributed.barrier() 598 | if is_main_process(): 599 | for ii in range(3): 600 | print("#" * 80) 601 | print("index: ", ii) 602 | print("html_predict_only_structure: \n", html_predict_only_structures[ii]) 603 | print("html_true_only_structure: \n", html_true_only_structures[ii]) 604 | print('TEDS score only structure:', scores_only_structure[ii]) 605 | print("-" * 80) 606 | print("html_predict_only_cell: \n", html_predict_only_cells[ii]) 607 | print("html_true: \n", html_trues[ii]) 608 | print('TEDS score only cell:', scores_only_cell[ii]) 609 | print("-" * 80) 610 | print("html_predict_all: \n", html_predict_alls[ii]) 611 | print("html_true: \n", html_trues[ii]) 612 | print('TEDS score:', scores_all[ii]) 613 | sys.stdout.flush() 614 | torch.distributed.barrier() 615 | 616 | ted_score_structure = torch.FloatTensor([np.mean(scores_only_structure)]).squeeze(0).to(device) 617 | ted_score_cell = torch.FloatTensor([np.mean(scores_only_cell)]).squeeze(0).to(device) 618 | ted_score = torch.FloatTensor([np.mean(scores_all)]).squeeze(0).to(device) 619 | 620 | # Calculate average TEDS scores in all rank 621 | torch.distributed.all_reduce(ted_score_structure) 622 | torch.distributed.all_reduce(ted_score_cell) 623 | torch.distributed.all_reduce(ted_score) 624 | if is_main_process(): 625 | print_lines = [ 626 | "Eval set size: {}".format(pred_local_rank_size.item()), 627 | "LOSS_STRUCTURE: {} \nLOSS_CELL: {} \nLOSS_DUAL_DECODER: {}".format( 628 | loss_structures.item() / world_size, loss_cells.item() / world_size, loss.item() / world_size), 629 | "TED_SCORE_STRUCTURE: {}".format(ted_score_structure.item() / world_size), 630 | "TED_SCORE_CELL: {}".format(ted_score_cell.item() / world_size), 631 | "TED_SCORE: {}".format(ted_score.item() / world_size) 632 | ] 633 | with open(logger_file, 'a') as f: 634 | for line in print_lines: 635 | print(line, flush=True) 636 | f.write(line + '\n') 637 | f.write('\n') 638 | if args.stage == 'structure': 639 | return ted_score_structure.item() / world_size 640 | else: 641 | return ted_score.item() / world_size 642 | 643 | 644 | if __name__ == "__main__": 645 | main() 646 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | from torch import nn 5 | import backbones 6 | import timm 7 | import copy 8 | from torch.nn.utils.rnn import pack_padded_sequence 9 | from torch.utils.checkpoint import checkpoint 10 | from utils import * 11 | 12 | USE_CHECKPOINT=True 13 | USE_CHECKPOINT2=True 14 | USE_CHECKPOINT3=True 15 | FPN=False 16 | 17 | class EDD(nn.Module): 18 | def __init__(self, 19 | encoded_image_size=14, 20 | encoder_dim=512, # encoder_dim is abandoned 21 | pretrained=True, 22 | structure_attention_dim=512, 23 | structure_embed_dim=16, 24 | structure_decoder_dim=256, 25 | structure_dropout=0.5, 26 | structure_vocab=None, 27 | cell_attention_dim=512, 28 | cell_embed_dim=80, 29 | cell_decoder_dim=512, 30 | cell_dropout=0.2, 31 | cell_vocab=None, 32 | criterion_structure=None, 33 | criterion_cell=None, 34 | alpha_c=1.0, 35 | id2word_structure=None, 36 | id2word_cell=None, 37 | last_conv_stride=2, 38 | lstm_bias=True, 39 | backbone='resnet18'): 40 | super(EDD, self).__init__() 41 | self.dummy_tensor = torch.ones(1, dtype=torch.float32, requires_grad=True) 42 | self.encoder = Encoder( 43 | backbone=backbone, 44 | encoded_image_size=encoded_image_size, 45 | encoder_dim=encoder_dim, 46 | pretrained=pretrained, 47 | last_stride=last_conv_stride) 48 | self.encoder_wrapper = ModuleWrapperIgnores2ndArg(self.encoder) 49 | 50 | self.decoder_structure = \ 51 | DecoderStructureWithAttention( 52 | attention_dim=structure_attention_dim, 53 | embed_dim=structure_embed_dim, 54 | decoder_dim=structure_decoder_dim, 55 | vocab=structure_vocab, 56 | dropout=structure_dropout, 57 | encoder_dim=self.encoder.encoder_dim, 58 | lstm_bias=lstm_bias) 59 | 60 | self.decoder_cell = \ 61 | DecoderCellPerImageWithAttention( 62 | attention_dim=cell_attention_dim, 63 | embed_dim=cell_embed_dim, 64 | decoder_dim=cell_decoder_dim, 65 | vocab_size=len(cell_vocab), 66 | dropout=cell_dropout, 67 | decoder_structure_dim=structure_decoder_dim, 68 | encoder_dim=self.encoder.encoder_dim, 69 | lstm_bias=lstm_bias) 70 | self.criterion_structure = criterion_structure 71 | self.criterion_cell = criterion_cell 72 | self.alpha_c = alpha_c 73 | self.id2word_structure = id2word_structure 74 | self.id2word_cell = id2word_cell 75 | self.cell_vocab = cell_vocab 76 | self.structure_vocab = structure_vocab 77 | 78 | def custom(self, module): 79 | def custom_forward(*inputs): 80 | output_1, output_2 = module(inputs[0], inputs[1]) 81 | return output_1, output_2 82 | 83 | return custom_forward 84 | 85 | def forward(self, images, 86 | caption_structures, 87 | caplen_structures, 88 | caption_cells, 89 | caplen_cells, 90 | number_cell_per_images): 91 | if not USE_CHECKPOINT3: 92 | imgs_structure, imgs_cell = \ 93 | self.encoder_wrapper(images, self.dummy_tensor) 94 | else: 95 | imgs_structure, imgs_cell = \ 96 | checkpoint(self.custom(self.encoder_wrapper), images, self.dummy_tensor) 97 | 98 | scores, \ 99 | caps_sorted, \ 100 | decode_lengths_structure, \ 101 | alphas_structure, \ 102 | hidden_states, \ 103 | sort_ind = \ 104 | self.decoder_structure( 105 | imgs_structure, 106 | caption_structures, 107 | caplen_structures) 108 | # Since we decoded starting with , the targets are all words after , up to 109 | targets = caps_sorted[:, 1:] 110 | 111 | # Remove timesteps that we didn't decode at, or are pads 112 | # pack_padded_sequence is an easy trick to do this 113 | scores_structure = \ 114 | pack_padded_sequence( 115 | scores, 116 | decode_lengths_structure, 117 | batch_first=True).data 118 | targets_structure = \ 119 | pack_padded_sequence( 120 | targets, 121 | decode_lengths_structure, 122 | batch_first=True).data 123 | 124 | loss_structures = \ 125 | self.criterion_structure( 126 | scores_structure, 127 | targets_structure) 128 | loss_structures += \ 129 | self.alpha_c * \ 130 | ((1. - alphas_structure.sum(dim=1)) ** 2).mean() 131 | 132 | if not self.training: 133 | _, pred_structure = torch.max(scores, dim=2) 134 | pred_structure = pred_structure.tolist() 135 | html_trues = list() 136 | html_predict_only_structures = list() 137 | html_true_only_structures = list() 138 | html_predict_only_cells = list() 139 | html_predict_alls = list() 140 | 141 | # decoder cell per image 142 | scores_cells_list = [] 143 | target_cells_list = [] 144 | loss_cells = [] 145 | for (ii, ind) in enumerate(sort_ind): 146 | img = imgs_cell[ind] 147 | hidden_state_structures = hidden_states[ii] 148 | hidden_state_structures = torch.stack(hidden_state_structures) 149 | number_cell_per_image = number_cell_per_images[ind][0] 150 | 151 | caption_cell = caption_cells[ind][:number_cell_per_image] 152 | caplen_cell = caplen_cells[ind][:number_cell_per_image] 153 | 154 | # Foward encoder image and decoder cell per image 155 | scores_cell, \ 156 | caps_sorted_cell, \ 157 | decode_lengths_cell, \ 158 | alphas_cell, \ 159 | sort_ind_ = \ 160 | self.decoder_cell( 161 | img, 162 | caption_cell, 163 | caplen_cell, 164 | hidden_state_structures, 165 | number_cell_per_image) 166 | 167 | if not self.training: 168 | html_predict_only_structure = "" 169 | html_true_only_structure = "" 170 | html_predict_only_cell = "" 171 | html_true = "" 172 | html_predict_all = "" 173 | 174 | _, pred_cells = torch.max(scores_cell, dim=2) 175 | pred_cells = pred_cells.tolist() 176 | ground_truth = list() 177 | 178 | # get cell content in per images when predict 179 | temp_preds = [''] * len(pred_cells) 180 | for j, p in enumerate(pred_cells): 181 | # because sort cell with descending, mapping pred_cell to sort_ind_ 182 | words = p[:decode_lengths_cell[j]] 183 | temp_preds[sort_ind_[j]] += convertId2wordSentence(self.id2word_cell, words) 184 | 185 | # get cell content in per images ground_truth 186 | for j in range(caption_cell.shape[0]): 187 | img_caps = caption_cell[j].tolist() 188 | img_captions = [w for w in img_caps if w not in { 189 | self.cell_vocab[''], self.cell_vocab['']}] # remove and pads 190 | ground_truth.append(convertId2wordSentence( 191 | self.id2word_cell, img_captions)) 192 | 193 | index_cell = 0 194 | cap_structure = caps_sorted[ii][:decode_lengths_structure[ii]].tolist() 195 | pred_structure_image = pred_structure[ii][:decode_lengths_structure[ii]] 196 | 197 | for (index, c) in enumerate(cap_structure): 198 | if c == self.structure_vocab[""] or c == self.structure_vocab[""]: 199 | continue 200 | html_predict_only_cell += self.id2word_structure[c] 201 | html_predict_only_structure += self.id2word_structure[pred_structure_image[index - 1]] 202 | html_true_only_structure += self.id2word_structure[c] 203 | html_true += self.id2word_structure[c] 204 | html_predict_all += self.id2word_structure[pred_structure_image[index - 1]] 205 | if c == self.structure_vocab[""] or c == self.structure_vocab[">"]: 206 | html_predict_only_cell += temp_preds[index_cell] 207 | html_true += ground_truth[index_cell] 208 | html_predict_all += temp_preds[index_cell] 209 | index_cell += 1 210 | 211 | html_predict_only_structure_ = create_html(html_predict_only_structure) 212 | html_true_only_structure_ = create_html(html_true_only_structure) 213 | html_predict_only_cell_ = create_html(html_predict_only_cell) 214 | html_predict_all_ = create_html(html_predict_all) 215 | html_true_ = create_html(html_true) 216 | 217 | html_predict_only_structures.append(html_predict_only_structure_) 218 | html_true_only_structures.append(html_true_only_structure_) 219 | html_predict_only_cells.append(html_predict_only_cell_) 220 | html_predict_alls.append(html_predict_all_) 221 | html_trues.append(html_true_) 222 | 223 | target_cell = caps_sorted_cell[:, 1:] 224 | # Remove timesteps that we didn't decode at, or are pads 225 | # pack_padded_sequence is an easy trick to do this 226 | scores_cell = pack_padded_sequence( 227 | scores_cell, decode_lengths_cell, batch_first=True).data 228 | target_cell = pack_padded_sequence( 229 | target_cell, decode_lengths_cell, batch_first=True).data 230 | scores_cells_list.append(scores_cell) 231 | target_cells_list.append(target_cell) 232 | loss_cell = self.criterion_cell(scores_cell, target_cell) 233 | loss_cell += self.alpha_c * ((1. - alphas_cell.sum(dim=1)) ** 2).mean() 234 | loss_cells.append(loss_cell) 235 | 236 | scores_cell = torch.cat(scores_cells_list, dim=0) 237 | targets_cell = torch.cat(target_cells_list, dim=0) 238 | loss_cells = torch.stack(loss_cells) 239 | loss_cells = torch.mean(loss_cells) 240 | 241 | if self.training: 242 | return loss_structures, scores_structure, targets_structure, \ 243 | loss_cells, scores_cell, targets_cell 244 | else: 245 | return loss_structures, loss_cells, \ 246 | html_predict_only_structures, \ 247 | html_true_only_structures, \ 248 | html_predict_only_cells, \ 249 | html_predict_alls, \ 250 | html_trues, \ 251 | scores_structure, targets_structure, \ 252 | scores_cell, targets_cell 253 | 254 | 255 | class ModuleWrapperIgnores2ndArg(nn.Module): 256 | def __init__(self, module): 257 | super().__init__() 258 | self.module = module 259 | 260 | def forward(self,x, dummy_arg=None): 261 | assert dummy_arg is not None 262 | x = self.module(x) 263 | return x 264 | 265 | 266 | class Encoder(nn.Module): 267 | """ 268 | Encoder. 269 | """ 270 | 271 | def __init__(self, 272 | backbone='resnet18', 273 | encoded_image_size=28//2, 274 | encoder_dim=512, 275 | pretrained=True, 276 | last_stride=2): 277 | super(Encoder, self).__init__() 278 | self.enc_image_size = encoded_image_size 279 | self.backbone = backbone 280 | 281 | if 'resnet' in backbone or 'resnext' in backbone: 282 | backbone_net = getattr(backbones, backbone)(pretrained=pretrained, last_stride=last_stride) 283 | if backbone == 'resnet18' or backbone == 'resnet34': 284 | self.encoder_dim = encoder_dim 285 | elif backbone == 'resnet50': 286 | self.encoder_dim = encoder_dim * 4 287 | elif backbone == 'resnext101_32x8d': 288 | print(backbone) 289 | self.encoder_dim = encoder_dim 290 | self.downsample_conv_structure = nn.Conv2d(2048, self.encoder_dim, kernel_size=1, stride=1, bias=False) 291 | self.downsample_bn_structure = nn.BatchNorm2d(self.encoder_dim) 292 | self.downsample_conv_cell = nn.Conv2d(2048, self.encoder_dim, kernel_size=1, stride=1, bias=False) 293 | self.downsample_bn_cell = nn.BatchNorm2d(self.encoder_dim) 294 | self.relu = nn.ReLU(inplace=True) 295 | 296 | offset = -3 297 | backbone_net_all_layer = list(backbone_net.children()) 298 | modules = backbone_net_all_layer[:offset] 299 | self.last_conv_block_for_structure = backbone_net_all_layer[offset] 300 | self.last_conv_block_for_cell = copy.deepcopy(self.last_conv_block_for_structure) 301 | 302 | self.backbone_net = nn.Sequential(*modules) 303 | else: 304 | backbone_net = timm.create_model(backbone, 305 | pretrained=pretrained, 306 | features_only=True, 307 | out_indices=(4,), 308 | output_stride=last_stride*16) 309 | 310 | self.encoder_dim = backbone_net.feature_info.get(key='num_chs', idx=4) 311 | offset = -1 312 | backbone_net_all_layer = list(backbone_net.children()) 313 | modules = backbone_net_all_layer[:offset] 314 | modules.extend(backbone_net_all_layer[offset][:-2]) 315 | self.last_conv_block_for_structure = backbone_net_all_layer[offset][-2:] 316 | self.last_conv_block_for_cell = copy.deepcopy(self.last_conv_block_for_structure) 317 | self.backbone_net = nn.Sequential(*modules) 318 | 319 | def forward(self, images): 320 | """ 321 | Forward propagation. 322 | 323 | :param images: images, a tensor of dimensions (batch_size, 3, image_size, image_size) 324 | :return: encoded images 325 | """ 326 | out = self.backbone_net( 327 | images) # (batch_size, encoder_dim, image_size/16, image_size/16), 328 | 329 | # (batch_size, encoder_dim, encoded_image_size, encoded_image_size) 330 | out_for_structure = self.last_conv_block_for_structure(out) 331 | if self.backbone == 'resnext101_32x8d': 332 | out_for_structure = self.downsample_conv_structure(out_for_structure) 333 | out_for_structure = self.downsample_bn_structure(out_for_structure) 334 | out_for_structure = self.relu(out_for_structure) 335 | out_for_structure = out_for_structure.permute(0, 2, 3, 1) 336 | 337 | out_for_cell = self.last_conv_block_for_cell(out) 338 | if self.backbone == 'resnext101_32x8d': 339 | out_for_cell = self.downsample_conv_cell(out_for_cell) 340 | out_for_cell = self.downsample_bn_cell(out_for_cell) 341 | out_for_cell = self.relu(out_for_cell) 342 | out_for_cell = out_for_cell.permute(0, 2, 3, 1) 343 | 344 | # (batch_size, encoded_image_size, encoded_image_size, encoder_dim) 345 | return out_for_structure, out_for_cell 346 | 347 | 348 | class Soft_Attention(nn.Module): 349 | 350 | def __init__(self, encoder_dim, decoder_dim, attention_dim, structure_decoder_dim=None): 351 | """ 352 | :param encoder_dim: feature size of encoded images 353 | :param decoder_dim: size of decoder's RNN 354 | :param attention_dim: size of the attention network 355 | """ 356 | super(Soft_Attention, self).__init__() 357 | # linear layer to transform encoded image 358 | self.encoder_att = nn.Linear(encoder_dim, attention_dim) 359 | # linear layer to transform decoder's output 360 | self.decoder_att = nn.Linear(decoder_dim, attention_dim) 361 | if structure_decoder_dim is not None: 362 | self.structure_decoder_att = nn.Linear(structure_decoder_dim, attention_dim) 363 | # linear layer to calculate values to be softmax-ed 364 | self.full_att = nn.Linear(attention_dim, 1) 365 | self.relu = nn.ReLU() 366 | self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights 367 | 368 | def forward(self, encoder_out, decoder_hidden, structure_hidden=None): 369 | """ 370 | Forward propagation. 371 | 372 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) 373 | :param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim) 374 | :return: attention weighted encoding, weights 375 | """ 376 | 377 | att1 = self.encoder_att( 378 | encoder_out) # (batch_size, num_pixels, attention_dim) 379 | 380 | att2 = self.decoder_att(decoder_hidden) # (batch_size, attention_dim) 381 | 382 | # attention is created by output encoder and previous decoder output 383 | # size (batch_size, num_pixels) 384 | if structure_hidden is not None: 385 | att3 = self.structure_decoder_att(structure_hidden) # (batch_size, attention_dim) 386 | att = self.full_att(self.relu(att1 + att3.unsqueeze(1)+ att2.unsqueeze(1))).squeeze(2) 387 | else: 388 | att = self.full_att(self.relu(att1 + att2.unsqueeze(1))).squeeze(2) 389 | alpha = self.softmax(att) # size is (batch_size, num_pixels) 390 | attention_weighted_encoding = ( 391 | encoder_out * alpha.unsqueeze(2)).sum(dim=1) # size is (batch_size, encoder_dim) 392 | 393 | return attention_weighted_encoding, alpha 394 | 395 | 396 | class DecoderStructureWithAttention(nn.Module): 397 | """ 398 | Decoder. 399 | """ 400 | 401 | def __init__(self, attention_dim, embed_dim, decoder_dim, vocab, encoder_dim=512, dropout=0.5, lstm_bias=True): 402 | """ 403 | :param attention_dim: size of attention network 404 | :param embed_dim: embedding size 405 | :param decoder_dim: size of decoder's RNN 406 | :param vocab_size: size of vocabulary 407 | :param encoder_dim: feature size of encoded images 408 | :param dropout: dropout 409 | """ 410 | super(DecoderStructureWithAttention, self).__init__() 411 | 412 | self.encoder_dim = encoder_dim 413 | self.attention_dim = attention_dim 414 | self.embed_dim = embed_dim 415 | self.decoder_dim = decoder_dim 416 | self.vocab = vocab 417 | self.id2words = id_to_word(vocab) 418 | self.vocab_size = len(vocab) 419 | self.dropout = dropout 420 | 421 | self.attention = Soft_Attention( 422 | encoder_dim, decoder_dim, attention_dim) # attention network 423 | 424 | self.embedding = nn.Embedding( 425 | self.vocab_size, embed_dim) # embedding layer 426 | self.dropout = nn.Dropout(p=self.dropout) 427 | self.decode_step = nn.LSTMCell( 428 | embed_dim + encoder_dim, decoder_dim, bias=lstm_bias) # decoding LSTMCell 429 | # linear layer to find initial hidden state of LSTMCell 430 | self.init_h = nn.Linear(encoder_dim, decoder_dim) 431 | # linear layer to find initial cell state of LSTMCell 432 | self.init_c = nn.Linear(encoder_dim, decoder_dim) 433 | 434 | # linear layer to create a sigmoid-activated gate 435 | self.f_beta = nn.Linear(decoder_dim, encoder_dim) 436 | self.sigmoid = nn.Sigmoid() 437 | # linear layer to find scores over vocabulary 438 | self.fc = nn.Linear(decoder_dim, self.vocab_size) 439 | self.init_weights() # initialize some layers with the uniform distribution 440 | 441 | def init_weights(self): 442 | """ 443 | Initializes some parameters with values from the uniform distribution, for easier convergence. 444 | """ 445 | self.embedding.weight.data.uniform_(-0.1, 0.1) 446 | self.fc.bias.data.fill_(0) 447 | self.fc.weight.data.uniform_(-0.1, 0.1) 448 | 449 | def load_pretrained_embeddings(self, embeddings): 450 | """ 451 | Loads embedding layer with pre-trained embeddings. 452 | 453 | :param embeddings: pre-trained embeddings 454 | """ 455 | self.embedding.weight = nn.Parameter(embeddings) 456 | 457 | def fine_tune_embeddings(self, fine_tune=True): 458 | """ 459 | Allow fine-tuning of embedding layer? (Only makes sense to not-allow if using pre-trained embeddings). 460 | 461 | :param fine_tune: Allow? 462 | """ 463 | for p in self.embedding.parameters(): 464 | p.requires_grad = fine_tune 465 | 466 | def init_hidden_state(self, encoder_out): 467 | """ 468 | Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images. 469 | 470 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) 471 | :return: hidden state, cell state 472 | """ 473 | mean_encoder_out = encoder_out.mean(dim=1) 474 | h = self.init_h(mean_encoder_out) # (batch_size, decoder_dim) 475 | c = self.init_c(mean_encoder_out) 476 | return h, c 477 | 478 | def custom(self, module): 479 | def custom_forward(*inputs): 480 | output1, output2 = module(inputs[0], inputs[1]) 481 | return output1, output2 482 | 483 | return custom_forward 484 | 485 | def run_function(self, module): 486 | def custom_forward(*inputs): 487 | output, hidden = module( 488 | inputs[0], (inputs[1], inputs[2]) 489 | ) 490 | return output, hidden 491 | 492 | return custom_forward 493 | 494 | def custom_2(self, module): 495 | def custom_forward(inputs): 496 | output = module(inputs) 497 | return output 498 | 499 | return custom_forward 500 | 501 | def forward(self, encoder_out, encoded_captions, caption_lengths): 502 | """ 503 | Forward propagation. 504 | 505 | :param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim) 506 | :param encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length) 507 | :param caption_lengths: caption lengths, a tensor of dimension (batch_size, 1) 508 | :return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices 509 | """ 510 | 511 | batch_size = encoder_out.size(0) 512 | encoder_dim = encoder_out.size(-1) 513 | vocab_size = self.vocab_size 514 | 515 | # Flatten image 516 | # (batch_size, num_pixels, encoder_dim) 517 | encoder_out = encoder_out.view(batch_size, -1, encoder_dim) # bs, 28*28, 512 518 | num_pixels = encoder_out.size(1) 519 | 520 | # Sort input data by decreasing lengths; why? apparent below 521 | caption_lengths, sort_ind = caption_lengths.squeeze( 522 | 1).sort(dim=0, descending=True) 523 | encoder_out = encoder_out[sort_ind] 524 | 525 | encoded_captions = encoded_captions[sort_ind] 526 | 527 | # Embedding 528 | # (batch_size, max_caption_length, embed_dim) # 16 529 | embeddings = self.embedding(encoded_captions) 530 | 531 | # Initialize LSTM structure state 532 | h, c = self.init_hidden_state(encoder_out) # (batch_size, decoder_dim) 533 | 534 | # We won't decode at the position, since we've finished generating as soon as we generate 535 | # So, decoding lengths are actual lengths - 1 536 | decode_lengths = (caption_lengths - 1).tolist() 537 | 538 | # Create tensors to hold word predicion scores and alphas 539 | predictions = torch.zeros(batch_size, max( 540 | decode_lengths), vocab_size).to(encoder_out.device) 541 | alphas = torch.zeros(batch_size, max( 542 | decode_lengths), num_pixels).to(encoder_out.device) 543 | 544 | # create hidden_states to generate cell 545 | hidden_states = [[] for x in range(batch_size)] 546 | 547 | # At each time-step, decode by 548 | # attention-weighing the encoder's output based on the decoder's previous hidden state output 549 | # then generate a new word in the decoder with the previous word and the attention weighted encoding 550 | # using teacher forcing, save hidden state h_k+1 if ground truth t_k is or > 551 | for t in range(max(decode_lengths)): 552 | batch_size_t = sum([l > t for l in decode_lengths]) 553 | if not USE_CHECKPOINT: 554 | attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], 555 | h[:batch_size_t]) 556 | else: 557 | attention_weighted_encoding, alpha = checkpoint(self.custom(self.attention), 558 | encoder_out[:batch_size_t], 559 | h[:batch_size_t]) 560 | # gating scalar, (batch_size_t, encoder_dim) 561 | if not USE_CHECKPOINT2: 562 | gate = self.sigmoid(self.f_beta(h[:batch_size_t])) 563 | else: 564 | gate = checkpoint(self.custom_2(self.f_beta), h[:batch_size_t]) 565 | gate = checkpoint(self.custom_2(self.sigmoid), gate) 566 | attention_weighted_encoding = gate * attention_weighted_encoding 567 | 568 | # hidden h_t+1 and ground_truth token t_t 569 | if not USE_CHECKPOINT2: 570 | h, c = self.decode_step( 571 | torch.cat([embeddings[:batch_size_t, t, :], 572 | attention_weighted_encoding], dim=1), 573 | (h[:batch_size_t], c[:batch_size_t])) # (batch_size_t, decoder_dim) 574 | else: 575 | h, c = checkpoint(self.run_function(self.decode_step), 576 | torch.cat([embeddings[:batch_size_t, t, :], 577 | attention_weighted_encoding], 578 | dim=1), 579 | h[:batch_size_t], c[:batch_size_t]) 580 | 581 | # get and save hidden state h_k+1 when groun_truth token in t_k is or > 582 | for i in range(batch_size_t): 583 | if self.vocab[""] == encoded_captions[i][t].cpu().numpy() or \ 584 | self.vocab[">"] == encoded_captions[i][t].cpu().numpy(): 585 | hidden_states[i].append(h[i]) 586 | 587 | preds = self.fc(self.dropout(h)) # (batch_size_t, vocab_size) 588 | predictions[:batch_size_t, t, :] = preds 589 | alphas[:batch_size_t, t, :] = alpha 590 | 591 | return predictions, encoded_captions, decode_lengths, alphas, hidden_states, sort_ind 592 | 593 | 594 | class DecoderCellPerImageWithAttention(nn.Module): 595 | def __init__(self, attention_dim, embed_dim, decoder_dim, 596 | decoder_structure_dim, vocab_size, encoder_dim=512, dropout=0.5, lstm_bias=True): 597 | """ 598 | :param attention_dim: size of attention network 599 | :param embed_dim: embedding size 600 | :param decoder_dim: size of decoder's RNN 601 | :param vocab_size: size of vocabulary 602 | :param encoder_dim: feature size of encoded images 603 | :param dropout: dropout 604 | """ 605 | super(DecoderCellPerImageWithAttention, self).__init__() 606 | 607 | self.encoder_dim = encoder_dim 608 | self.attention_dim = attention_dim 609 | self.embed_dim = embed_dim 610 | self.decoder_dim = decoder_dim 611 | self.vocab_size = vocab_size 612 | self.dropout = dropout 613 | 614 | self.attention = Soft_Attention( 615 | encoder_dim, decoder_dim, attention_dim, decoder_structure_dim) # attention network 616 | 617 | self.embedding = nn.Embedding(vocab_size, embed_dim) # embedding layer 618 | self.dropout = nn.Dropout(p=self.dropout) 619 | self.decode_step = nn.LSTMCell( 620 | embed_dim + encoder_dim, decoder_dim, bias=lstm_bias) # decoding LSTMCell 621 | # linear layer to find initial hidden state of LSTMCell 622 | self.init_h = nn.Linear(encoder_dim, decoder_dim) 623 | # linear layer to find initial cell state of LSTMCell 624 | self.init_c = nn.Linear(encoder_dim, decoder_dim) 625 | 626 | # linear layer to create a sigmoid-activated gate 627 | self.f_beta = nn.Linear(decoder_dim, encoder_dim) 628 | self.sigmoid = nn.Sigmoid() 629 | # linear layer to find scores over vocabulary 630 | self.fc = nn.Linear(decoder_dim, vocab_size) 631 | self.init_weights() # initialize some layers with the uniform distribution 632 | 633 | def init_weights(self): 634 | """ 635 | Initializes some parameters with values from the uniform distribution, for easier convergence. 636 | """ 637 | self.embedding.weight.data.uniform_(-0.1, 0.1) 638 | self.fc.bias.data.fill_(0) 639 | self.fc.weight.data.uniform_(-0.1, 0.1) 640 | 641 | def load_pretrained_embeddings(self, embeddings): 642 | """ 643 | Loads embedding layer with pre-trained embeddings. 644 | 645 | :param embeddings: pre-trained embeddings 646 | """ 647 | self.embedding.weight = nn.Parameter(embeddings) 648 | 649 | def fine_tune_embeddings(self, fine_tune=True): 650 | """ 651 | Allow fine-tuning of embedding layer? (Only makes sense to not-allow if using pre-trained embeddings). 652 | 653 | :param fine_tune: Allow? 654 | """ 655 | for p in self.embedding.parameters(): 656 | p.requires_grad = fine_tune 657 | 658 | def init_hidden_state(self, encoder_out): 659 | """ 660 | Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images. 661 | 662 | :param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim) 663 | :return: hidden state, cell state 664 | """ 665 | mean_encoder_out = encoder_out.mean(dim=1) 666 | h = self.init_h(mean_encoder_out) # (batch_size, decoder_dim) 667 | c = self.init_c(mean_encoder_out) 668 | return h, c 669 | 670 | def custom(self, module): 671 | def custom_forward(*inputs): 672 | output1, output2 = module(inputs[0], inputs[1], inputs[2]) 673 | return output1, output2 674 | 675 | return custom_forward 676 | 677 | def run_function(self, module): 678 | def custom_forward(*inputs): 679 | output, hidden = module( 680 | inputs[0], (inputs[1], inputs[2]) 681 | ) 682 | return output, hidden 683 | 684 | return custom_forward 685 | 686 | def custom_2(self, module): 687 | def custom_forward(inputs): 688 | output = module(inputs) 689 | return output 690 | 691 | return custom_forward 692 | 693 | def forward(self, encoder_out, encoded_captions, caption_lengths, hidden_state_structures, batch_size): 694 | 695 | encoder_dim = encoder_out.size(-1) 696 | vocab_size = self.vocab_size 697 | 698 | # Flatten image 699 | # expand encoder_out (batch_size, num_pixels, encoder_dim) 700 | encoder_out = encoder_out.view(-1, encoder_dim) 701 | encoder_out = encoder_out.squeeze(1).repeat(batch_size, 1, 1) 702 | 703 | num_pixels = encoder_out.size(1) 704 | 705 | # Sort input data by decreasing lengths; why? apparent below 706 | caption_lengths, sort_ind = caption_lengths.unsqueeze( 707 | 1).squeeze( 708 | 1).sort(dim=0, descending=True) 709 | # caption_lengths, sort_ind = caption_lengths.sort(descending=True) 710 | encoder_out = encoder_out[sort_ind] 711 | encoded_captions = encoded_captions[sort_ind] 712 | # hidden_state_structures size (batch_size, decoder_dim) 713 | hidden_state_structures = hidden_state_structures[sort_ind] 714 | # Embedding 715 | # (batch_size, max_caption_length, embed_dim) 716 | embeddings = self.embedding(encoded_captions) 717 | 718 | # Initialize LSTM cell state 719 | # size is (batch_size, decoder_dim) 720 | h, c = self.init_hidden_state(encoder_out) 721 | # We won't decode at the position, since we've finished generating as soon as we generate 722 | # So, decoding lengths are actual lengths - 1 723 | decode_lengths = (caption_lengths - 1).tolist() 724 | 725 | # Create tensors to hold word predicion scores and alphas 726 | predictions = torch.zeros(batch_size, max( 727 | decode_lengths), vocab_size).to(encoder_out.device) 728 | alphas = torch.zeros(batch_size, max( 729 | decode_lengths), num_pixels).to(encoder_out.device) 730 | 731 | # decode with time step 732 | # attention-weighing the encoder's output based on the decoder's previous hidden state output 733 | for t in range(max(decode_lengths)): 734 | batch_size_t = sum([l > t for l in decode_lengths]) 735 | if not USE_CHECKPOINT: 736 | attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t], 737 | h[:batch_size_t], 738 | hidden_state_structures[:batch_size_t]) 739 | else: 740 | attention_weighted_encoding, alpha = checkpoint(self.custom(self.attention), 741 | encoder_out[:batch_size_t], 742 | h[:batch_size_t], 743 | hidden_state_structures[:batch_size_t]) 744 | # gating scalar, (batch_size_t, encoder_dim) 745 | if not USE_CHECKPOINT2: 746 | gate = self.sigmoid(self.f_beta(h[:batch_size_t])) 747 | else: 748 | gate = checkpoint(self.custom_2(self.f_beta), h[:batch_size_t]) 749 | gate = checkpoint(self.custom_2(self.sigmoid), gate) 750 | attention_weighted_encoding = gate * attention_weighted_encoding 751 | 752 | # concat hidden state structure + attention_weighted_encoding 753 | # attention_weighted_encoding = torch.cat( 754 | # (attention_weighted_encoding, hidden_state_structures[:batch_size_t]), dim=1) 755 | # hidden h_t+1 and ground_truth token t_t 756 | if not USE_CHECKPOINT2: 757 | h, c = self.decode_step( 758 | torch.cat([embeddings[:batch_size_t, t, :], 759 | attention_weighted_encoding], dim=1), 760 | (h[:batch_size_t], c[:batch_size_t])) # (batch_size_t, decoder_dim) 761 | else: 762 | h, c = checkpoint(self.run_function(self.decode_step), 763 | torch.cat([embeddings[:batch_size_t, t, :], 764 | attention_weighted_encoding], 765 | dim=1), 766 | h[:batch_size_t], c[:batch_size_t]) 767 | 768 | preds = self.fc(self.dropout(h)) # (batch_size_t, vocab_size) 769 | predictions[:batch_size_t, t, :] = preds 770 | alphas[:batch_size_t, t, :] = alpha 771 | 772 | return predictions, encoded_captions, decode_lengths, alphas, sort_ind 773 | --------------------------------------------------------------------------------