├── models ├── __init__.py ├── encoder.py ├── sar.py ├── backbone.py └── decoder.py ├── utils ├── __init__.py ├── attention_map.py └── dataproc.py ├── dataset ├── __init__.py └── dataset.py ├── .gitignore ├── misc ├── iiit_0.jpg ├── iiit_0_0.png ├── iiit_0_1.png ├── syn90k_0.jpg ├── syn90k_0_0.png ├── syn90k_0_1.png ├── syn90k_0_2.png ├── syn90k_0_3.png ├── syn90k_0_4.png ├── syn90k_0_5.png ├── syn90k_0_6.png ├── syn90k_0_7.png ├── svt_results.png ├── synthtext_0.jpg ├── iiit5k_results.png ├── syn90k_results.png ├── synthtext_0_0.png ├── synthtext_0_1.png ├── synthtext_0_2.png └── synthtext_results.png ├── requirements.txt ├── LICENSE ├── README.md ├── inference.py └── train.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .DS_Store 3 | logs 4 | .ipynb_checkpoints 5 | -------------------------------------------------------------------------------- /misc/iiit_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuch37/sar-pytorch/HEAD/misc/iiit_0.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | editdistance 2 | opencv-python 3 | scipy 4 | torch 5 | torchvision 6 | -------------------------------------------------------------------------------- /misc/iiit_0_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuch37/sar-pytorch/HEAD/misc/iiit_0_0.png -------------------------------------------------------------------------------- /misc/iiit_0_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuch37/sar-pytorch/HEAD/misc/iiit_0_1.png -------------------------------------------------------------------------------- /misc/syn90k_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuch37/sar-pytorch/HEAD/misc/syn90k_0.jpg -------------------------------------------------------------------------------- /misc/syn90k_0_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuch37/sar-pytorch/HEAD/misc/syn90k_0_0.png -------------------------------------------------------------------------------- /misc/syn90k_0_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuch37/sar-pytorch/HEAD/misc/syn90k_0_1.png -------------------------------------------------------------------------------- /misc/syn90k_0_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuch37/sar-pytorch/HEAD/misc/syn90k_0_2.png -------------------------------------------------------------------------------- /misc/syn90k_0_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuch37/sar-pytorch/HEAD/misc/syn90k_0_3.png -------------------------------------------------------------------------------- /misc/syn90k_0_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuch37/sar-pytorch/HEAD/misc/syn90k_0_4.png -------------------------------------------------------------------------------- /misc/syn90k_0_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuch37/sar-pytorch/HEAD/misc/syn90k_0_5.png -------------------------------------------------------------------------------- /misc/syn90k_0_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuch37/sar-pytorch/HEAD/misc/syn90k_0_6.png -------------------------------------------------------------------------------- /misc/syn90k_0_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuch37/sar-pytorch/HEAD/misc/syn90k_0_7.png -------------------------------------------------------------------------------- /misc/svt_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuch37/sar-pytorch/HEAD/misc/svt_results.png -------------------------------------------------------------------------------- /misc/synthtext_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuch37/sar-pytorch/HEAD/misc/synthtext_0.jpg -------------------------------------------------------------------------------- /misc/iiit5k_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuch37/sar-pytorch/HEAD/misc/iiit5k_results.png -------------------------------------------------------------------------------- /misc/syn90k_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuch37/sar-pytorch/HEAD/misc/syn90k_results.png -------------------------------------------------------------------------------- /misc/synthtext_0_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuch37/sar-pytorch/HEAD/misc/synthtext_0_0.png -------------------------------------------------------------------------------- /misc/synthtext_0_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuch37/sar-pytorch/HEAD/misc/synthtext_0_1.png -------------------------------------------------------------------------------- /misc/synthtext_0_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuch37/sar-pytorch/HEAD/misc/synthtext_0_2.png -------------------------------------------------------------------------------- /misc/synthtext_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liuch37/sar-pytorch/HEAD/misc/synthtext_results.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Chun-Hao Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/encoder.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is to construct encoder for SAR - two layer LSTMs 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | 7 | __all__ = ['encoder'] 8 | 9 | class encoder(nn.Module): 10 | def __init__(self, H, C, hidden_units=512, layers=2, keep_prob=1.0, device='cpu'): 11 | super(encoder, self).__init__() 12 | self.maxpool = nn.MaxPool2d(kernel_size=(H,1), stride=1) 13 | self.lstm = nn.LSTM(input_size=C, hidden_size=hidden_units, num_layers=layers, batch_first=True, dropout=keep_prob) 14 | self.layers = layers 15 | self.hidden_units = hidden_units 16 | self.device = device 17 | 18 | def forward(self, x): 19 | self.lstm.flatten_parameters() 20 | # x is feature map in [batch, C, H, W] 21 | # Initialize hidden state with zeros 22 | h_0 = torch.zeros(self.layers*1, x.size(0), self.hidden_units).to(self.device) 23 | # Initialize cell state 24 | c_0 = torch.zeros(self.layers*1, x.size(0), self.hidden_units).to(self.device) 25 | x = self.maxpool(x) # [batch, C, 1, W] 26 | x = torch.squeeze(x) # [batch, C, W] 27 | if len(x.size()) == 2: # [C, W] 28 | x = x.unsqueeze(0) # [batch, C, W] 29 | x = x.permute(0,2,1) # [batch, W, C] 30 | _, (h, _) = self.lstm(x, (h_0, c_0)) # h with shape [layers*1, batch, hidden_uints] 31 | 32 | return h[-1] # shape [batch, hidden_units] 33 | 34 | # unit test 35 | if __name__ == '__main__': 36 | 37 | batch_size = 32 38 | Height = 48 39 | Width = 160 40 | Channel = 3 41 | input_feature = torch.randn(batch_size,Channel,Height,Width) 42 | print("Input feature size is:",input_feature.shape) 43 | 44 | encoder_model = encoder(Height, Channel, hidden_units=512, layers=2, keep_prob=1.0) 45 | output_encoder = encoder_model(input_feature) 46 | 47 | print("Output feature of encoder size is:",output_encoder.shape) # (batch, hidden_units) -------------------------------------------------------------------------------- /utils/attention_map.py: -------------------------------------------------------------------------------- 1 | ''' 2 | The code for postprocessing. 3 | ''' 4 | import torch 5 | import numpy as np 6 | import cv2 7 | import matplotlib.pyplot as plt 8 | 9 | def attention_map(predict_word, x, attention_weight): 10 | ''' 11 | Input: 12 | predict_word: string of predicted word 13 | x: tensor of original image [C, H, W], channel in BGR order, normalized to [-1, 1] 14 | attention_weight: tensor of attention weights [seq_len, 1, feature_H, feature_W] in [0, 1] 15 | Output: 16 | heatmaps: a list of heatmap [H, W, C=3] 17 | overlaps: a list of overlapped image [H, W, C=3] 18 | ''' 19 | T = len(predict_word) 20 | x = x.permute(1,2,0).detach().cpu().numpy() # [H, W, C] 21 | x = (((x+1)/2)*255).astype(np.uint8) # normalized to [0,255] 22 | H, W, C = x.shape 23 | heatmaps = [] 24 | overlaps = [] 25 | for t in range(T): 26 | att_map = attention_weight[t,:,:,:].permute(1,2,0).detach().cpu().numpy() # [feature_H, feature_W, 1] 27 | att_map = cv2.resize(att_map, (W,H)) # [H, W] 28 | att_map = (att_map*255).astype(np.uint8) 29 | heatmap = cv2.applyColorMap(att_map, cv2.COLORMAP_JET) # [H, W, C] 30 | overlap = cv2.addWeighted(heatmap, 0.6, x, 0.4, 0) 31 | heatmaps.append(heatmap) 32 | overlaps.append(overlap) 33 | 34 | return heatmaps, overlaps 35 | 36 | # unit test 37 | if __name__ == '__main__': 38 | 39 | img_path = '../svt/img/00_16.jpg' 40 | 41 | predict_word = "hello" 42 | 43 | x = cv2.imread(img_path) 44 | 45 | x = (x-127.5)/127.5 # normalization 46 | 47 | x = torch.from_numpy(x) 48 | 49 | x = x.permute(2, 0, 1) # [C, H, W] 50 | 51 | attention_weight = torch.rand((40, 1, 384, 512)) 52 | 53 | attention_weight[:,:,250:300,150:200] = 1.0 54 | 55 | attention_weight[:,:,0:50,0:50] = 0.0 56 | 57 | attention_weight[:,:,300:350,450:500] = 0.5 58 | 59 | heatmaps, overlaps = attention_map(predict_word, x, attention_weight) 60 | 61 | heatmap_single = cv2.cvtColor(heatmaps[-1], cv2.COLOR_BGR2RGB) 62 | overlap_single = cv2.cvtColor(overlaps[-1], cv2.COLOR_BGR2RGB) 63 | 64 | plt.figure(0) 65 | plt.imshow(heatmap_single) 66 | 67 | plt.figure(1) 68 | plt.imshow(overlap_single) 69 | 70 | plt.show() -------------------------------------------------------------------------------- /models/sar.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is to construct the complete SAR model by combining backbone+encoder+decoder as one integrated model. 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | from .backbone import backbone 7 | from .encoder import encoder 8 | from .decoder import decoder 9 | 10 | __all__ = ['sar'] 11 | 12 | class sar(nn.Module): 13 | def __init__(self, channel, feature_height, feature_width, embedding_dim, output_classes, hidden_units=512, layers=2, keep_prob=1.0, seq_len=40, device='cpu'): 14 | super(sar, self).__init__() 15 | ''' 16 | channel: channel of input image 17 | feature_height: feature height of backbone feature map 18 | embedding_dim: embedding dimension for a word 19 | output_classes: number of output classes for the one hot encoding of a word 20 | hidden_units: hidden units for both LSTM encoder and decoder 21 | layers: layers for both LSTM encoder and decoder, should be set to 2 22 | keep_prob: keep_prob probability dropout for LSTM encoder 23 | seq_len: decoding sequence length 24 | ''' 25 | self.backbone = backbone(channel) 26 | self.encoder_model = encoder(feature_height, 512, hidden_units, layers, keep_prob, device) 27 | self.decoder_model = decoder(output_classes, feature_height, feature_width, 512, hidden_units, seq_len, device) 28 | self.embedding_dim = embedding_dim 29 | self.output_classes = output_classes 30 | self.hidden_units = hidden_units 31 | self.layers = layers 32 | self.keep_prob = keep_prob 33 | self.seq_len = seq_len 34 | self.device = device 35 | 36 | def forward(self,x,y): 37 | ''' 38 | x: input images [batch, channel, height, width] 39 | y: output labels [batch, seq_len, output_classes] 40 | ''' 41 | V = self.backbone(x) # (batch, feature_depth, feature_height, feature_width) 42 | hw = self.encoder_model(V) # (batch, hidden_units) 43 | outputs, attention_weights = self.decoder_model(hw, y, V) # [batch, seq_len, output_classes], [batch, seq_len, 1, feature_height, feature_width] 44 | 45 | return outputs, attention_weights, V, hw 46 | 47 | # unit test 48 | if __name__ == '__main__': 49 | ''' 50 | Need to change the import to do unit test: 51 | from backbone import backbone 52 | from encoder import encoder 53 | from decoder import decoder 54 | ''' 55 | torch.manual_seed(0) 56 | 57 | batch_size = 2 58 | Height = 12 59 | Width = 24 60 | Channel = 3 61 | output_classes = 94 62 | embedding_dim = 512 63 | hidden_units = 512 64 | layers = 2 65 | keep_prob = 1.0 66 | seq_len = 40 67 | 68 | feature_height = Height // 4 69 | feature_width = Width // 8 70 | 71 | y = torch.randn(batch_size, seq_len, output_classes) 72 | x = torch.randn(batch_size, Channel, Height, Width) 73 | print("Input image size is:", x.shape) 74 | print("Input label size is:", y.shape) 75 | 76 | model = sar(Channel, feature_height, feature_width, embedding_dim, output_classes, hidden_units, layers, keep_prob, seq_len) 77 | 78 | predict1, att_weights1, V1, hw1 = model.train()(x,y) 79 | print("Prediction size is:", predict1.shape) 80 | print("Attention weight size is:", att_weights1.shape) 81 | 82 | predict2, att_weights2, V2, hw2 = model.train()(x,y) 83 | print("Prediction size is:", predict2.shape) 84 | print("Attention weight size is:", att_weights2.shape) 85 | 86 | print("Difference:", torch.sum(predict1-predict2)) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Show, Attend and Read - A PyTorch Implementation 2 | 3 | Implementation of Show, Attend and Read: A Simple and Strong Baseline for Irregular Text Recognition in AAAI 2019, with PyTorch >= v1.4.0. 4 | 5 | ## Task 6 | 7 | - [x] Backbone model 8 | - [x] Encoder model 9 | - [x] Decoder model 10 | - [x] Integrated model 11 | - [x] Data processing 12 | - [x] Training pipeline 13 | - [x] Inference pipeline 14 | 15 | ## Supported Dataset 16 | 17 | - [x] Street View Text: http://vision.ucsd.edu/~kai/svt/ 18 | - [x] IIIT5K: https://cvit.iiit.ac.in/research/projects/cvit-projects/the-iiit-5k-word-dataset 19 | - [x] Syn90k: https://www.robots.ox.ac.uk/~vgg/data/text/ 20 | - [x] SynthText: https://www.robots.ox.ac.uk/~vgg/data/scenetext/ 21 | 22 | ## Command 23 | 24 | ### Training 25 | 26 | `` 27 | python train.py --batch 32 --epoch 5000 --dataset ./svt --dataset_type svt --gpu True 28 | `` 29 | 30 | ### Inference 31 | 32 | `` 33 | python inference.py --batch 32 --input input_folder --model model_path --gpu True 34 | `` 35 | 36 | ## Results 37 | 38 | ### SVT 39 | ![Statstics for SVT training](https://github.com/liuch37/sar-pytorch/blob/master/misc/svt_results.png) 40 | 41 | ### IIIT5K 42 | ![Statstics for IIIT5K training](https://github.com/liuch37/sar-pytorch/blob/master/misc/iiit5k_results.png) 43 | 44 | Input: 45 | 46 | ![Input image](https://github.com/liuch37/sar-pytorch/blob/master/misc/iiit_0.jpg) 47 | 48 | Output attention map per character: 49 | 50 | ![Attention map for char 0](https://github.com/liuch37/sar-pytorch/blob/master/misc/iiit_0_0.png) 51 | ![Attention map for char 1](https://github.com/liuch37/sar-pytorch/blob/master/misc/iiit_0_1.png) 52 | 53 | ### Syn90K (10k for training/3k for testing) 54 | ![Statstics for Syn90K training](https://github.com/liuch37/sar-pytorch/blob/master/misc/syn90k_results.png) 55 | 56 | Input: 57 | 58 | ![Input image](https://github.com/liuch37/sar-pytorch/blob/master/misc/syn90k_0.jpg) 59 | 60 | Output attention map per character: 61 | 62 | ![Attention map for char 0](https://github.com/liuch37/sar-pytorch/blob/master/misc/syn90k_0_0.png) 63 | ![Attention map for char 1](https://github.com/liuch37/sar-pytorch/blob/master/misc/syn90k_0_1.png) 64 | ![Attention map for char 2](https://github.com/liuch37/sar-pytorch/blob/master/misc/syn90k_0_2.png) 65 | ![Attention map for char 3](https://github.com/liuch37/sar-pytorch/blob/master/misc/syn90k_0_3.png) 66 | ![Attention map for char 4](https://github.com/liuch37/sar-pytorch/blob/master/misc/syn90k_0_4.png) 67 | ![Attention map for char 5](https://github.com/liuch37/sar-pytorch/blob/master/misc/syn90k_0_5.png) 68 | ![Attention map for char 6](https://github.com/liuch37/sar-pytorch/blob/master/misc/syn90k_0_6.png) 69 | ![Attention map for char 7](https://github.com/liuch37/sar-pytorch/blob/master/misc/syn90k_0_7.png) 70 | 71 | ### SynthText (80k for training/20k for testing) 72 | ![Statstics for SynthText training](https://github.com/liuch37/sar-pytorch/blob/master/misc/synthtext_results.png) 73 | 74 | Input: 75 | 76 | ![Input image](https://github.com/liuch37/sar-pytorch/blob/master/misc/synthtext_0.jpg) 77 | 78 | Output attention map per character: 79 | 80 | ![Attention map for char 0](https://github.com/liuch37/sar-pytorch/blob/master/misc/synthtext_0_0.png) 81 | ![Attention map for char 1](https://github.com/liuch37/sar-pytorch/blob/master/misc/synthtext_0_1.png) 82 | ![Attention map for char 2](https://github.com/liuch37/sar-pytorch/blob/master/misc/synthtext_0_2.png) 83 | 84 | ## Source 85 | 86 | [1] Original paper: https://arxiv.org/abs/1811.00751 87 | 88 | [2] Official code by the authors in torch: https://github.com/wangpengnorman/SAR-Strong-Baseline-for-Text-Recognition 89 | 90 | [3] A TensorFlow implementation: https://github.com/Pay20Y/SAR_TF 91 | 92 | 93 | -------------------------------------------------------------------------------- /utils/dataproc.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is to provide necessary utilization functions. 3 | ''' 4 | import string 5 | import editdistance 6 | import numpy as np 7 | 8 | def end_cut(indices, char2id, id2char): 9 | ''' 10 | indices: numpy array or list of character indices 11 | charid: char to id conversion 12 | id2char: id to char conversion 13 | ''' 14 | cut_indices = [] 15 | for id in indices: 16 | if id != char2id['END']: 17 | if id != char2id['UNK'] and id != char2id['PAD']: 18 | cut_indices.append(id2char[id]) 19 | else: 20 | break 21 | return ''.join(cut_indices) 22 | 23 | def performance_evaluate(pred_choice, target, voc, char2id, id2char, metrics_type): 24 | ''' 25 | pred_choice: predicted numpy array of [batch_size, seq_len] with index in output_classes 26 | target: true numpy array of [batch_size, seq_len] with index in output_classes 27 | voc: vocabular dictionary 28 | charid: char to id conversion 29 | id2char: id to char conversion 30 | metrics_type: evaluation metric name 31 | ''' 32 | batch_size = target.shape[0] 33 | predicts = [] 34 | labels = [] 35 | for batch in range(batch_size): 36 | predict_indices = pred_choice[batch] 37 | tareget_indices = target[batch] 38 | 39 | predicts.append(end_cut(predict_indices, char2id, id2char)) 40 | labels.append(end_cut(tareget_indices, char2id, id2char)) 41 | 42 | if metrics_type == 'accuracy': 43 | acc_list = [(pred == tar) for pred, tar in zip(predicts, labels)] 44 | accuracy = 1.0 * sum(acc_list) / len(acc_list) 45 | 46 | return accuracy, acc_list, predicts, labels 47 | elif metrics_type == 'editdistance': 48 | ed_list = [editdistance.eval(pred, targ) for pred, targ in zip(predicts, labels)] 49 | eds = 1.0 * sum(ed_list) / len(ed_list) 50 | 51 | return eds, ed_list, predicts, labels 52 | 53 | return -1 54 | 55 | # unit test 56 | if __name__ == '__main__': 57 | import sys 58 | sys.path.append("..") 59 | 60 | from dataset.dataset import dictionary_generator 61 | 62 | batch_size = 2 63 | seq_len = 40 64 | voc, char2id, id2char = dictionary_generator() 65 | print("Vocabulary size is:", len(voc)) 66 | 67 | pred_choice = np.random.randint(0,len(voc),(batch_size, seq_len)) # [batch_size, seq_len] 68 | target = np.array([[47, 44, 57, 44, 49, 42, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 69 | 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 70 | 95, 95, 95, 94], 71 | [54, 55, 36, 49, 39, 36, 53, 39, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 72 | 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 95, 73 | 95, 95, 95, 94]]) 74 | 75 | word = end_cut(pred_choice[0], char2id, id2char) 76 | print("First decode word is:", word) 77 | word = end_cut(pred_choice[1], char2id, id2char) 78 | print("Second decode word is:", word) 79 | 80 | word = end_cut(target[0], char2id, id2char) 81 | print("First decode word is:", word) 82 | word = end_cut(target[1], char2id, id2char) 83 | print("Second decode word is:", word) 84 | 85 | metric, metric_list, predicts, labels = performance_evaluate(pred_choice, target, voc, char2id, id2char, 'accuracy') 86 | print("Accuracy:", metric) 87 | print("Accuracy list:", metric_list) 88 | print("Predicted words:", predicts) 89 | print("Labeled words:", labels) 90 | metric, metric_list, predicts, labels = performance_evaluate(pred_choice, target, voc, char2id, id2char, 'editdistance') 91 | print("Edit distance:", metric) 92 | print("Edit distance list:", metric_list) 93 | print("Predicted words:", predicts) 94 | print("Labeled words:", labels) -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | ''' 2 | THis is the main inference code. 3 | ''' 4 | import os 5 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" # set GPU id at the very begining 6 | import argparse 7 | import random 8 | import math 9 | import torch 10 | import torch.nn.parallel 11 | import torch.optim as optim 12 | import torch.utils.data 13 | import torch.nn.functional as F 14 | from torch.multiprocessing import freeze_support 15 | import cv2 16 | import pdb 17 | # internal package 18 | from dataset import dataset 19 | from dataset.dataset import dictionary_generator 20 | from models.sar import sar 21 | from utils.dataproc import end_cut 22 | from utils.attention_map import attention_map 23 | 24 | # main function: 25 | if __name__ == '__main__': 26 | freeze_support() 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--batch', type=int, default=32, help='batch size') 29 | parser.add_argument( 30 | '--worker', type=int, default=4, help='number of data loading workers') 31 | parser.add_argument('--input', type=str, default='', help='input folder') 32 | parser.add_argument('--output', type=str, default='predict.txt', help='output file name') 33 | parser.add_argument('--model', type=str, default='', help='model path') 34 | parser.add_argument('--gpu', type=bool, default=False, help="GPU being used or not") 35 | 36 | opt = parser.parse_args() 37 | print(opt) 38 | 39 | # turn on GPU for models: 40 | if opt.gpu == False: 41 | device = torch.device("cpu") 42 | print("CPU being used!") 43 | else: 44 | if torch.cuda.is_available() == True and opt.gpu == True: 45 | device = torch.device("cuda") 46 | print("GPU being used!") 47 | else: 48 | device = torch.device("cpu") 49 | print("CPU being used!") 50 | 51 | # set training parameters 52 | Height = 48 53 | Width = 64 54 | feature_height = Height // 4 55 | feature_width = Width // 8 56 | Channel = 3 57 | voc, char2id, id2char = dictionary_generator() 58 | output_classes = len(voc) 59 | embedding_dim = 512 60 | hidden_units = 512 61 | layers = 2 62 | keep_prob = 1.0 63 | seq_len = 40 64 | batch_size = opt.batch 65 | output_path = opt.output 66 | trained_model_path = opt.model 67 | input_path = opt.input 68 | worker = opt.worker 69 | 70 | # load test data 71 | test_dataset = dataset.test_dataset_builder(Height, Width, input_path) 72 | 73 | # make dataloader 74 | test_dataloader = torch.utils.data.DataLoader( 75 | test_dataset, 76 | batch_size=batch_size, 77 | shuffle=False, 78 | num_workers=int(worker)) 79 | 80 | # load model 81 | print("Create model......") 82 | model = sar(Channel, feature_height, feature_width, embedding_dim, output_classes, hidden_units, layers, keep_prob, seq_len, device) 83 | 84 | if torch.cuda.is_available() == True and opt.gpu == True: 85 | model.load_state_dict(torch.load(trained_model_path, map_location=lambda storage, loc: storage), strict=False) 86 | model = torch.nn.DataParallel(model).to(device) 87 | else: 88 | model.load_state_dict(torch.load(trained_model_path, map_location=lambda storage, loc: storage), strict=False) 89 | model = model.to(device) 90 | 91 | if input_path == '': 92 | print("Error: Empty --input!") 93 | exit(1) 94 | 95 | if os.path.isfile(output_path): 96 | os.remove(output_path) 97 | 98 | # run inference 99 | print("Inference starts......") 100 | for i, data in enumerate(test_dataloader): 101 | print("processing for batch index:", i) 102 | x = data[0] # [batch_size, Channel, Height, Width] 103 | image_name = data[1] # [batch_size, image_name] 104 | x = x.to(device) 105 | model = model.eval() 106 | predict, att_weights, _, _ = model(x, 0) 107 | batch_size_current = predict.shape[0] 108 | pred_choice = predict.max(2)[1] # [batch_size, seq_len] 109 | with open(output_path, "a") as f: 110 | for idx in range(batch_size_current): 111 | # prediction evaluation 112 | predict_word = end_cut(pred_choice[idx].detach().cpu().numpy(), char2id, id2char) 113 | # generate attention heatmap 114 | heatmaps, overlayed_images = attention_map(predict_word, x[idx], att_weights[idx,:,:,:,:]) 115 | ''' 116 | for i, img in enumerate(overlayed_images): 117 | cv2.imwrite('./attmap/'+image_name[idx][:-4]+'_'+str(i)+'.png', img) 118 | ''' 119 | # write to output path 120 | f.write("{} {}\n".format(image_name[idx], predict_word)) 121 | print("Inference done!") -------------------------------------------------------------------------------- /models/backbone.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is to construct backbone network for SAR - with 13 layers of customized ResNet. 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | 7 | __all__ = ['basicblock','backbone'] 8 | 9 | class basicblock(nn.Module): 10 | def __init__(self, depth_in, output_dim, kernel_size, stride): 11 | super(basicblock, self).__init__() 12 | self.identity = nn.Identity() 13 | self.conv_res = nn.Conv2d(depth_in, output_dim, kernel_size=1, stride=1) 14 | self.batchnorm_res = nn.BatchNorm2d(output_dim) 15 | self.conv1 = nn.Conv2d(depth_in, output_dim, kernel_size=kernel_size, stride=stride, padding=1) 16 | self.conv2 = nn.Conv2d(output_dim, output_dim, kernel_size=kernel_size, stride=stride, padding=1) 17 | self.batchnorm1 = nn.BatchNorm2d(output_dim) 18 | self.batchnorm2 = nn.BatchNorm2d(output_dim) 19 | self.relu1 = nn.ReLU(inplace=True) 20 | self.relu2 = nn.ReLU(inplace=True) 21 | self.depth_in = depth_in 22 | self.output_dim = output_dim 23 | 24 | def forward(self, x): 25 | # create shortcut path 26 | if self.depth_in == self.output_dim: 27 | residual = self.identity(x) 28 | else: 29 | residual = self.conv_res(x) 30 | residual = self.batchnorm_res(residual) 31 | out = self.conv1(x) 32 | out = self.batchnorm1(out) 33 | out = self.relu1(out) 34 | out = self.conv2(out) 35 | out = self.batchnorm2(out) 36 | 37 | out += residual 38 | out = self.relu2(out) 39 | 40 | return out 41 | 42 | class backbone(nn.Module): 43 | 44 | def __init__(self, input_dim): 45 | super(backbone, self).__init__() 46 | self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=3, stride=1, padding=1) 47 | self.batchnorm1 = nn.BatchNorm2d(64) 48 | self.relu1 = nn.ReLU(inplace=True) 49 | self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) 50 | self.batchnorm2 = nn.BatchNorm2d(128) 51 | self.relu2 = nn.ReLU(inplace=True) 52 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2) 53 | # Block 1 starts 54 | self.basicblock1 = basicblock(128, 256, kernel_size=3, stride=1) 55 | # Block 1 ends 56 | self.conv3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 57 | self.batchnorm3 = nn.BatchNorm2d(256) 58 | self.relu3 = nn.ReLU(inplace=True) 59 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2) 60 | # Block 2 starts 61 | self.basicblock2 = basicblock(256, 256, kernel_size=3, stride=1) 62 | self.basicblock3 = basicblock(256, 256, kernel_size=3, stride=1) 63 | # Block 2 ends 64 | self.conv4 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 65 | self.batchnorm4 = nn.BatchNorm2d(256) 66 | self.relu4 = nn.ReLU(inplace=True) 67 | self.maxpool3 = nn.MaxPool2d(kernel_size=(1,2), stride=(1,2)) 68 | # Block 5 starts 69 | self.basicblock4 = basicblock(256, 512, kernel_size=3, stride=1) 70 | self.basicblock5 = basicblock(512, 512, kernel_size=3, stride=1) 71 | self.basicblock6 = basicblock(512, 512, kernel_size=3, stride=1) 72 | self.basicblock7 = basicblock(512, 512, kernel_size=3, stride=1) 73 | self.basicblock8 = basicblock(512, 512, kernel_size=3, stride=1) 74 | # Block 5 ends 75 | self.conv5 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 76 | self.batchnorm5 = nn.BatchNorm2d(512) 77 | self.relu5 = nn.ReLU(inplace=True) 78 | # Block 3 starts 79 | self.basicblock9 = basicblock(512, 512, kernel_size=3, stride=1) 80 | self.basicblock10 = basicblock(512, 512, kernel_size=3, stride=1) 81 | self.basicblock11 = basicblock(512, 512, kernel_size=3, stride=1) 82 | # Block 3 ends 83 | self.conv6 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 84 | self.batchnorm6 = nn.BatchNorm2d(512) 85 | self.relu6 = nn.ReLU(inplace=True) 86 | 87 | def forward(self, x): 88 | x = self.conv1(x) 89 | x = self.batchnorm1(x) 90 | x = self.relu1(x) 91 | x = self.conv2(x) 92 | x = self.batchnorm2(x) 93 | x = self.relu2(x) 94 | x = self.maxpool1(x) 95 | x = self.basicblock1(x) 96 | x = self.conv3(x) 97 | x = self.batchnorm3(x) 98 | x = self.relu3(x) 99 | x = self.maxpool2(x) 100 | x = self.basicblock2(x) 101 | x = self.basicblock3(x) 102 | x = self.conv4(x) 103 | x = self.batchnorm4(x) 104 | x = self.relu4(x) 105 | x = self.maxpool3(x) 106 | x = self.basicblock4(x) 107 | x = self.basicblock5(x) 108 | x = self.basicblock6(x) 109 | x = self.basicblock7(x) 110 | x = self.basicblock8(x) 111 | x = self.conv5(x) 112 | x = self.batchnorm5(x) 113 | x = self.relu5(x) 114 | x = self.basicblock9(x) 115 | x = self.basicblock10(x) 116 | x = self.basicblock11(x) 117 | x = self.conv6(x) 118 | x = self.batchnorm6(x) 119 | x = self.relu6(x) 120 | 121 | return x 122 | 123 | # unit test 124 | if __name__ == '__main__': 125 | 126 | batch_size = 32 127 | Height = 48 128 | Width = 160 129 | Channel = 3 130 | 131 | input_images = torch.randn(batch_size,Channel,Height,Width) 132 | model = backbone(Channel) 133 | output_features = model(input_images) 134 | 135 | print("Input size is:",input_images.shape) 136 | print("Output feature map size is:",output_features.shape) -------------------------------------------------------------------------------- /models/decoder.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is to construct decoder for SAR - two layer LSTMs combined with feature map with attention mechanism 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | 7 | __all__ = ['word_embedding','attention','decoder'] 8 | 9 | class word_embedding(nn.Module): 10 | def __init__(self, output_classes, embedding_dim): 11 | super(word_embedding, self).__init__() 12 | ''' 13 | output_classes: number of output classes for the one hot encoding of a word 14 | embedding_dim: embedding dimension for a word 15 | ''' 16 | self.linear = nn.Linear(output_classes, embedding_dim) # linear transformation 17 | 18 | def forward(self,x): 19 | x = self.linear(x) 20 | 21 | return x 22 | 23 | class attention(nn.Module): 24 | def __init__(self, hidden_units, H, W, D): 25 | super(attention, self).__init__() 26 | ''' 27 | hidden_units: hidden units of decoder 28 | H: height of feature map 29 | W: width of feature map 30 | D: depth of feature map 31 | ''' 32 | self.conv1 = nn.Conv2d(hidden_units, D, kernel_size=1, stride=1) 33 | self.conv2 = nn.Conv2d(D, D, kernel_size=3, stride=1, padding=1) 34 | self.conv3 = nn.Conv2d(D, 1, kernel_size=1, stride=1) 35 | self.dropout = nn.Dropout(0.5) 36 | self.softmax = nn.Softmax(dim=-1) 37 | self.H = H 38 | self.W = W 39 | self.D = D 40 | 41 | def forward(self, h, feature_map): 42 | ''' 43 | h: hidden state from decoder output, with size [batch, hidden_units] 44 | feature_map: feature map from backbone network, with size [batch, channel, H, W] 45 | ''' 46 | # reshape hidden state [batch, hidden_units] to [batch, hidden_units, 1, 1] 47 | h = h.unsqueeze(2) 48 | h = h.unsqueeze(3) 49 | h = self.conv1(h) # [batch, D, 1, 1] 50 | h = h.repeat(1, 1, self.H, self.W) # tiling to [batch, D, H, W] 51 | feature_map_origin = feature_map 52 | feature_map = self.conv2(feature_map) # [batch, D, H, W] 53 | combine = self.conv3(self.dropout(torch.tanh(feature_map + h))) # [batch, 1, H, W] 54 | combine_flat = combine.view(combine.size(0), -1) # resize to [batch, H*W] 55 | attention_weights = self.softmax(combine_flat) # [batch, H*W] 56 | attention_weights = attention_weights.view(combine.size()) # [batch, 1, H, W] 57 | glimpse = feature_map_origin * attention_weights.repeat(1, self.D, 1, 1) # [batch, D, H, W] 58 | glimpse = torch.sum(glimpse, dim=(2,3)) # [batch, D] 59 | 60 | return glimpse, attention_weights 61 | 62 | class decoder(nn.Module): 63 | def __init__(self, output_classes, H, W, D=512, hidden_units=512, seq_len=40, device='cpu'): 64 | super(decoder, self).__init__() 65 | ''' 66 | output_classes: number of output classes for the one hot encoding of a word 67 | H: feature map height 68 | W: feature map width 69 | D: glimpse depth 70 | hidden_units: hidden units of encoder/decoder for LSTM 71 | seq_len: output sequence length T 72 | ''' 73 | self.linear1 = nn.Linear(output_classes, hidden_units) 74 | self.lstmcell1 = [nn.LSTMCell(hidden_units, hidden_units) for i in range(seq_len+1)] 75 | self.lstmcell2 = [nn.LSTMCell(hidden_units, hidden_units) for i in range(seq_len+1)] 76 | self.attention = attention(hidden_units, H, W, D) 77 | self.linear2 = nn.Linear(hidden_units+D, output_classes) 78 | self.softmax = nn.LogSoftmax(dim=1) 79 | self.seq_len = seq_len 80 | self.START_TOKEN = output_classes - 3 # Same as END TOKEN 81 | self.output_classes = output_classes 82 | self.hidden_units = hidden_units 83 | self.device = device 84 | 85 | self.lstmcell1 = torch.nn.ModuleList(self.lstmcell1) 86 | self.lstmcell2 = torch.nn.ModuleList(self.lstmcell2) 87 | 88 | def forward(self,hw,y,V): 89 | ''' 90 | hw: embedded feature from encoder [batch, hidden_units] 91 | y: ground truth label one hot encoder [batch, seq, output_classes] 92 | V: feature map for backbone network [batch, D, H, W] 93 | ''' 94 | outputs = [] 95 | attention_weights = [] 96 | batch_size = hw.shape[0] 97 | y_onehot = torch.zeros(batch_size, self.output_classes).to(self.device) 98 | for t in range(self.seq_len + 1): 99 | if t == 0: 100 | inputs_y = hw # size [batch, hidden_units] 101 | # LSTM layer 1 initialization: 102 | hx_1 = torch.zeros(batch_size, self.hidden_units).to(self.device) # initial h0_1 103 | cx_1 = torch.zeros(batch_size, self.hidden_units).to(self.device) # initial c0_1 104 | # LSTM layer 2 initialization: 105 | hx_2 = torch.zeros(batch_size, self.hidden_units).to(self.device) # initial h0_2 106 | cx_2 = torch.zeros(batch_size, self.hidden_units).to(self.device) # initial c0_2 107 | elif t == 1: 108 | y_onehot.zero_() 109 | y_onehot[:,self.START_TOKEN] = 1.0 110 | inputs_y = y_onehot 111 | inputs_y = self.linear1(inputs_y) # [batch, hidden_units] 112 | else: 113 | if self.training: 114 | inputs_y = y[:,t-2,:] # [batch, output_classes] 115 | else: 116 | # greedy search for now - beam search to be implemented! 117 | index = torch.argmax(outputs[t-1], dim=-1) # [batch] 118 | index = index.unsqueeze(1) # [batch, 1] 119 | y_onehot.zero_() 120 | inputs_y = y_onehot.scatter_(1, index, 1) # [batch, output_classes] 121 | 122 | inputs_y = self.linear1(inputs_y) # [batch, hidden_units_encoder] 123 | 124 | # LSTM cells combined with attention and fusion layer 125 | hx_1, cx_1 = self.lstmcell1[t](inputs_y, (hx_1,cx_1)) 126 | hx_2, cx_2 = self.lstmcell2[t](hx_1, (hx_2,cx_2)) 127 | glimpse, att_weights = self.attention(hx_2, V) # [batch, D], [batch, 1, H, W] 128 | combine = torch.cat((hx_2,glimpse), dim=1) # [batch, hidden_units_decoder+D] 129 | out = self.linear2(combine) # [batch, output_classes] 130 | out = self.softmax(out) # [batch, output_classes] 131 | outputs.append(out) 132 | attention_weights.append(att_weights) 133 | 134 | outputs = outputs[1:] # [seq_len, batch, output_classes] 135 | attention_weights = attention_weights[1:] # [seq_len, batch, 1, H, W] 136 | outputs = torch.stack(outputs) # [seq_len, batch, output_classes] 137 | outputs = outputs.permute(1,0,2) # [batch, seq_len, output_classes] 138 | attention_weights = torch.stack(attention_weights) # [seq_len, batch, 1, H, W] 139 | attention_weights = attention_weights.permute(1,0,2,3,4) # [batch, seq_len, 1, H, W] 140 | 141 | return outputs, attention_weights 142 | 143 | # unit test 144 | if __name__ == '__main__': 145 | 146 | batch_size = 2 147 | Height = 48 148 | Width = 160 149 | Channel = 512 150 | output_classes = 94 151 | embedding_dim = 512 152 | hidden_units = 512 153 | layers_decoder = 2 154 | seq_len = 40 155 | 156 | one_hot_embedding = torch.randn(batch_size, output_classes) 157 | one_hot_embedding[one_hot_embedding>0] = torch.ones(1) 158 | one_hot_embedding[one_hot_embedding<0] = torch.zeros(1) 159 | print("Word embedding size is:", one_hot_embedding.shape) 160 | 161 | embedding_model = word_embedding(output_classes, embedding_dim) 162 | embedding_transform = embedding_model(one_hot_embedding) 163 | print("Embedding transform size is:", embedding_transform.shape) 164 | 165 | hw = torch.randn(batch_size, hidden_units) 166 | feature_map = torch.randn(batch_size,Channel,Height,Width) 167 | print("Feature map size is:", feature_map.shape) 168 | 169 | attention_model = attention(hidden_units, Height, Width, Channel) 170 | glimpse, attention_weights = attention_model(hw, feature_map) 171 | print("Glimpse size is:", glimpse.shape) 172 | print("Attention weight size is:", attention_weights.shape) 173 | 174 | label = torch.randn(batch_size, seq_len, output_classes) 175 | decoder_model = decoder(output_classes, Height, Width, Channel, hidden_units, seq_len) 176 | outputs, attention_weights = decoder_model(hw, label, feature_map) 177 | print("Output size is:", outputs.shape) 178 | print("Attention_weights size is:", attention_weights.shape) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | THis is the main training code. 3 | ''' 4 | import os 5 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" # set GPU id at the very begining 6 | import argparse 7 | import random 8 | import math 9 | import torch 10 | import torch.nn.parallel 11 | import torch.optim as optim 12 | import torch.utils.data 13 | import torch.nn.functional as F 14 | from torch.multiprocessing import freeze_support 15 | import pdb 16 | # internal package 17 | from dataset import dataset 18 | from dataset.dataset import dictionary_generator 19 | from models.sar import sar 20 | from utils.dataproc import performance_evaluate 21 | 22 | # main function: 23 | if __name__ == '__main__': 24 | freeze_support() 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument( 27 | '--batch', type=int, default=32, help='input batch size') 28 | parser.add_argument( 29 | '--worker', type=int, default=4, help='number of data loading workers') 30 | parser.add_argument( 31 | '--epoch', type=int, default=250, help='number of epochs') 32 | parser.add_argument('--output', type=str, default='str', help='output folder name') 33 | parser.add_argument('--model', type=str, default='', help='model path') 34 | parser.add_argument('--dataset', type=str, required=True, help="dataset path") 35 | parser.add_argument('--dataset_type', type=str, default='svt', help="dataset type - svt|iiit5k|syn90k|synthtext") 36 | parser.add_argument('--gpu', type=bool, default=False, help="GPU being used or not") 37 | parser.add_argument('--metric', type=str, default='accuracy', help="evaluation metric - accuracy|editdistance") 38 | 39 | opt = parser.parse_args() 40 | print(opt) 41 | 42 | opt.manualSeed = random.randint(1, 10000) # fix seed 43 | print("Random Seed:", opt.manualSeed) 44 | random.seed(opt.manualSeed) 45 | torch.manual_seed(opt.manualSeed) 46 | 47 | # turn on GPU for models: 48 | if opt.gpu == False: 49 | device = torch.device("cpu") 50 | print("CPU being used!") 51 | else: 52 | if torch.cuda.is_available() == True and opt.gpu == True: 53 | device = torch.device("cuda") 54 | print("GPU being used!") 55 | else: 56 | device = torch.device("cpu") 57 | print("CPU being used!") 58 | 59 | # set training parameters 60 | batch_size = opt.batch 61 | Height = 48 62 | Width = 64 63 | feature_height = Height // 4 64 | feature_width = Width // 8 65 | Channel = 3 66 | voc, char2id, id2char = dictionary_generator() 67 | output_classes = len(voc) 68 | embedding_dim = 512 69 | hidden_units = 512 70 | layers = 2 71 | keep_prob = 1.0 72 | seq_len = 40 73 | epochs = opt.epoch 74 | worker = opt.worker 75 | dataset_path = opt.dataset 76 | dataset_type = opt.dataset_type 77 | output_path = opt.output 78 | trained_model_path = opt.model 79 | eval_metric = opt.metric 80 | 81 | # create dataset 82 | print("Create dataset......") 83 | if dataset_type == 'svt': # street view text dataset 84 | img_path = os.path.join(dataset_path, 'img') 85 | train_xml_path = os.path.join(dataset_path, 'train.xml') 86 | test_xml_path = os.path.join(dataset_path, 'test.xml') 87 | train_dataset = dataset.svt_dataset_builder(Height, Width, seq_len, img_path, train_xml_path) 88 | test_dataset = dataset.svt_dataset_builder(Height, Width, seq_len, img_path, test_xml_path) 89 | elif dataset_type == 'iiit5k': # IIIT5k dataset 90 | train_img_path = os.path.join(dataset_path, 'train') 91 | test_img_path = os.path.join(dataset_path, 'test') 92 | train_annotation_path = os.path.join(dataset_path, 'traindata.mat') 93 | test_annotation_path = os.path.join(dataset_path, 'testdata.mat') 94 | train_dataset = dataset.iiit5k_dataset_builder(Height, Width, seq_len, train_img_path, train_annotation_path) 95 | test_dataset = dataset.iiit5k_dataset_builder(Height, Width, seq_len, test_img_path, test_annotation_path) 96 | elif dataset_type == 'syn90k': # Syn90K dataset 97 | train_img_path = os.path.join(dataset_path, 'train') 98 | test_img_path = os.path.join(dataset_path, 'test') 99 | train_dataset = dataset.syn90k_dataset_builder(Height, Width, seq_len, train_img_path) 100 | test_dataset = dataset.syn90k_dataset_builder(Height, Width, seq_len, test_img_path) 101 | elif dataset_type == 'synthtext': # SynthText dataset 102 | train_img_path = os.path.join(dataset_path, 'train') 103 | test_img_path = os.path.join(dataset_path, 'test') 104 | annotation_path = os.path.join(dataset_path, 'gt.mat') 105 | train_dataset = dataset.synthtext_dataset_builder(Height, Width, seq_len, train_img_path, annotation_path) 106 | test_dataset = dataset.synthtext_dataset_builder(Height, Width, seq_len, test_img_path, annotation_path) 107 | else: 108 | print("Not supported yet!") 109 | exit(1) 110 | 111 | # make dataloader 112 | train_dataloader = torch.utils.data.DataLoader( 113 | train_dataset, 114 | batch_size=batch_size, 115 | shuffle=True, 116 | num_workers=int(worker)) 117 | 118 | test_dataloader = torch.utils.data.DataLoader( 119 | test_dataset, 120 | batch_size=batch_size, 121 | shuffle=True, 122 | num_workers=int(worker)) 123 | 124 | print("Length of train dataset is:", len(train_dataset)) 125 | print("Length of test dataset is:", len(test_dataset)) 126 | print("Number of output classes is:", train_dataset.output_classes) 127 | 128 | # make model output folder 129 | try: 130 | os.makedirs(output_path) 131 | except OSError: 132 | pass 133 | 134 | # create model 135 | print("Create model......") 136 | model = sar(Channel, feature_height, feature_width, embedding_dim, output_classes, hidden_units, layers, keep_prob, seq_len, device) 137 | 138 | if trained_model_path != '': 139 | if torch.cuda.is_available() == True and opt.gpu == True: 140 | model.load_state_dict(torch.load(trained_model_path, map_location=lambda storage, loc: storage), strict=False) 141 | model = torch.nn.DataParallel(model).to(device) 142 | else: 143 | model.load_state_dict(torch.load(trained_model_path, map_location=lambda storage, loc: storage), strict=False) 144 | else: 145 | if torch.cuda.is_available() == True and opt.gpu == True: 146 | model = torch.nn.DataParallel(model).to(device) 147 | else: 148 | model = model.to(device) 149 | 150 | optimizer = optim.Adam(model.parameters(), lr=0.001) 151 | lmbda = lambda epoch: 0.9**(epoch // 300) if epoch < 13200 else 10**(-2) 152 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lmbda) 153 | 154 | num_batch = math.ceil(len(train_dataset) / batch_size) 155 | 156 | # train, evaluate, and save model 157 | print("Training starts......") 158 | if eval_metric == 'accuracy': 159 | best_acc = float('-inf') 160 | elif eval_metric == 'editdistance': 161 | best_acc = float('inf') 162 | else: 163 | print("Wrong --metric argument, set it to default") 164 | eval_metric = 'accuracy' 165 | best_acc = float('-inf') 166 | 167 | for epoch in range(epochs): 168 | M_list = [] 169 | for i, data in enumerate(train_dataloader): 170 | x = data[0] # [batch_size, Channel, Height, Width] 171 | y = data[1] # [batch_size, seq_len, output_classes] 172 | x, y = x.to(device), y.to(device) 173 | #print(x.shape, y.shape) 174 | optimizer.zero_grad() 175 | model = model.train() 176 | predict, _, _, _ = model(x, y) 177 | target = y.max(2)[1] # [batch_size, seq_len] 178 | #print("Prediction size is:", predict.shape) 179 | #print("Attention weight size is:", att_weights.shape) 180 | predict_reshape = predict.permute(0,2,1) # [batch_size, output_classes, seq_len] 181 | loss = F.nll_loss(predict_reshape, target) 182 | loss.backward() 183 | optimizer.step() 184 | # prediction evaluation 185 | pred_choice = predict.max(2)[1] # [batch_size, seq_len] 186 | metric, metric_list, predict_words, labeled_words = performance_evaluate(pred_choice.detach().cpu().numpy(), target.detach().cpu().numpy(), voc, char2id, id2char, eval_metric) 187 | M_list += metric_list 188 | print('[Epoch %d: %d/%d] train loss: %f accuracy: %f' % (epoch, i, num_batch, loss.item(), metric)) 189 | #print("predict prob:", predict[0][0]) 190 | #print("predict words:", predict_words[0]) 191 | #print("labeled words:", labeled_words[0]) 192 | train_acc = float(sum(M_list)/len(M_list)) 193 | print("Epoch {} average train accuracy: {}".format(epoch, train_acc)) 194 | 195 | scheduler.step() 196 | 197 | # Validation 198 | print("Testing......") 199 | with torch.set_grad_enabled(False): 200 | M_list = [] 201 | for i, data in enumerate(test_dataloader): 202 | x = data[0] # [batch_size, Channel, Height, Width] 203 | y = data[1] # [batch_size, seq_len, output_classes] 204 | x, y = x.to(device), y.to(device) 205 | model = model.eval() 206 | predict, _, _, _ = model(x, y) 207 | # prediction evaluation 208 | pred_choice = predict.max(2)[1] # [batch_size, seq_len] 209 | target = y.max(2)[1] # [batch_size, seq_len] 210 | metric, metric_list, predict_words, labeled_words = performance_evaluate(pred_choice.detach().cpu().numpy(), target.detach().cpu().numpy(), voc, char2id, id2char, eval_metric) 211 | M_list += metric_list 212 | test_acc = float(sum(M_list)/len(M_list)) 213 | #print("Test predict words:", predict_words[0]) 214 | #print("Test labeled words:", labeled_words[0]) 215 | print("Epoch {} average test accuracy: {}".format(epoch, test_acc)) 216 | with open(os.path.join(output_path,'statistics.txt'), 'a') as f: 217 | f.write("{} {}\n".format(train_acc, test_acc)) 218 | if eval_metric == 'accuracy': 219 | if test_acc >= best_acc: 220 | print("Save current best model with accuracy:", test_acc) 221 | best_acc = test_acc 222 | if torch.cuda.is_available() == True and opt.gpu == True: 223 | torch.save(model.module.state_dict(), '%s/model_best.pth' % (output_path)) 224 | else: 225 | torch.save(model.state_dict(), '%s/model_best.pth' % (output_path)) 226 | elif eval_metric == 'editdistance': 227 | if test_acc <= best_acc: 228 | print("Save current best model with accuracy:", test_acc) 229 | best_acc = test_acc 230 | if torch.cuda.is_available() == True and opt.gpu == True: 231 | torch.save(model.module.state_dict(), '%s/model_best.pth' % (output_path)) 232 | else: 233 | torch.save(model.state_dict(), '%s/model_best.pth' % (output_path)) 234 | print("Best test accuracy is:", best_acc) -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | ''' 2 | This code is to build various dataset for SAR 3 | ''' 4 | import string 5 | import cv2 6 | import torch.utils.data as data 7 | import os 8 | import torch 9 | import numpy as np 10 | import xml.etree.ElementTree as ET 11 | from scipy.io import loadmat 12 | import pdb 13 | 14 | def dictionary_generator(END='END', PADDING='PAD', UNKNOWN='UNK'): 15 | ''' 16 | END: end of sentence token 17 | PADDING: padding token 18 | UNKNOWN: unknown character token 19 | ''' 20 | voc = list(string.printable[:-6]) # characters including 9 digits + 26 lower cases + 26 upper cases + 33 punctuations 21 | 22 | # update the voc with 3 specifical chars 23 | voc.append(END) 24 | voc.append(PADDING) 25 | voc.append(UNKNOWN) 26 | 27 | char2id = dict(zip(voc, range(len(voc)))) 28 | id2char = dict(zip(range(len(voc)), voc)) 29 | 30 | return voc, char2id, id2char 31 | 32 | def end_cut(indices, char2id, id2char): 33 | ''' 34 | indices: numpy array or list of character indices 35 | charid: char to id conversion 36 | id2char: id to char conversion 37 | ''' 38 | cut_indices = [] 39 | for id in indices: 40 | if id != char2id['END']: 41 | if id != char2id['UNK'] and id != char2id['PAD']: 42 | cut_indices.append(id2char[id]) 43 | else: 44 | break 45 | return ''.join(cut_indices) 46 | 47 | def svt_xml_extractor(label_path): 48 | ''' 49 | This code is to extract xml labels from SVT dataset 50 | Input: 51 | label_path: xml label path file 52 | Output: 53 | dict_img: [image_name, bounding box, labels, lexicon] 54 | ''' 55 | # create element tree object 56 | tree = ET.parse(label_path) 57 | 58 | # get root element 59 | root = tree.getroot() 60 | 61 | # create empty list for news items 62 | dict_img = [] 63 | 64 | # iterate news items 65 | for item in root.findall('image'): 66 | name = item.find('imageName').text.split('/')[-1] 67 | lexicon = item.find('lex').text.split(',') 68 | rec = item.find('taggedRectangles') 69 | for r in rec.findall('taggedRectangle'): 70 | x = int(r.get('x')) 71 | y = int(r.get('y')) 72 | w = int(r.get('width')) 73 | h = int(r.get('height')) 74 | bdb = (x,y,w,h) 75 | labels = r.find('tag').text 76 | dict_img.append([name, bdb,labels,lexicon]) 77 | 78 | return dict_img 79 | 80 | def iiit5k_mat_extractor(label_path): 81 | ''' 82 | This code is to extract mat labels from IIIT5k dataset 83 | Input: 84 | label_path: mat label path file 85 | Output: 86 | dict_img: [image_name, labels, small_lexicon, medium_lexicon] 87 | ''' 88 | # create empty list for news items 89 | dict_img = [] 90 | 91 | mat_contents = loadmat(label_path) 92 | 93 | if 'traindata' in mat_contents: 94 | key = 'traindata' 95 | else: 96 | key = 'testdata' 97 | for i in range(len(mat_contents[key][0])): 98 | name = mat_contents[key][0][i][0][0] 99 | label = mat_contents[key][0][i][1][0] 100 | #small_lexi = [item[0] for item in mat_contents[key][0][i][2][0]] 101 | #medium_lexi = [item[0] for item in mat_contents[key][0][i][3][0]] 102 | dict_img.append([name, label]) 103 | 104 | return dict_img 105 | 106 | def synthtext_mat_extractor(label_path): 107 | ''' 108 | This code is to extract mat labels from SynthText dataset 109 | Input: 110 | label_path: mat label path file 111 | Output: 112 | dict_img: [image_name, bounding box, labels] 113 | ''' 114 | # create empty list for news items 115 | dict_img = [] 116 | 117 | mat_contents = loadmat(label_path) 118 | 119 | for i in range(len(mat_contents['imnames'][0])): 120 | image_name = mat_contents['imnames'][0][i][0].split('/')[-1] 121 | word_bdbs = [] 122 | if len(mat_contents['wordBB'][0][i].shape) == 3: 123 | for bdb_idx in range(mat_contents['wordBB'][0][i].shape[-1]): 124 | word_bdb = mat_contents['wordBB'][0][i][:,:,bdb_idx] # shape (2,4), i.e., 4 points for (x,y) 125 | xmin = min(word_bdb[0,:]) 126 | ymin = min(word_bdb[1,:]) 127 | xmax = max(word_bdb[0,:]) 128 | ymax = max(word_bdb[1,:]) 129 | word_bdbs.append((int(xmin),int(ymin),int(xmax),int(ymax))) 130 | else: 131 | word_bdb = mat_contents['wordBB'][0][i][:,:] # shape (2,4), i.e., 4 points for (x,y) 132 | xmin = min(word_bdb[0,:]) 133 | ymin = min(word_bdb[1,:]) 134 | xmax = max(word_bdb[0,:]) 135 | ymax = max(word_bdb[1,:]) 136 | word_bdbs.append((int(xmin),int(ymin),int(xmax),int(ymax))) 137 | 138 | if len(mat_contents['charBB'][0][i].shape) == 3: 139 | for bdb_idx in range(mat_contents['charBB'][0][i].shape[-1]): 140 | char_bdb = mat_contents['charBB'][0][i][:,:,bdb_idx] # shape (2,4), i.e., 4 points for (x,y) 141 | else: 142 | char_bdb = mat_contents['charBB'][0][i][:,:] # shape (2,4), i.e., 4 points for (x,y) 143 | 144 | labels = [] 145 | for label_idx in range(mat_contents['txt'][0][i].shape[0]): 146 | label_total = mat_contents['txt'][0][i][label_idx] 147 | L = label_total.split('\n') 148 | for l in L: 149 | l = l.strip() 150 | l = list(l.split(' ')) 151 | l = [item for item in l if item != ''] 152 | labels += l 153 | if len(word_bdbs) != len(labels): 154 | print("Wrong parsing for labels in SynthText dataset!") 155 | exit(1) 156 | for word_bdb, label in zip(word_bdbs, labels): 157 | dict_img.append([image_name, word_bdb, label]) 158 | 159 | return dict_img 160 | 161 | class svt_dataset_builder(data.Dataset): 162 | def __init__(self, height, width, seq_len, total_img_path, xml_path): 163 | ''' 164 | height: input height to model 165 | width: input width to model 166 | total_img_path: path with all images 167 | xml_path: xml labeling file 168 | seq_len: sequence length 169 | ''' 170 | # parse xml file and create fully ready dataset 171 | self.total_img_path = total_img_path 172 | self.height = height 173 | self.width = width 174 | self.seq_len = seq_len 175 | self.dictionary = svt_xml_extractor(xml_path) 176 | self.total_img_name = os.listdir(total_img_path) 177 | self.dataset = [] 178 | self.voc, self.char2id, _ = dictionary_generator() 179 | self.output_classes = len(self.voc) 180 | for items in self.dictionary: 181 | if items[0] in self.total_img_name: 182 | self.dataset.append([items[0],items[1],items[2]]) 183 | 184 | def __getitem__(self, index): 185 | img_name, bdb, label = self.dataset[index] 186 | IMG = cv2.imread(os.path.join(self.total_img_path,img_name)) 187 | x, y, w, h = bdb 188 | (H, W, _) = IMG.shape 189 | x = max(0, x) 190 | x = min(W-1, x) 191 | y = max(0, y) 192 | y = min(H-1, y) 193 | w = max(0, w) 194 | w = min(W, w) 195 | h = max(0, h) 196 | h = min(H-1, h) 197 | # image processing: 198 | IMG = IMG[y:y+h,x:x+w,:] # crop 199 | IMG = cv2.resize(IMG, (self.width, self.height)) # resize 200 | IMG = (IMG - 127.5)/127.5 # normalization to [-1,1] 201 | IMG = torch.FloatTensor(IMG) # convert to tensor [H, W, C] 202 | IMG = IMG.permute(2,0,1) # [C, H, W] 203 | y_true = np.ones(self.seq_len)*self.char2id['PAD'] # initialize y_true with 'PAD', size [seq_len] 204 | # label processing 205 | for i, c in enumerate(label): 206 | index = self.char2id[c] 207 | y_true[i] = index 208 | y_true[-1] = self.char2id['END'] # always put 'END' in the end 209 | y_true = y_true.astype(int) # must to integer index for one-hot encoding 210 | # convert to one-hot encoding 211 | y_onehot = np.eye(self.output_classes)[y_true] # [seq_len, output_classes] 212 | 213 | return IMG, torch.FloatTensor(y_onehot) 214 | 215 | def __len__(self): 216 | return len(self.dataset) 217 | 218 | class iiit5k_dataset_builder(data.Dataset): 219 | def __init__(self, height, width, seq_len, total_img_path, annotation_path): 220 | ''' 221 | height: input height to model 222 | width: input width to model 223 | total_img_path: path with all images 224 | annotation_path: mat labeling file 225 | seq_len: sequence length 226 | ''' 227 | self.total_img_path = total_img_path 228 | self.height = height 229 | self.width = width 230 | self.seq_len = seq_len 231 | self.dictionary = iiit5k_mat_extractor(annotation_path) 232 | self.total_img_name = os.listdir(total_img_path) 233 | self.dataset = [] 234 | self.voc, self.char2id, _ = dictionary_generator() 235 | self.output_classes = len(self.voc) 236 | 237 | for items in self.dictionary: 238 | if items[0].split('/')[-1] in self.total_img_name: 239 | self.dataset.append([items[0].split('/')[-1],items[1]]) 240 | 241 | def __getitem__(self, index): 242 | img_name, label = self.dataset[index] 243 | IMG = cv2.imread(os.path.join(self.total_img_path,img_name)) 244 | IMG = cv2.resize(IMG, (self.width, self.height)) # resize 245 | IMG = (IMG - 127.5)/127.5 # normalization to [-1,1] 246 | IMG = torch.FloatTensor(IMG) # convert to tensor [H, W, C] 247 | IMG = IMG.permute(2,0,1) # [C, H, W] 248 | y_true = np.ones(self.seq_len)*self.char2id['PAD'] # initialize y_true with 'PAD', size [seq_len] 249 | # label processing 250 | for i, c in enumerate(label): 251 | index = self.char2id[c] 252 | y_true[i] = index 253 | y_true[-1] = self.char2id['END'] # always put 'END' in the end 254 | y_true = y_true.astype(int) # must to integer index for one-hot encoding 255 | # convert to one-hot encoding 256 | y_onehot = np.eye(self.output_classes)[y_true] # [seq_len, output_classes] 257 | 258 | return IMG, torch.FloatTensor(y_onehot) 259 | 260 | def __len__(self): 261 | return len(self.dataset) 262 | 263 | class syn90k_dataset_builder(data.Dataset): 264 | def __init__(self, height, width, seq_len, total_img_path): 265 | ''' 266 | height: input height to model 267 | width: input width to model 268 | total_img_path: path with all images 269 | seq_len: sequence length 270 | ''' 271 | self.total_img_path = total_img_path 272 | self.height = height 273 | self.width = width 274 | self.seq_len = seq_len 275 | self.total_img_name = os.listdir(total_img_path) 276 | self.dataset = [] 277 | self.voc, self.char2id, _ = dictionary_generator() 278 | self.output_classes = len(self.voc) 279 | 280 | for img_name in self.total_img_name: 281 | _, label, _ = img_name.split('_') 282 | self.dataset.append([img_name, label]) 283 | 284 | def __getitem__(self, index): 285 | img_name, label = self.dataset[index] 286 | IMG = cv2.imread(os.path.join(self.total_img_path,img_name)) 287 | IMG = cv2.resize(IMG, (self.width, self.height)) # resize 288 | IMG = (IMG - 127.5)/127.5 # normalization to [-1,1] 289 | IMG = torch.FloatTensor(IMG) # convert to tensor [H, W, C] 290 | IMG = IMG.permute(2,0,1) # [C, H, W] 291 | y_true = np.ones(self.seq_len)*self.char2id['PAD'] # initialize y_true with 'PAD', size [seq_len] 292 | # label processing 293 | for i, c in enumerate(label): 294 | index = self.char2id[c] 295 | y_true[i] = index 296 | y_true[-1] = self.char2id['END'] # always put 'END' in the end 297 | y_true = y_true.astype(int) # must to integer index for one-hot encoding 298 | # convert to one-hot encoding 299 | y_onehot = np.eye(self.output_classes)[y_true] # [seq_len, output_classes] 300 | 301 | return IMG, torch.FloatTensor(y_onehot) 302 | 303 | def __len__(self): 304 | return len(self.dataset) 305 | 306 | class synthtext_dataset_builder(data.Dataset): 307 | def __init__(self, height, width, seq_len, total_img_path, annotation_path): 308 | ''' 309 | height: input height to model 310 | width: input width to model 311 | total_img_path: path with all images 312 | annotation_path: mat labeling file 313 | seq_len: sequence length 314 | ''' 315 | self.total_img_path = total_img_path 316 | self.height = height 317 | self.width = width 318 | self.seq_len = seq_len 319 | self.dictionary = synthtext_mat_extractor(annotation_path) 320 | self.total_img_name = os.listdir(total_img_path) 321 | self.dataset = [] 322 | self.voc, self.char2id, _ = dictionary_generator() 323 | self.output_classes = len(self.voc) 324 | 325 | for items in self.dictionary: 326 | if items[0] in self.total_img_name: 327 | self.dataset.append([items[0],items[1],items[2]]) 328 | 329 | def __getitem__(self, index): 330 | img_name, bdb, label = self.dataset[index] 331 | IMG = cv2.imread(os.path.join(self.total_img_path,img_name)) 332 | xmin, ymin, xmax, ymax = bdb 333 | (H, W, _) = IMG.shape 334 | xmin = max(0, xmin) 335 | xmin = min(W-1, xmin) 336 | ymin = max(0, ymin) 337 | ymin = min(H-1, ymin) 338 | xmax = max(0, xmax) 339 | xmax = min(W-1, xmax) 340 | ymax = max(0, ymax) 341 | ymax = min(H-1, ymax) 342 | # image processing: 343 | IMG = IMG[ymin:ymax+1,xmin:xmax+1,:] # crop 344 | IMG = cv2.resize(IMG, (self.width, self.height)) # resize 345 | IMG = (IMG - 127.5)/127.5 # normalization to [-1,1] 346 | IMG = torch.FloatTensor(IMG) # convert to tensor [H, W, C] 347 | IMG = IMG.permute(2,0,1) # [C, H, W] 348 | y_true = np.ones(self.seq_len)*self.char2id['PAD'] # initialize y_true with 'PAD', size [seq_len] 349 | # label processing 350 | for i, c in enumerate(label): 351 | index = self.char2id[c] 352 | y_true[i] = index 353 | y_true[-1] = self.char2id['END'] # always put 'END' in the end 354 | y_true = y_true.astype(int) # must to integer index for one-hot encoding 355 | # convert to one-hot encoding 356 | y_onehot = np.eye(self.output_classes)[y_true] # [seq_len, output_classes] 357 | 358 | return IMG, torch.FloatTensor(y_onehot) 359 | 360 | def __len__(self): 361 | return len(self.dataset) 362 | 363 | class test_dataset_builder(data.Dataset): 364 | def __init__(self, height, width, img_path): 365 | ''' 366 | height: input height to model 367 | width: input width to model 368 | img_path: path with images 369 | ''' 370 | self.height = height 371 | self.width = width 372 | self.img_path = img_path 373 | self.dataset = [image_name for image_name in os.listdir(self.img_path)] 374 | 375 | def __getitem__(self, index): 376 | IMG = cv2.imread(os.path.join(self.img_path, self.dataset[index])) 377 | # image processing: 378 | IMG = cv2.resize(IMG, (self.width, self.height)) # resize 379 | IMG = (IMG - 127.5)/127.5 # normalization to [-1,1] 380 | IMG = torch.FloatTensor(IMG) # convert to tensor [H, W, C] 381 | IMG = IMG.permute(2,0,1) # [C, H, W] 382 | 383 | return IMG, self.dataset[index] 384 | 385 | def __len__(self): 386 | return len(self.dataset) 387 | 388 | # unit test 389 | if __name__ == '__main__': 390 | 391 | img_path = '../svt/img/' 392 | train_xml_path = '../svt/train.xml' 393 | test_xml_path = '../svt/test.xml' 394 | 395 | img_path_iiit = '../IIIT5K/train/' 396 | annotation_path_iiit = '../IIIT5K/traindata.mat' 397 | 398 | img_path_syn90k = '../Syn90k/train/' 399 | 400 | img_path_synthtext = '../SynthText/train/' 401 | annotation_path_synthtext = '../SynthText/gt.mat' 402 | 403 | height = 48 # input height pixel 404 | width = 64 # input width pixel 405 | seq_len = 40 # sequence length 406 | 407 | voc, char2id, id2char = dictionary_generator() 408 | 409 | train_dict = svt_xml_extractor(train_xml_path) 410 | print("Dictionary for training set is:", train_dict) 411 | 412 | train_dataset = svt_dataset_builder(height, width, seq_len, img_path, train_xml_path) 413 | 414 | for i, item in enumerate(train_dataset): 415 | print(item[0].shape,item[1].shape) 416 | 417 | test_dataset = svt_dataset_builder(height, width, seq_len, img_path, test_xml_path) 418 | 419 | train_dict_iiit = iiit5k_mat_extractor(annotation_path_iiit) 420 | print("Dictionary for training set is:", train_dict_iiit) 421 | 422 | train_dataset_iiit5k = iiit5k_dataset_builder(height, width, seq_len, img_path_iiit, annotation_path_iiit) 423 | 424 | train_dataset_syn90k = syn90k_dataset_builder(height, width, seq_len, img_path_syn90k) 425 | 426 | train_dict_synthtext = synthtext_mat_extractor(annotation_path_synthtext) 427 | #print("Dictionary for training set is:", train_dict_synthtext) 428 | 429 | train_dataset_synthtext = synthtext_dataset_builder(height, width, seq_len, img_path_synthtext, annotation_path_synthtext) 430 | 431 | for i, item in enumerate(train_dataset): 432 | print(item[0].shape,item[1].shape) 433 | IMG = item[0].permute(1,2,0) 434 | IMG = IMG.detach().numpy() 435 | IMG = (IMG*127.5+127.5).astype(np.uint8) 436 | cv2.imwrite('../test/svt_'+str(i)+'.jpg', IMG) 437 | 438 | for i, item in enumerate(train_dataset_iiit5k): 439 | print(item[0].shape,item[1].shape) 440 | IMG = item[0].permute(1,2,0) 441 | IMG = IMG.detach().numpy() 442 | IMG = (IMG*127.5+127.5).astype(np.uint8) 443 | cv2.imwrite('../test/iiit_'+str(i)+'.jpg', IMG) 444 | 445 | for i, item in enumerate(train_dataset_syn90k): 446 | print(item[0].shape,item[1].shape) 447 | IMG = item[0].permute(1,2,0) 448 | IMG = IMG.detach().numpy() 449 | IMG = (IMG*127.5+127.5).astype(np.uint8) 450 | cv2.imwrite('../test/syn90k_'+str(i)+'.jpg', IMG) 451 | 452 | for i, item in enumerate(train_dataset_synthtext): 453 | print(item[0].shape,item[1].shape) 454 | IMG = item[0].permute(1,2,0) 455 | IMG = IMG.detach().numpy() 456 | IMG = (IMG*127.5+127.5).astype(np.uint8) 457 | target = item[1].max(1)[1] # [seq_len] 458 | label = end_cut(target.detach().cpu().numpy(), char2id, id2char) 459 | print(label) 460 | cv2.imwrite('../test/synthtext_'+str(i)+'.jpg', IMG) --------------------------------------------------------------------------------