├── demo_image ├── __init__.py ├── demo_1.png ├── demo_10.jpg ├── demo_2.jpg ├── demo_3.png ├── demo_4.png ├── demo_5.png ├── demo_6.png ├── demo_7.png ├── demo_8.jpg └── demo_9.jpg ├── modules ├── __init__.py ├── sequence_modeling.py ├── prediction.py ├── transformation.py ├── feature_extraction.py └── transformer_util.py ├── README.md ├── .gitignore ├── create_lmdb_dataset.py ├── model.py ├── demo.py ├── utils.py ├── LICENSE ├── dataset.py ├── test.py └── train.py /demo_image/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /demo_image/demo_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saberSabersaber/transformer_OCR/HEAD/demo_image/demo_1.png -------------------------------------------------------------------------------- /demo_image/demo_10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saberSabersaber/transformer_OCR/HEAD/demo_image/demo_10.jpg -------------------------------------------------------------------------------- /demo_image/demo_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saberSabersaber/transformer_OCR/HEAD/demo_image/demo_2.jpg -------------------------------------------------------------------------------- /demo_image/demo_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saberSabersaber/transformer_OCR/HEAD/demo_image/demo_3.png -------------------------------------------------------------------------------- /demo_image/demo_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saberSabersaber/transformer_OCR/HEAD/demo_image/demo_4.png -------------------------------------------------------------------------------- /demo_image/demo_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saberSabersaber/transformer_OCR/HEAD/demo_image/demo_5.png -------------------------------------------------------------------------------- /demo_image/demo_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saberSabersaber/transformer_OCR/HEAD/demo_image/demo_6.png -------------------------------------------------------------------------------- /demo_image/demo_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saberSabersaber/transformer_OCR/HEAD/demo_image/demo_7.png -------------------------------------------------------------------------------- /demo_image/demo_8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saberSabersaber/transformer_OCR/HEAD/demo_image/demo_8.jpg -------------------------------------------------------------------------------- /demo_image/demo_9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/saberSabersaber/transformer_OCR/HEAD/demo_image/demo_9.jpg -------------------------------------------------------------------------------- /modules/sequence_modeling.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BidirectionalLSTM(nn.Module): 5 | 6 | def __init__(self, input_size, hidden_size, output_size): 7 | super(BidirectionalLSTM, self).__init__() 8 | self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True) 9 | self.linear = nn.Linear(hidden_size * 2, output_size) 10 | 11 | def forward(self, input): 12 | """ 13 | input : visual feature [batch_size x T x input_size] 14 | output : contextual feature [batch_size x T x output_size] 15 | """ 16 | self.rnn.flatten_parameters() 17 | recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) 18 | output = self.linear(recurrent) # batch_size x T x output_size 19 | return output 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # transformer_OCR 2 | 利用transformer 进行ocr识别。项目主题框架参考deep-text-recognition-benchmark(https://github.com/clovaai/deep-text-recognition-benchmark),训练和测试数据在deep-text-recognition-benchmark 项目中可以进行下载,Transformer 部分 参考pytorch-seq2seq(https://github.com/bentrevett/pytorch-seq2seq) 3 | transformer 结构采用 pytorch-seq2seq 中的demo 参数配置,利用deep-text-recognition-benchmark 中模型对backbone部分进行初始化,平均acc 达到0.85. 4 | 5 | # 项目依赖 6 | PyTorch 1.3.1 CUDA 10.1, python 3.6 and Ubuntu 16.04 lmdb, pillow, torchvision, nltk, natsort 7 | 8 | # 模型地址 9 | 链接:https://pan.baidu.com/s/1RzWpU_0-OQcezTKuMQqmUA 10 | 提取码:olze 11 | 12 | # 训练测试 13 | CUDA_VISIBLE_DEVICES=0 python3 train.py \ 14 | --train_data data_lmdb_release/training --valid_data data_lmdb_release/validation \ 15 | --select_data MJ-ST --batch_ratio 0.5-0.5 \ 16 | --Transformation TPS --FeatureExtraction ResNet --SequenceModeling None --Prediction Transformer 17 | 18 | CUDA_VISIBLE_DEVICES=0 python3 test.py \ 19 | --eval_data data_lmdb_release/evaluation --benchmark_all_eval \ 20 | --Transformation TPS --FeatureExtraction ResNet --SequenceModeling None --Prediction Transformer \ 21 | --saved_model TPS-ResNet-None-Transformer.pth 22 | 23 | # 模型效果 24 | accuracy: IIIT5k_3000: 87.567 SVT: 87.172 IC03_860: 95.465 IC03_867: 94.810 IC13_857: 93.816 IC13_1015: 92.414 IC15_1811: 77.361 IC15_2077: 74.506 SVTP: 78.915 CUTE80: 73.519 total_accuracy: 85.039 averaged_infer_time: 28.099 # parameters: 58.723 25 | 26 | # todo 27 | 由于transformer 结构只是采用了 pytorch-seq2seq 中的demo 参数配置,参数配置还有一定的调优空间,以及学习率策略、优化器等都还需实验进一步尝试。 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /create_lmdb_dataset.py: -------------------------------------------------------------------------------- 1 | """ a modified version of CRNN torch repository https://github.com/bgshih/crnn/blob/master/tool/create_dataset.py """ 2 | 3 | import fire 4 | import os 5 | import lmdb 6 | import cv2 7 | 8 | import numpy as np 9 | 10 | 11 | def checkImageIsValid(imageBin): 12 | if imageBin is None: 13 | return False 14 | imageBuf = np.frombuffer(imageBin, dtype=np.uint8) 15 | img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE) 16 | imgH, imgW = img.shape[0], img.shape[1] 17 | if imgH * imgW == 0: 18 | return False 19 | return True 20 | 21 | 22 | def writeCache(env, cache): 23 | with env.begin(write=True) as txn: 24 | for k, v in cache.items(): 25 | txn.put(k, v) 26 | 27 | 28 | def createDataset(inputPath, gtFile, outputPath, checkValid=True): 29 | """ 30 | Create LMDB dataset for training and evaluation. 31 | ARGS: 32 | inputPath : input folder path where starts imagePath 33 | outputPath : LMDB output path 34 | gtFile : list of image path and label 35 | checkValid : if true, check the validity of every image 36 | """ 37 | os.makedirs(outputPath, exist_ok=True) 38 | env = lmdb.open(outputPath, map_size=1099511627776) 39 | cache = {} 40 | cnt = 1 41 | 42 | with open(gtFile, 'r', encoding='utf-8') as data: 43 | datalist = data.readlines() 44 | 45 | nSamples = len(datalist) 46 | for i in range(nSamples): 47 | imagePath, label = datalist[i].strip('\n').split('\t') 48 | imagePath = os.path.join(inputPath, imagePath) 49 | 50 | # # only use alphanumeric data 51 | # if re.search('[^a-zA-Z0-9]', label): 52 | # continue 53 | 54 | if not os.path.exists(imagePath): 55 | print('%s does not exist' % imagePath) 56 | continue 57 | with open(imagePath, 'rb') as f: 58 | imageBin = f.read() 59 | if checkValid: 60 | try: 61 | if not checkImageIsValid(imageBin): 62 | print('%s is not a valid image' % imagePath) 63 | continue 64 | except: 65 | print('error occured', i) 66 | with open(outputPath + '/error_image_log.txt', 'a') as log: 67 | log.write('%s-th image data occured error\n' % str(i)) 68 | continue 69 | 70 | imageKey = 'image-%09d'.encode() % cnt 71 | labelKey = 'label-%09d'.encode() % cnt 72 | cache[imageKey] = imageBin 73 | cache[labelKey] = label.encode() 74 | 75 | if cnt % 1000 == 0: 76 | writeCache(env, cache) 77 | cache = {} 78 | print('Written %d / %d' % (cnt, nSamples)) 79 | cnt += 1 80 | nSamples = cnt-1 81 | cache['num-samples'.encode()] = str(nSamples).encode() 82 | writeCache(env, cache) 83 | print('Created dataset with %d samples' % nSamples) 84 | 85 | 86 | if __name__ == '__main__': 87 | fire.Fire(createDataset) 88 | -------------------------------------------------------------------------------- /modules/prediction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 5 | 6 | 7 | class Attention(nn.Module): 8 | 9 | def __init__(self, input_size, hidden_size, num_classes): 10 | super(Attention, self).__init__() 11 | self.attention_cell = AttentionCell(input_size, hidden_size, num_classes) 12 | self.hidden_size = hidden_size 13 | self.num_classes = num_classes 14 | self.generator = nn.Linear(hidden_size, num_classes) 15 | 16 | def _char_to_onehot(self, input_char, onehot_dim=38): 17 | input_char = input_char.unsqueeze(1) 18 | batch_size = input_char.size(0) 19 | one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(device) 20 | one_hot = one_hot.scatter_(1, input_char, 1) 21 | return one_hot 22 | 23 | def forward(self, batch_H, text, is_train=True, batch_max_length=25): 24 | """ 25 | input: 26 | batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x contextual_feature_channels] 27 | text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO]. 28 | output: probability distribution at each step [batch_size x num_steps x num_classes] 29 | """ 30 | batch_size = batch_H.size(0) 31 | num_steps = batch_max_length + 1 # +1 for [s] at end of sentence. 32 | 33 | output_hiddens = torch.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0).to(device) 34 | hidden = (torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device), 35 | torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device)) 36 | 37 | if is_train: 38 | for i in range(num_steps): 39 | # one-hot vectors for a i-th char. in a batch 40 | char_onehots = self._char_to_onehot(text[:, i], onehot_dim=self.num_classes) 41 | # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_onehots : one-hot(y_{t-1}) 42 | hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots) 43 | output_hiddens[:, i, :] = hidden[0] # LSTM hidden index (0: hidden, 1: Cell) 44 | probs = self.generator(output_hiddens) 45 | 46 | else: 47 | targets = torch.LongTensor(batch_size).fill_(0).to(device) # [GO] token 48 | probs = torch.FloatTensor(batch_size, num_steps, self.num_classes).fill_(0).to(device) 49 | 50 | for i in range(num_steps): 51 | char_onehots = self._char_to_onehot(targets, onehot_dim=self.num_classes) 52 | hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots) 53 | probs_step = self.generator(hidden[0]) 54 | probs[:, i, :] = probs_step 55 | _, next_input = probs_step.max(1) 56 | targets = next_input 57 | 58 | return probs # batch_size x num_steps x num_classes 59 | 60 | 61 | class AttentionCell(nn.Module): 62 | 63 | def __init__(self, input_size, hidden_size, num_embeddings): 64 | super(AttentionCell, self).__init__() 65 | self.i2h = nn.Linear(input_size, hidden_size, bias=False) 66 | self.h2h = nn.Linear(hidden_size, hidden_size) # either i2i or h2h should have bias 67 | self.score = nn.Linear(hidden_size, 1, bias=False) 68 | self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size) 69 | self.hidden_size = hidden_size 70 | 71 | def forward(self, prev_hidden, batch_H, char_onehots): 72 | # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size] 73 | batch_H_proj = self.i2h(batch_H) 74 | prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1) 75 | e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1 76 | 77 | alpha = F.softmax(e, dim=1) 78 | context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel 79 | concat_context = torch.cat([context, char_onehots], 1) # batch_size x (num_channel + num_embedding) 80 | cur_hidden = self.rnn(concat_context, prev_hidden) 81 | return cur_hidden, alpha 82 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | from modules.transformation import TPS_SpatialTransformerNetwork 21 | from modules.feature_extraction import VGG_FeatureExtractor, RCNN_FeatureExtractor, ResNet_FeatureExtractor 22 | from modules.sequence_modeling import BidirectionalLSTM 23 | from modules.prediction import Attention 24 | from modules.transformer_util import Seq2Seq, Encoder, Decoder 25 | 26 | 27 | class Model(nn.Module): 28 | 29 | def __init__(self, opt): 30 | super(Model, self).__init__() 31 | self.opt = opt 32 | self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction, 33 | 'Seq': opt.SequenceModeling, 'Pred': opt.Prediction} 34 | 35 | """ Transformation """ 36 | if opt.Transformation == 'TPS': 37 | self.Transformation = TPS_SpatialTransformerNetwork( 38 | F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel) 39 | else: 40 | print('No Transformation module specified') 41 | 42 | """ FeatureExtraction """ 43 | if opt.FeatureExtraction == 'VGG': 44 | self.FeatureExtraction = VGG_FeatureExtractor(opt.input_channel, opt.output_channel) 45 | elif opt.FeatureExtraction == 'RCNN': 46 | self.FeatureExtraction = RCNN_FeatureExtractor(opt.input_channel, opt.output_channel) 47 | elif opt.FeatureExtraction == 'ResNet': 48 | self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel) 49 | else: 50 | raise Exception('No FeatureExtraction module specified') 51 | self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512 52 | self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1 53 | 54 | """ Sequence modeling""" 55 | if opt.SequenceModeling == 'BiLSTM': 56 | self.SequenceModeling = nn.Sequential( 57 | BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size), 58 | BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size)) 59 | self.SequenceModeling_output = opt.hidden_size 60 | else: 61 | print('No SequenceModeling module specified') 62 | self.SequenceModeling_output = self.FeatureExtraction_output 63 | 64 | """ Prediction """ 65 | if opt.Prediction == 'CTC': 66 | self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class) 67 | elif opt.Prediction == 'Attn': 68 | self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class) 69 | elif opt.Prediction == 'Transformer': 70 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 71 | OUTPUT_DIM = opt.num_class 72 | HID_DIM = self.SequenceModeling_output 73 | ENC_LAYERS = 3 74 | DEC_LAYERS = 3 75 | ENC_HEADS = 8 76 | DEC_HEADS = 8 77 | ENC_PF_DIM = 512 78 | DEC_PF_DIM = 512 79 | ENC_DROPOUT = 0.1 80 | DEC_DROPOUT = 0.1 81 | 82 | enc = Encoder(HID_DIM, 83 | ENC_LAYERS, 84 | ENC_HEADS, 85 | ENC_PF_DIM, 86 | ENC_DROPOUT, 87 | device) 88 | 89 | dec = Decoder(OUTPUT_DIM, 90 | HID_DIM, 91 | DEC_LAYERS, 92 | DEC_HEADS, 93 | DEC_PF_DIM, 94 | DEC_DROPOUT, 95 | device) 96 | 97 | 98 | TRG_PAD_IDX = 2#TRG.vocab.stoi[TRG.pad_token] 99 | 100 | self.Prediction = Seq2Seq(enc, dec, TRG_PAD_IDX, device).to(device) 101 | self.Prediction.apply(self.initialize_weights); 102 | print("use transformer") 103 | else: 104 | raise Exception('Prediction is neither CTC or Attn') 105 | 106 | def initialize_weights(self, m): 107 | if hasattr(m, 'weight') and m.weight.dim() > 1: 108 | nn.init.xavier_uniform_(m.weight.data) 109 | def forward(self, input, text, is_train=True): 110 | """ Transformation stage """ 111 | if not self.stages['Trans'] == "None": 112 | input = self.Transformation(input) 113 | 114 | """ Feature extraction stage """ 115 | visual_feature = self.FeatureExtraction(input) 116 | visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h] 117 | visual_feature = visual_feature.squeeze(3) 118 | 119 | """ Sequence modeling stage """ 120 | if self.stages['Seq'] == 'BiLSTM': 121 | contextual_feature = self.SequenceModeling(visual_feature) 122 | else: 123 | contextual_feature = visual_feature # for convenience. this is NOT contextually modeled by BiLSTM 124 | 125 | """ Prediction stage """ 126 | if self.stages['Pred'] == 'CTC': 127 | prediction = self.Prediction(contextual_feature.contiguous()) 128 | elif self.stages['Pred'] == 'Transformer': 129 | prediction = self.Prediction(contextual_feature.contiguous(), text, is_train) 130 | else: 131 | prediction = self.Prediction(contextual_feature.contiguous(), text, is_train, batch_max_length=self.opt.batch_max_length) 132 | 133 | return prediction 134 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import string 3 | import argparse 4 | 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import torch.utils.data 8 | import torch.nn.functional as F 9 | 10 | from utils import CTCLabelConverter, AttnLabelConverter, TransformerLabelConverter 11 | from dataset import RawDataset, AlignCollate 12 | from model import Model 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | 15 | 16 | def demo(opt): 17 | """ model configuration """ 18 | if 'CTC' in opt.Prediction: 19 | converter = CTCLabelConverter(opt.character) 20 | elif "Transformer" in opt.Prediction: 21 | converter = TransformerLabelConverter(opt.character) 22 | else: 23 | converter = AttnLabelConverter(opt.character) 24 | opt.num_class = len(converter.character) 25 | 26 | if opt.rgb: 27 | opt.input_channel = 3 28 | model = Model(opt) 29 | print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, 30 | opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, 31 | opt.SequenceModeling, opt.Prediction) 32 | model = torch.nn.DataParallel(model).to(device) 33 | 34 | # load model 35 | print('loading pretrained model from %s' % opt.saved_model) 36 | model.load_state_dict(torch.load(opt.saved_model, map_location=device)) 37 | 38 | # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo 39 | AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 40 | demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset 41 | demo_loader = torch.utils.data.DataLoader( 42 | demo_data, batch_size=opt.batch_size, 43 | shuffle=False, 44 | num_workers=int(opt.workers), 45 | collate_fn=AlignCollate_demo, pin_memory=True) 46 | 47 | # predict 48 | model.eval() 49 | with torch.no_grad(): 50 | for image_tensors, image_path_list in demo_loader: 51 | batch_size = image_tensors.size(0) 52 | image = image_tensors.to(device) 53 | # For max length prediction 54 | length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) 55 | text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) 56 | 57 | if 'CTC' in opt.Prediction: 58 | preds = model(image, text_for_pred) 59 | 60 | # Select max probabilty (greedy decoding) then decode index to character 61 | preds_size = torch.IntTensor([preds.size(1)] * batch_size) 62 | _, preds_index = preds.max(2) 63 | # preds_index = preds_index.view(-1) 64 | preds_str = converter.decode(preds_index, preds_size) 65 | 66 | else: 67 | preds = model(image, text_for_pred, is_train=False) 68 | 69 | # select max probabilty (greedy decoding) then decode index to character 70 | _, preds_index = preds.max(2) 71 | preds_str = converter.decode(preds_index, length_for_pred) 72 | 73 | 74 | log = open(f'./log_demo_result.txt', 'a') 75 | dashed_line = '-' * 80 76 | head = f'{"image_path":25s}\t{"predicted_labels":25s}\tconfidence score' 77 | 78 | print(f'{dashed_line}\n{head}\n{dashed_line}') 79 | log.write(f'{dashed_line}\n{head}\n{dashed_line}\n') 80 | 81 | preds_prob = F.softmax(preds, dim=2) 82 | preds_max_prob, _ = preds_prob.max(dim=2) 83 | for img_name, pred, pred_max_prob in zip(image_path_list, preds_str, preds_max_prob): 84 | if 'Attn' in opt.Prediction or "Transformer" in opt.Prediction: 85 | pred_EOS = pred.find('[s]') 86 | pred = pred[:pred_EOS] # prune after "end of sentence" token ([s]) 87 | pred_max_prob = pred_max_prob[:pred_EOS] 88 | 89 | # calculate confidence score (= multiply of pred_max_prob) 90 | confidence_score = pred_max_prob.cumprod(dim=0)[-1] 91 | 92 | print(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}') 93 | log.write(f'{img_name:25s}\t{pred:25s}\t{confidence_score:0.4f}\n') 94 | 95 | log.close() 96 | 97 | if __name__ == '__main__': 98 | parser = argparse.ArgumentParser() 99 | parser.add_argument('--image_folder', required=True, help='path to image_folder which contains text images') 100 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) 101 | parser.add_argument('--batch_size', type=int, default=192, help='input batch size') 102 | parser.add_argument('--saved_model', required=True, help="path to saved_model to evaluation") 103 | """ Data processing """ 104 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') 105 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') 106 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') 107 | parser.add_argument('--rgb', action='store_true', help='use rgb input') 108 | parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label') 109 | parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode') 110 | parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize') 111 | """ Model Architecture """ 112 | parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS') 113 | parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet') 114 | parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM') 115 | parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn') 116 | parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') 117 | parser.add_argument('--input_channel', type=int, default=1, help='the number of input channel of Feature extractor') 118 | parser.add_argument('--output_channel', type=int, default=512, 119 | help='the number of output channel of Feature extractor') 120 | parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') 121 | 122 | opt = parser.parse_args() 123 | 124 | """ vocab / character number configuration """ 125 | if opt.sensitive: 126 | opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). 127 | 128 | cudnn.benchmark = True 129 | cudnn.deterministic = True 130 | opt.num_gpu = torch.cuda.device_count() 131 | 132 | demo(opt) 133 | -------------------------------------------------------------------------------- /modules/transformation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 6 | 7 | 8 | class TPS_SpatialTransformerNetwork(nn.Module): 9 | """ Rectification Network of RARE, namely TPS based STN """ 10 | 11 | def __init__(self, F, I_size, I_r_size, I_channel_num=1): 12 | """ Based on RARE TPS 13 | input: 14 | batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width] 15 | I_size : (height, width) of the input image I 16 | I_r_size : (height, width) of the rectified image I_r 17 | I_channel_num : the number of channels of the input image I 18 | output: 19 | batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width] 20 | """ 21 | super(TPS_SpatialTransformerNetwork, self).__init__() 22 | self.F = F 23 | self.I_size = I_size 24 | self.I_r_size = I_r_size # = (I_r_height, I_r_width) 25 | self.I_channel_num = I_channel_num 26 | self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num) 27 | self.GridGenerator = GridGenerator(self.F, self.I_r_size) 28 | 29 | def forward(self, batch_I): 30 | batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 31 | build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2 32 | build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2]) 33 | 34 | if torch.__version__ > "1.2.0": 35 | batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True) 36 | else: 37 | batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border') 38 | 39 | return batch_I_r 40 | 41 | 42 | class LocalizationNetwork(nn.Module): 43 | """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """ 44 | 45 | def __init__(self, F, I_channel_num): 46 | super(LocalizationNetwork, self).__init__() 47 | self.F = F 48 | self.I_channel_num = I_channel_num 49 | self.conv = nn.Sequential( 50 | nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1, 51 | bias=False), nn.BatchNorm2d(64), nn.ReLU(True), 52 | nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2 53 | nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), 54 | nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4 55 | nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), 56 | nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8 57 | nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), 58 | nn.AdaptiveAvgPool2d(1) # batch_size x 512 59 | ) 60 | 61 | self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True)) 62 | self.localization_fc2 = nn.Linear(256, self.F * 2) 63 | 64 | # Init fc2 in LocalizationNetwork 65 | self.localization_fc2.weight.data.fill_(0) 66 | """ see RARE paper Fig. 6 (a) """ 67 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 68 | ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2)) 69 | ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2)) 70 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 71 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 72 | initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 73 | self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1) 74 | 75 | def forward(self, batch_I): 76 | """ 77 | input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width] 78 | output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2] 79 | """ 80 | batch_size = batch_I.size(0) 81 | features = self.conv(batch_I).view(batch_size, -1) 82 | batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2) 83 | return batch_C_prime 84 | 85 | 86 | class GridGenerator(nn.Module): 87 | """ Grid Generator of RARE, which produces P_prime by multipling T with P """ 88 | 89 | def __init__(self, F, I_r_size): 90 | """ Generate P_hat and inv_delta_C for later """ 91 | super(GridGenerator, self).__init__() 92 | self.eps = 1e-6 93 | self.I_r_height, self.I_r_width = I_r_size 94 | self.F = F 95 | self.C = self._build_C(self.F) # F x 2 96 | self.P = self._build_P(self.I_r_width, self.I_r_height) 97 | ## for multi-gpu, you need register buffer 98 | self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3 99 | self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3 100 | ## for fine-tuning with different image width, you may use below instead of self.register_buffer 101 | #self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float().cuda() # F+3 x F+3 102 | #self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float().cuda() # n x F+3 103 | 104 | def _build_C(self, F): 105 | """ Return coordinates of fiducial points in I_r; C """ 106 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 107 | ctrl_pts_y_top = -1 * np.ones(int(F / 2)) 108 | ctrl_pts_y_bottom = np.ones(int(F / 2)) 109 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 110 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 111 | C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 112 | return C # F x 2 113 | 114 | def _build_inv_delta_C(self, F, C): 115 | """ Return inv_delta_C which is needed to calculate T """ 116 | hat_C = np.zeros((F, F), dtype=float) # F x F 117 | for i in range(0, F): 118 | for j in range(i, F): 119 | r = np.linalg.norm(C[i] - C[j]) 120 | hat_C[i, j] = r 121 | hat_C[j, i] = r 122 | np.fill_diagonal(hat_C, 1) 123 | hat_C = (hat_C ** 2) * np.log(hat_C) 124 | # print(C.shape, hat_C.shape) 125 | delta_C = np.concatenate( # F+3 x F+3 126 | [ 127 | np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3 128 | np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3 129 | np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3 130 | ], 131 | axis=0 132 | ) 133 | inv_delta_C = np.linalg.inv(delta_C) 134 | return inv_delta_C # F+3 x F+3 135 | 136 | def _build_P(self, I_r_width, I_r_height): 137 | I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width # self.I_r_width 138 | I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height # self.I_r_height 139 | P = np.stack( # self.I_r_width x self.I_r_height x 2 140 | np.meshgrid(I_r_grid_x, I_r_grid_y), 141 | axis=2 142 | ) 143 | return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2 144 | 145 | def _build_P_hat(self, F, C, P): 146 | n = P.shape[0] # n (= self.I_r_width x self.I_r_height) 147 | P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2 148 | C_tile = np.expand_dims(C, axis=0) # 1 x F x 2 149 | P_diff = P_tile - C_tile # n x F x 2 150 | rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F 151 | rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F 152 | P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1) 153 | return P_hat # n x F+3 154 | 155 | def build_P_prime(self, batch_C_prime): 156 | """ Generate Grid from batch_C_prime [batch_size x F x 2] """ 157 | batch_size = batch_C_prime.size(0) 158 | batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1) 159 | batch_P_hat = self.P_hat.repeat(batch_size, 1, 1) 160 | batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros( 161 | batch_size, 3, 2).float().to(device)), dim=1) # batch_size x F+3 x 2 162 | batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2 163 | batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2 164 | return batch_P_prime # batch_size x n x 2 165 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 3 | 4 | 5 | class CTCLabelConverter(object): 6 | """ Convert between text-label and text-index """ 7 | 8 | def __init__(self, character): 9 | # character (str): set of the possible characters. 10 | dict_character = list(character) 11 | 12 | self.dict = {} 13 | for i, char in enumerate(dict_character): 14 | # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss 15 | self.dict[char] = i + 1 16 | 17 | self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0) 18 | 19 | def encode(self, text, batch_max_length=25): 20 | """convert text-label into text-index. 21 | input: 22 | text: text labels of each image. [batch_size] 23 | batch_max_length: max length of text label in the batch. 25 by default 24 | 25 | output: 26 | text: text index for CTCLoss. [batch_size, batch_max_length] 27 | length: length of each text. [batch_size] 28 | """ 29 | length = [len(s) for s in text] 30 | 31 | # The index used for padding (=0) would not affect the CTC loss calculation. 32 | batch_text = torch.LongTensor(len(text), batch_max_length).fill_(0) 33 | for i, t in enumerate(text): 34 | text = list(t) 35 | text = [self.dict[char] for char in text] 36 | batch_text[i][:len(text)] = torch.LongTensor(text) 37 | return (batch_text.to(device), torch.IntTensor(length).to(device)) 38 | 39 | def decode(self, text_index, length): 40 | """ convert text-index into text-label. """ 41 | texts = [] 42 | for index, l in enumerate(length): 43 | t = text_index[index, :] 44 | 45 | char_list = [] 46 | for i in range(l): 47 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. 48 | char_list.append(self.character[t[i]]) 49 | text = ''.join(char_list) 50 | 51 | texts.append(text) 52 | return texts 53 | 54 | 55 | class CTCLabelConverterForBaiduWarpctc(object): 56 | """ Convert between text-label and text-index for baidu warpctc """ 57 | 58 | def __init__(self, character): 59 | # character (str): set of the possible characters. 60 | dict_character = list(character) 61 | 62 | self.dict = {} 63 | for i, char in enumerate(dict_character): 64 | # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss 65 | self.dict[char] = i + 1 66 | 67 | self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0) 68 | 69 | def encode(self, text, batch_max_length=25): 70 | """convert text-label into text-index. 71 | input: 72 | text: text labels of each image. [batch_size] 73 | output: 74 | text: concatenated text index for CTCLoss. 75 | [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] 76 | length: length of each text. [batch_size] 77 | """ 78 | length = [len(s) for s in text] 79 | text = ''.join(text) 80 | text = [self.dict[char] for char in text] 81 | 82 | return (torch.IntTensor(text), torch.IntTensor(length)) 83 | 84 | def decode(self, text_index, length): 85 | """ convert text-index into text-label. """ 86 | texts = [] 87 | index = 0 88 | for l in length: 89 | t = text_index[index:index + l] 90 | 91 | char_list = [] 92 | for i in range(l): 93 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. 94 | char_list.append(self.character[t[i]]) 95 | text = ''.join(char_list) 96 | 97 | texts.append(text) 98 | index += l 99 | return texts 100 | 101 | 102 | class AttnLabelConverter(object): 103 | """ Convert between text-label and text-index """ 104 | 105 | def __init__(self, character): 106 | # character (str): set of the possible characters. 107 | # [GO] for the start token of the attention decoder. [s] for end-of-sentence token. 108 | list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]'] 109 | list_character = list(character) 110 | self.character = list_token + list_character 111 | 112 | self.dict = {} 113 | for i, char in enumerate(self.character): 114 | # print(i, char) 115 | self.dict[char] = i 116 | 117 | def encode(self, text, batch_max_length=25): 118 | """ convert text-label into text-index. 119 | input: 120 | text: text labels of each image. [batch_size] 121 | batch_max_length: max length of text label in the batch. 25 by default 122 | 123 | output: 124 | text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token. 125 | text[:, 0] is [GO] token and text is padded with [GO] token after [s] token. 126 | length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size] 127 | """ 128 | length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence. 129 | # batch_max_length = max(length) # this is not allowed for multi-gpu setting 130 | batch_max_length += 1 131 | # additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token. 132 | batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0) 133 | for i, t in enumerate(text): 134 | text = list(t) 135 | text.append('[s]') 136 | text = [self.dict[char] for char in text] 137 | batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token 138 | return (batch_text.to(device), torch.IntTensor(length).to(device)) 139 | 140 | def decode(self, text_index, length): 141 | """ convert text-index into text-label. """ 142 | texts = [] 143 | for index, l in enumerate(length): 144 | text = ''.join([self.character[i] for i in text_index[index, :]]) 145 | texts.append(text) 146 | return texts 147 | 148 | class TransformerLabelConverter(object): 149 | """ Convert between text-label and text-index """ 150 | 151 | def __init__(self, character): 152 | # character (str): set of the possible characters. 153 | # [GO] for the start token of the attention decoder. [s] for end-of-sentence token. 154 | list_token = ['[GO]', '[s]', '[PAD]'] # ['[s]','[UNK]','[PAD]','[GO]'] 155 | #list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]'] 156 | list_character = list(character) 157 | self.character = list_token + list_character 158 | 159 | self.dict = {} 160 | for i, char in enumerate(self.character): 161 | # print(i, char) 162 | self.dict[char] = i 163 | 164 | def encode(self, text, batch_max_length=25): 165 | """ convert text-label into text-index. 166 | input: 167 | text: text labels of each image. [batch_size] 168 | batch_max_length: max length of text label in the batch. 25 by default 169 | 170 | output: 171 | text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token. 172 | text[:, 0] is [GO] token and text is padded with [GO] token after [s] token. 173 | length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size] 174 | """ 175 | length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence. 176 | # batch_max_length = max(length) # this is not allowed for multi-gpu setting 177 | batch_max_length += 1 178 | # additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token. 179 | batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(self.dict['[PAD]']) 180 | #batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(self.dict['[GO]']) 181 | for i, t in enumerate(text): 182 | text = list(t) 183 | text.append('[s]') 184 | text = [self.dict[char] for char in text] 185 | batch_text[i][0] = 0 186 | batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token 187 | return (batch_text.to(device), torch.IntTensor(length).to(device)) 188 | 189 | def decode(self, text_index, length): 190 | """ convert text-index into text-label. """ 191 | texts = [] 192 | for index, l in enumerate(length): 193 | text = ''.join([self.character[i] for i in text_index[index, :]]) 194 | texts.append(text) 195 | return texts 196 | 197 | class Averager(object): 198 | """Compute average for torch.Tensor, used for loss average.""" 199 | 200 | def __init__(self): 201 | self.reset() 202 | 203 | def add(self, v): 204 | count = v.data.numel() 205 | v = v.data.sum() 206 | self.n_count += count 207 | self.sum += v 208 | 209 | def reset(self): 210 | self.n_count = 0 211 | self.sum = 0 212 | 213 | def val(self): 214 | res = 0 215 | if self.n_count != 0: 216 | res = self.sum / float(self.n_count) 217 | return res 218 | -------------------------------------------------------------------------------- /modules/feature_extraction.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class VGG_FeatureExtractor(nn.Module): 6 | """ FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf) """ 7 | 8 | def __init__(self, input_channel, output_channel=512): 9 | super(VGG_FeatureExtractor, self).__init__() 10 | self.output_channel = [int(output_channel / 8), int(output_channel / 4), 11 | int(output_channel / 2), output_channel] # [64, 128, 256, 512] 12 | self.ConvNet = nn.Sequential( 13 | nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True), 14 | nn.MaxPool2d(2, 2), # 64x16x50 15 | nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1), nn.ReLU(True), 16 | nn.MaxPool2d(2, 2), # 128x8x25 17 | nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), nn.ReLU(True), # 256x8x25 18 | nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1), nn.ReLU(True), 19 | nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25 20 | nn.Conv2d(self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False), 21 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), # 512x4x25 22 | nn.Conv2d(self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False), 23 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), 24 | nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25 25 | nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True)) # 512x1x24 26 | 27 | def forward(self, input): 28 | return self.ConvNet(input) 29 | 30 | 31 | class RCNN_FeatureExtractor(nn.Module): 32 | """ FeatureExtractor of GRCNN (https://papers.nips.cc/paper/6637-gated-recurrent-convolution-neural-network-for-ocr.pdf) """ 33 | 34 | def __init__(self, input_channel, output_channel=512): 35 | super(RCNN_FeatureExtractor, self).__init__() 36 | self.output_channel = [int(output_channel / 8), int(output_channel / 4), 37 | int(output_channel / 2), output_channel] # [64, 128, 256, 512] 38 | self.ConvNet = nn.Sequential( 39 | nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True), 40 | nn.MaxPool2d(2, 2), # 64 x 16 x 50 41 | GRCL(self.output_channel[0], self.output_channel[0], num_iteration=5, kernel_size=3, pad=1), 42 | nn.MaxPool2d(2, 2), # 64 x 8 x 25 43 | GRCL(self.output_channel[0], self.output_channel[1], num_iteration=5, kernel_size=3, pad=1), 44 | nn.MaxPool2d(2, (2, 1), (0, 1)), # 128 x 4 x 26 45 | GRCL(self.output_channel[1], self.output_channel[2], num_iteration=5, kernel_size=3, pad=1), 46 | nn.MaxPool2d(2, (2, 1), (0, 1)), # 256 x 2 x 27 47 | nn.Conv2d(self.output_channel[2], self.output_channel[3], 2, 1, 0, bias=False), 48 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True)) # 512 x 1 x 26 49 | 50 | def forward(self, input): 51 | return self.ConvNet(input) 52 | 53 | 54 | class ResNet_FeatureExtractor(nn.Module): 55 | """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """ 56 | 57 | def __init__(self, input_channel, output_channel=512): 58 | super(ResNet_FeatureExtractor, self).__init__() 59 | self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3]) 60 | 61 | def forward(self, input): 62 | return self.ConvNet(input) 63 | 64 | 65 | # For Gated RCNN 66 | class GRCL(nn.Module): 67 | 68 | def __init__(self, input_channel, output_channel, num_iteration, kernel_size, pad): 69 | super(GRCL, self).__init__() 70 | self.wgf_u = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=False) 71 | self.wgr_x = nn.Conv2d(output_channel, output_channel, 1, 1, 0, bias=False) 72 | self.wf_u = nn.Conv2d(input_channel, output_channel, kernel_size, 1, pad, bias=False) 73 | self.wr_x = nn.Conv2d(output_channel, output_channel, kernel_size, 1, pad, bias=False) 74 | 75 | self.BN_x_init = nn.BatchNorm2d(output_channel) 76 | 77 | self.num_iteration = num_iteration 78 | self.GRCL = [GRCL_unit(output_channel) for _ in range(num_iteration)] 79 | self.GRCL = nn.Sequential(*self.GRCL) 80 | 81 | def forward(self, input): 82 | """ The input of GRCL is consistant over time t, which is denoted by u(0) 83 | thus wgf_u / wf_u is also consistant over time t. 84 | """ 85 | wgf_u = self.wgf_u(input) 86 | wf_u = self.wf_u(input) 87 | x = F.relu(self.BN_x_init(wf_u)) 88 | 89 | for i in range(self.num_iteration): 90 | x = self.GRCL[i](wgf_u, self.wgr_x(x), wf_u, self.wr_x(x)) 91 | 92 | return x 93 | 94 | 95 | class GRCL_unit(nn.Module): 96 | 97 | def __init__(self, output_channel): 98 | super(GRCL_unit, self).__init__() 99 | self.BN_gfu = nn.BatchNorm2d(output_channel) 100 | self.BN_grx = nn.BatchNorm2d(output_channel) 101 | self.BN_fu = nn.BatchNorm2d(output_channel) 102 | self.BN_rx = nn.BatchNorm2d(output_channel) 103 | self.BN_Gx = nn.BatchNorm2d(output_channel) 104 | 105 | def forward(self, wgf_u, wgr_x, wf_u, wr_x): 106 | G_first_term = self.BN_gfu(wgf_u) 107 | G_second_term = self.BN_grx(wgr_x) 108 | G = F.sigmoid(G_first_term + G_second_term) 109 | 110 | x_first_term = self.BN_fu(wf_u) 111 | x_second_term = self.BN_Gx(self.BN_rx(wr_x) * G) 112 | x = F.relu(x_first_term + x_second_term) 113 | 114 | return x 115 | 116 | 117 | class BasicBlock(nn.Module): 118 | expansion = 1 119 | 120 | def __init__(self, inplanes, planes, stride=1, downsample=None): 121 | super(BasicBlock, self).__init__() 122 | self.conv1 = self._conv3x3(inplanes, planes) 123 | self.bn1 = nn.BatchNorm2d(planes) 124 | self.conv2 = self._conv3x3(planes, planes) 125 | self.bn2 = nn.BatchNorm2d(planes) 126 | self.relu = nn.ReLU(inplace=True) 127 | self.downsample = downsample 128 | self.stride = stride 129 | 130 | def _conv3x3(self, in_planes, out_planes, stride=1): 131 | "3x3 convolution with padding" 132 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 133 | padding=1, bias=False) 134 | 135 | def forward(self, x): 136 | residual = x 137 | 138 | out = self.conv1(x) 139 | out = self.bn1(out) 140 | out = self.relu(out) 141 | 142 | out = self.conv2(out) 143 | out = self.bn2(out) 144 | 145 | if self.downsample is not None: 146 | residual = self.downsample(x) 147 | out += residual 148 | out = self.relu(out) 149 | 150 | return out 151 | 152 | 153 | class ResNet(nn.Module): 154 | 155 | def __init__(self, input_channel, output_channel, block, layers): 156 | super(ResNet, self).__init__() 157 | 158 | self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel] 159 | 160 | self.inplanes = int(output_channel / 8) 161 | self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16), 162 | kernel_size=3, stride=1, padding=1, bias=False) 163 | self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16)) 164 | self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes, 165 | kernel_size=3, stride=1, padding=1, bias=False) 166 | self.bn0_2 = nn.BatchNorm2d(self.inplanes) 167 | self.relu = nn.ReLU(inplace=True) 168 | 169 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 170 | self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0]) 171 | self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[ 172 | 0], kernel_size=3, stride=1, padding=1, bias=False) 173 | self.bn1 = nn.BatchNorm2d(self.output_channel_block[0]) 174 | 175 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 176 | self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1) 177 | self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[ 178 | 1], kernel_size=3, stride=1, padding=1, bias=False) 179 | self.bn2 = nn.BatchNorm2d(self.output_channel_block[1]) 180 | 181 | self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1)) 182 | self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1) 183 | self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[ 184 | 2], kernel_size=3, stride=1, padding=1, bias=False) 185 | self.bn3 = nn.BatchNorm2d(self.output_channel_block[2]) 186 | 187 | self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1) 188 | self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ 189 | 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False) 190 | self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3]) 191 | self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ 192 | 3], kernel_size=2, stride=1, padding=0, bias=False) 193 | self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3]) 194 | 195 | def _make_layer(self, block, planes, blocks, stride=1): 196 | downsample = None 197 | if stride != 1 or self.inplanes != planes * block.expansion: 198 | downsample = nn.Sequential( 199 | nn.Conv2d(self.inplanes, planes * block.expansion, 200 | kernel_size=1, stride=stride, bias=False), 201 | nn.BatchNorm2d(planes * block.expansion), 202 | ) 203 | 204 | layers = [] 205 | layers.append(block(self.inplanes, planes, stride, downsample)) 206 | self.inplanes = planes * block.expansion 207 | for i in range(1, blocks): 208 | layers.append(block(self.inplanes, planes)) 209 | 210 | return nn.Sequential(*layers) 211 | 212 | def forward(self, x): 213 | x = self.conv0_1(x) 214 | x = self.bn0_1(x) 215 | x = self.relu(x) 216 | x = self.conv0_2(x) 217 | x = self.bn0_2(x) 218 | x = self.relu(x) 219 | 220 | x = self.maxpool1(x) 221 | x = self.layer1(x) 222 | x = self.conv1(x) 223 | x = self.bn1(x) 224 | x = self.relu(x) 225 | 226 | x = self.maxpool2(x) 227 | x = self.layer2(x) 228 | x = self.conv2(x) 229 | x = self.bn2(x) 230 | x = self.relu(x) 231 | 232 | x = self.maxpool3(x) 233 | x = self.layer3(x) 234 | x = self.conv3(x) 235 | x = self.bn3(x) 236 | x = self.relu(x) 237 | 238 | x = self.layer4(x) 239 | x = self.conv4_1(x) 240 | x = self.bn4_1(x) 241 | x = self.relu(x) 242 | x = self.conv4_2(x) 243 | x = self.bn4_2(x) 244 | x = self.relu(x) 245 | 246 | return x 247 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | import six 5 | import math 6 | import lmdb 7 | import torch 8 | 9 | from natsort import natsorted 10 | from PIL import Image 11 | import numpy as np 12 | from torch.utils.data import Dataset, ConcatDataset, Subset 13 | from torch._utils import _accumulate 14 | import torchvision.transforms as transforms 15 | 16 | 17 | class Batch_Balanced_Dataset(object): 18 | 19 | def __init__(self, opt): 20 | """ 21 | Modulate the data ratio in the batch. 22 | For example, when select_data is "MJ-ST" and batch_ratio is "0.5-0.5", 23 | the 50% of the batch is filled with MJ and the other 50% of the batch is filled with ST. 24 | """ 25 | log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') 26 | dashed_line = '-' * 80 27 | print(dashed_line) 28 | log.write(dashed_line + '\n') 29 | print(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}') 30 | log.write(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}\n') 31 | assert len(opt.select_data) == len(opt.batch_ratio) 32 | 33 | _AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 34 | self.data_loader_list = [] 35 | self.dataloader_iter_list = [] 36 | batch_size_list = [] 37 | Total_batch_size = 0 38 | for selected_d, batch_ratio_d in zip(opt.select_data, opt.batch_ratio): 39 | _batch_size = max(round(opt.batch_size * float(batch_ratio_d)), 1) 40 | print(dashed_line) 41 | log.write(dashed_line + '\n') 42 | _dataset, _dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt, select_data=[selected_d]) 43 | total_number_dataset = len(_dataset) 44 | log.write(_dataset_log) 45 | 46 | """ 47 | The total number of data can be modified with opt.total_data_usage_ratio. 48 | ex) opt.total_data_usage_ratio = 1 indicates 100% usage, and 0.2 indicates 20% usage. 49 | See 4.2 section in our paper. 50 | """ 51 | number_dataset = int(total_number_dataset * float(opt.total_data_usage_ratio)) 52 | dataset_split = [number_dataset, total_number_dataset - number_dataset] 53 | indices = range(total_number_dataset) 54 | _dataset, _ = [Subset(_dataset, indices[offset - length:offset]) 55 | for offset, length in zip(_accumulate(dataset_split), dataset_split)] 56 | selected_d_log = f'num total samples of {selected_d}: {total_number_dataset} x {opt.total_data_usage_ratio} (total_data_usage_ratio) = {len(_dataset)}\n' 57 | selected_d_log += f'num samples of {selected_d} per batch: {opt.batch_size} x {float(batch_ratio_d)} (batch_ratio) = {_batch_size}' 58 | print(selected_d_log) 59 | log.write(selected_d_log + '\n') 60 | batch_size_list.append(str(_batch_size)) 61 | Total_batch_size += _batch_size 62 | 63 | _data_loader = torch.utils.data.DataLoader( 64 | _dataset, batch_size=_batch_size, 65 | shuffle=True, 66 | num_workers=int(opt.workers), 67 | collate_fn=_AlignCollate, pin_memory=True) 68 | self.data_loader_list.append(_data_loader) 69 | self.dataloader_iter_list.append(iter(_data_loader)) 70 | 71 | Total_batch_size_log = f'{dashed_line}\n' 72 | batch_size_sum = '+'.join(batch_size_list) 73 | Total_batch_size_log += f'Total_batch_size: {batch_size_sum} = {Total_batch_size}\n' 74 | Total_batch_size_log += f'{dashed_line}' 75 | opt.batch_size = Total_batch_size 76 | 77 | print(Total_batch_size_log) 78 | log.write(Total_batch_size_log + '\n') 79 | log.close() 80 | 81 | def get_batch(self): 82 | balanced_batch_images = [] 83 | balanced_batch_texts = [] 84 | 85 | for i, data_loader_iter in enumerate(self.dataloader_iter_list): 86 | try: 87 | image, text = data_loader_iter.next() 88 | balanced_batch_images.append(image) 89 | balanced_batch_texts += text 90 | except StopIteration: 91 | self.dataloader_iter_list[i] = iter(self.data_loader_list[i]) 92 | image, text = self.dataloader_iter_list[i].next() 93 | balanced_batch_images.append(image) 94 | balanced_batch_texts += text 95 | except ValueError: 96 | pass 97 | 98 | balanced_batch_images = torch.cat(balanced_batch_images, 0) 99 | 100 | return balanced_batch_images, balanced_batch_texts 101 | 102 | 103 | def hierarchical_dataset(root, opt, select_data='/'): 104 | """ select_data='/' contains all sub-directory of root directory """ 105 | dataset_list = [] 106 | dataset_log = f'dataset_root: {root}\t dataset: {select_data[0]}' 107 | print(dataset_log) 108 | dataset_log += '\n' 109 | for dirpath, dirnames, filenames in os.walk(root+'/'): 110 | if not dirnames: 111 | select_flag = False 112 | for selected_d in select_data: 113 | if selected_d in dirpath: 114 | select_flag = True 115 | break 116 | 117 | if select_flag: 118 | dataset = LmdbDataset(dirpath, opt) 119 | sub_dataset_log = f'sub-directory:\t/{os.path.relpath(dirpath, root)}\t num samples: {len(dataset)}' 120 | print(sub_dataset_log) 121 | dataset_log += f'{sub_dataset_log}\n' 122 | dataset_list.append(dataset) 123 | 124 | concatenated_dataset = ConcatDataset(dataset_list) 125 | 126 | return concatenated_dataset, dataset_log 127 | 128 | 129 | class LmdbDataset(Dataset): 130 | 131 | def __init__(self, root, opt): 132 | 133 | self.root = root 134 | self.opt = opt 135 | self.env = lmdb.open(root, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) 136 | if not self.env: 137 | print('cannot create lmdb from %s' % (root)) 138 | sys.exit(0) 139 | 140 | with self.env.begin(write=False) as txn: 141 | nSamples = int(txn.get('num-samples'.encode())) 142 | self.nSamples = nSamples 143 | 144 | if self.opt.data_filtering_off: 145 | # for fast check or benchmark evaluation with no filtering 146 | self.filtered_index_list = [index + 1 for index in range(self.nSamples)] 147 | else: 148 | """ Filtering part 149 | If you want to evaluate IC15-2077 & CUTE datasets which have special character labels, 150 | use --data_filtering_off and only evaluate on alphabets and digits. 151 | see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L190-L192 152 | 153 | And if you want to evaluate them with the model trained with --sensitive option, 154 | use --sensitive and --data_filtering_off, 155 | see https://github.com/clovaai/deep-text-recognition-benchmark/blob/dff844874dbe9e0ec8c5a52a7bd08c7f20afe704/test.py#L137-L144 156 | """ 157 | self.filtered_index_list = [] 158 | for index in range(self.nSamples): 159 | index += 1 # lmdb starts with 1 160 | label_key = 'label-%09d'.encode() % index 161 | label = txn.get(label_key).decode('utf-8') 162 | 163 | if len(label) > self.opt.batch_max_length: 164 | # print(f'The length of the label is longer than max_length: length 165 | # {len(label)}, {label} in dataset {self.root}') 166 | continue 167 | 168 | # By default, images containing characters which are not in opt.character are filtered. 169 | # You can add [UNK] token to `opt.character` in utils.py instead of this filtering. 170 | out_of_char = f'[^{self.opt.character}]' 171 | if re.search(out_of_char, label.lower()): 172 | continue 173 | 174 | self.filtered_index_list.append(index) 175 | 176 | self.nSamples = len(self.filtered_index_list) 177 | 178 | def __len__(self): 179 | return self.nSamples 180 | 181 | def __getitem__(self, index): 182 | assert index <= len(self), 'index range error' 183 | index = self.filtered_index_list[index] 184 | 185 | with self.env.begin(write=False) as txn: 186 | label_key = 'label-%09d'.encode() % index 187 | label = txn.get(label_key).decode('utf-8') 188 | img_key = 'image-%09d'.encode() % index 189 | imgbuf = txn.get(img_key) 190 | 191 | buf = six.BytesIO() 192 | buf.write(imgbuf) 193 | buf.seek(0) 194 | try: 195 | if self.opt.rgb: 196 | img = Image.open(buf).convert('RGB') # for color image 197 | else: 198 | img = Image.open(buf).convert('L') 199 | 200 | except IOError: 201 | print(f'Corrupted image for {index}') 202 | # make dummy image and dummy label for corrupted image. 203 | if self.opt.rgb: 204 | img = Image.new('RGB', (self.opt.imgW, self.opt.imgH)) 205 | else: 206 | img = Image.new('L', (self.opt.imgW, self.opt.imgH)) 207 | label = '[dummy_label]' 208 | 209 | if not self.opt.sensitive: 210 | label = label.lower() 211 | 212 | # We only train and evaluate on alphanumerics (or pre-defined character set in train.py) 213 | out_of_char = f'[^{self.opt.character}]' 214 | label = re.sub(out_of_char, '', label) 215 | 216 | return (img, label) 217 | 218 | 219 | class RawDataset(Dataset): 220 | 221 | def __init__(self, root, opt): 222 | self.opt = opt 223 | self.image_path_list = [] 224 | for dirpath, dirnames, filenames in os.walk(root): 225 | for name in filenames: 226 | _, ext = os.path.splitext(name) 227 | ext = ext.lower() 228 | if ext == '.jpg' or ext == '.jpeg' or ext == '.png': 229 | self.image_path_list.append(os.path.join(dirpath, name)) 230 | 231 | self.image_path_list = natsorted(self.image_path_list) 232 | self.nSamples = len(self.image_path_list) 233 | 234 | def __len__(self): 235 | return self.nSamples 236 | 237 | def __getitem__(self, index): 238 | 239 | try: 240 | if self.opt.rgb: 241 | img = Image.open(self.image_path_list[index]).convert('RGB') # for color image 242 | else: 243 | img = Image.open(self.image_path_list[index]).convert('L') 244 | 245 | except IOError: 246 | print(f'Corrupted image for {index}') 247 | # make dummy image and dummy label for corrupted image. 248 | if self.opt.rgb: 249 | img = Image.new('RGB', (self.opt.imgW, self.opt.imgH)) 250 | else: 251 | img = Image.new('L', (self.opt.imgW, self.opt.imgH)) 252 | 253 | return (img, self.image_path_list[index]) 254 | 255 | 256 | class ResizeNormalize(object): 257 | 258 | def __init__(self, size, interpolation=Image.BICUBIC): 259 | self.size = size 260 | self.interpolation = interpolation 261 | self.toTensor = transforms.ToTensor() 262 | 263 | def __call__(self, img): 264 | img = img.resize(self.size, self.interpolation) 265 | img = self.toTensor(img) 266 | img.sub_(0.5).div_(0.5) 267 | return img 268 | 269 | 270 | class NormalizePAD(object): 271 | 272 | def __init__(self, max_size, PAD_type='right'): 273 | self.toTensor = transforms.ToTensor() 274 | self.max_size = max_size 275 | self.max_width_half = math.floor(max_size[2] / 2) 276 | self.PAD_type = PAD_type 277 | 278 | def __call__(self, img): 279 | img = self.toTensor(img) 280 | img.sub_(0.5).div_(0.5) 281 | c, h, w = img.size() 282 | Pad_img = torch.FloatTensor(*self.max_size).fill_(0) 283 | Pad_img[:, :, :w] = img # right pad 284 | if self.max_size[2] != w: # add border Pad 285 | Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w) 286 | 287 | return Pad_img 288 | 289 | 290 | class AlignCollate(object): 291 | 292 | def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False): 293 | self.imgH = imgH 294 | self.imgW = imgW 295 | self.keep_ratio_with_pad = keep_ratio_with_pad 296 | 297 | def __call__(self, batch): 298 | batch = filter(lambda x: x is not None, batch) 299 | images, labels = zip(*batch) 300 | 301 | if self.keep_ratio_with_pad: # same concept with 'Rosetta' paper 302 | resized_max_w = self.imgW 303 | input_channel = 3 if images[0].mode == 'RGB' else 1 304 | transform = NormalizePAD((input_channel, self.imgH, resized_max_w)) 305 | 306 | resized_images = [] 307 | for image in images: 308 | w, h = image.size 309 | ratio = w / float(h) 310 | if math.ceil(self.imgH * ratio) > self.imgW: 311 | resized_w = self.imgW 312 | else: 313 | resized_w = math.ceil(self.imgH * ratio) 314 | 315 | resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC) 316 | resized_images.append(transform(resized_image)) 317 | # resized_image.save('./image_test/%d_test.jpg' % w) 318 | 319 | image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0) 320 | 321 | else: 322 | transform = ResizeNormalize((self.imgW, self.imgH)) 323 | image_tensors = [transform(image) for image in images] 324 | image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0) 325 | 326 | return image_tensors, labels 327 | 328 | 329 | def tensor2im(image_tensor, imtype=np.uint8): 330 | image_numpy = image_tensor.cpu().float().numpy() 331 | if image_numpy.shape[0] == 1: 332 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 333 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 334 | return image_numpy.astype(imtype) 335 | 336 | 337 | def save_image(image_numpy, image_path): 338 | image_pil = Image.fromarray(image_numpy) 339 | image_pil.save(image_path) 340 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import os 3 | import time 4 | import string 5 | import argparse 6 | import re 7 | 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | import torch.utils.data 11 | import torch.nn.functional as F 12 | import numpy as np 13 | from nltk.metrics.distance import edit_distance 14 | 15 | from utils import CTCLabelConverter, AttnLabelConverter, TransformerLabelConverter, Averager 16 | from dataset import hierarchical_dataset, AlignCollate 17 | from model import Model 18 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 19 | 20 | 21 | def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=False): 22 | """ evaluation with 10 benchmark evaluation datasets """ 23 | # The evaluation datasets, dataset order is same with Table 1 in our paper. 24 | eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', 25 | 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80'] 26 | 27 | # # To easily compute the total accuracy of our paper. 28 | # eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_867', 29 | # 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80'] 30 | 31 | if calculate_infer_time: 32 | evaluation_batch_size = 1 # batch_size should be 1 to calculate the GPU inference time per image. 33 | else: 34 | evaluation_batch_size = opt.batch_size 35 | 36 | list_accuracy = [] 37 | total_forward_time = 0 38 | total_evaluation_data_number = 0 39 | total_correct_number = 0 40 | log = open(f'./result/{opt.exp_name}/log_all_evaluation.txt', 'a') 41 | dashed_line = '-' * 80 42 | print(dashed_line) 43 | log.write(dashed_line + '\n') 44 | for eval_data in eval_data_list: 45 | eval_data_path = os.path.join(opt.eval_data, eval_data) 46 | AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 47 | eval_data, eval_data_log = hierarchical_dataset(root=eval_data_path, opt=opt) 48 | evaluation_loader = torch.utils.data.DataLoader( 49 | eval_data, batch_size=evaluation_batch_size, 50 | shuffle=False, 51 | num_workers=int(opt.workers), 52 | collate_fn=AlignCollate_evaluation, pin_memory=True) 53 | 54 | _, accuracy_by_best_model, norm_ED_by_best_model, _, _, _, infer_time, length_of_data = validation( 55 | model, criterion, evaluation_loader, converter, opt) 56 | list_accuracy.append(f'{accuracy_by_best_model:0.3f}') 57 | total_forward_time += infer_time 58 | total_evaluation_data_number += len(eval_data) 59 | total_correct_number += accuracy_by_best_model * length_of_data 60 | log.write(eval_data_log) 61 | print(f'Acc {accuracy_by_best_model:0.3f}\t normalized_ED {norm_ED_by_best_model:0.3f}') 62 | log.write(f'Acc {accuracy_by_best_model:0.3f}\t normalized_ED {norm_ED_by_best_model:0.3f}\n') 63 | print(dashed_line) 64 | log.write(dashed_line + '\n') 65 | 66 | averaged_forward_time = total_forward_time / total_evaluation_data_number * 1000 67 | total_accuracy = total_correct_number / total_evaluation_data_number 68 | params_num = sum([np.prod(p.size()) for p in model.parameters()]) 69 | 70 | evaluation_log = 'accuracy: ' 71 | for name, accuracy in zip(eval_data_list, list_accuracy): 72 | evaluation_log += f'{name}: {accuracy}\t' 73 | evaluation_log += f'total_accuracy: {total_accuracy:0.3f}\t' 74 | evaluation_log += f'averaged_infer_time: {averaged_forward_time:0.3f}\t# parameters: {params_num/1e6:0.3f}' 75 | print(evaluation_log) 76 | log.write(evaluation_log + '\n') 77 | log.close() 78 | 79 | return None 80 | 81 | 82 | def validation(model, criterion, evaluation_loader, converter, opt): 83 | """ validation or evaluation """ 84 | n_correct = 0 85 | norm_ED = 0 86 | length_of_data = 0 87 | infer_time = 0 88 | valid_loss_avg = Averager() 89 | 90 | for i, (image_tensors, labels) in enumerate(evaluation_loader): 91 | batch_size = image_tensors.size(0) 92 | length_of_data = length_of_data + batch_size 93 | image = image_tensors.to(device) 94 | # For max length prediction 95 | length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) 96 | text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) 97 | 98 | text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length) 99 | 100 | start_time = time.time() 101 | if 'CTC' in opt.Prediction: 102 | preds = model(image, text_for_pred) 103 | forward_time = time.time() - start_time 104 | 105 | # Calculate evaluation loss for CTC deocder. 106 | preds_size = torch.IntTensor([preds.size(1)] * batch_size) 107 | # permute 'preds' to use CTCloss format 108 | if opt.baiduCTC: 109 | cost = criterion(preds.permute(1, 0, 2), text_for_loss, preds_size, length_for_loss) / batch_size 110 | else: 111 | cost = criterion(preds.log_softmax(2).permute(1, 0, 2), text_for_loss, preds_size, length_for_loss) 112 | 113 | # Select max probabilty (greedy decoding) then decode index to character 114 | if opt.baiduCTC: 115 | _, preds_index = preds.max(2) 116 | preds_index = preds_index.view(-1) 117 | else: 118 | _, preds_index = preds.max(2) 119 | preds_str = converter.decode(preds_index.data, preds_size.data) 120 | 121 | else: 122 | preds = model(image, text_for_pred, is_train=False) 123 | forward_time = time.time() - start_time 124 | 125 | preds = preds[:, :text_for_loss.shape[1] - 1, :] 126 | target = text_for_loss[:, 1:] # without [GO] Symbol 127 | cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), target.contiguous().view(-1)) 128 | 129 | # select max probabilty (greedy decoding) then decode index to character 130 | _, preds_index = preds.max(2) 131 | preds_str = converter.decode(preds_index, length_for_pred) 132 | labels = converter.decode(text_for_loss[:, 1:], length_for_loss) 133 | 134 | infer_time += forward_time 135 | valid_loss_avg.add(cost) 136 | 137 | # calculate accuracy & confidence score 138 | preds_prob = F.softmax(preds, dim=2) 139 | preds_max_prob, _ = preds_prob.max(dim=2) 140 | confidence_score_list = [] 141 | for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob): 142 | if 'Attn' in opt.Prediction or "Transformer" in opt.Prediction: 143 | gt = gt[:gt.find('[s]')] 144 | pred_EOS = pred.find('[s]') 145 | pred = pred[:pred_EOS] # prune after "end of sentence" token ([s]) 146 | pred_max_prob = pred_max_prob[:pred_EOS] 147 | 148 | # To evaluate 'case sensitive model' with alphanumeric and case insensitve setting. 149 | if opt.sensitive and opt.data_filtering_off: 150 | pred = pred.lower() 151 | gt = gt.lower() 152 | alphanumeric_case_insensitve = '0123456789abcdefghijklmnopqrstuvwxyz' 153 | out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]' 154 | pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred) 155 | gt = re.sub(out_of_alphanumeric_case_insensitve, '', gt) 156 | #import pdb;pdb.set_trace() 157 | 158 | if pred == gt: 159 | n_correct += 1 160 | 161 | ''' 162 | (old version) ICDAR2017 DOST Normalized Edit Distance https://rrc.cvc.uab.es/?ch=7&com=tasks 163 | "For each word we calculate the normalized edit distance to the length of the ground truth transcription." 164 | if len(gt) == 0: 165 | norm_ED += 1 166 | else: 167 | norm_ED += edit_distance(pred, gt) / len(gt) 168 | ''' 169 | 170 | # ICDAR2019 Normalized Edit Distance 171 | if len(gt) == 0 or len(pred) == 0: 172 | norm_ED += 0 173 | elif len(gt) > len(pred): 174 | norm_ED += 1 - edit_distance(pred, gt) / len(gt) 175 | else: 176 | norm_ED += 1 - edit_distance(pred, gt) / len(pred) 177 | 178 | # calculate confidence score (= multiply of pred_max_prob) 179 | try: 180 | confidence_score = pred_max_prob.cumprod(dim=0)[-1] 181 | except: 182 | confidence_score = 0 # for empty pred case, when prune after "end of sentence" token ([s]) 183 | confidence_score_list.append(confidence_score) 184 | # print(pred, gt, pred==gt, confidence_score) 185 | 186 | accuracy = n_correct / float(length_of_data) * 100 187 | norm_ED = norm_ED / float(length_of_data) # ICDAR2019 Normalized Edit Distance 188 | 189 | return valid_loss_avg.val(), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data 190 | 191 | 192 | def test(opt): 193 | """ model configuration """ 194 | if 'CTC' in opt.Prediction: 195 | converter = CTCLabelConverter(opt.character) 196 | elif "Transformer" in opt.Prediction: 197 | converter = TransformerLabelConverter(opt.character) 198 | else: 199 | converter = AttnLabelConverter(opt.character) 200 | opt.num_class = len(converter.character) 201 | 202 | if opt.rgb: 203 | opt.input_channel = 3 204 | model = Model(opt) 205 | print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, 206 | opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, 207 | opt.SequenceModeling, opt.Prediction) 208 | model = torch.nn.DataParallel(model).to(device) 209 | 210 | # load model 211 | print('loading pretrained model from %s' % opt.saved_model) 212 | model.load_state_dict(torch.load(opt.saved_model, map_location=device)) 213 | opt.exp_name = '_'.join(opt.saved_model.split('/')[1:]) 214 | # print(model) 215 | 216 | """ keep evaluation model and result logs """ 217 | os.makedirs(f'./result/{opt.exp_name}', exist_ok=True) 218 | os.system(f'cp {opt.saved_model} ./result/{opt.exp_name}/') 219 | 220 | """ setup loss """ 221 | if 'CTC' in opt.Prediction: 222 | criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) 223 | elif "Transformer" in opt.Prediction: 224 | criterion = torch.nn.CrossEntropyLoss(ignore_index=2).to(device) # ignore [PAD] token = ignore index 1 225 | else: 226 | criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 227 | 228 | """ evaluation """ 229 | model.eval() 230 | with torch.no_grad(): 231 | if opt.benchmark_all_eval: # evaluation with 10 benchmark evaluation datasets 232 | benchmark_all_eval(model, criterion, converter, opt) 233 | else: 234 | log = open(f'./result/{opt.exp_name}/log_evaluation.txt', 'a') 235 | AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 236 | eval_data, eval_data_log = hierarchical_dataset(root=opt.eval_data, opt=opt) 237 | evaluation_loader = torch.utils.data.DataLoader( 238 | eval_data, batch_size=opt.batch_size, 239 | shuffle=False, 240 | num_workers=int(opt.workers), 241 | collate_fn=AlignCollate_evaluation, pin_memory=True) 242 | _, accuracy_by_best_model, _, _, _, _, _, _ = validation( 243 | model, criterion, evaluation_loader, converter, opt) 244 | log.write(eval_data_log) 245 | print(f'{accuracy_by_best_model:0.3f}') 246 | log.write(f'{accuracy_by_best_model:0.3f}\n') 247 | log.close() 248 | 249 | 250 | if __name__ == '__main__': 251 | parser = argparse.ArgumentParser() 252 | parser.add_argument('--eval_data', required=True, help='path to evaluation dataset') 253 | parser.add_argument('--benchmark_all_eval', action='store_true', help='evaluate 10 benchmark evaluation datasets') 254 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) 255 | parser.add_argument('--batch_size', type=int, default=192, help='input batch size') 256 | parser.add_argument('--saved_model', required=True, help="path to saved_model to evaluation") 257 | """ Data processing """ 258 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') 259 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') 260 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') 261 | parser.add_argument('--rgb', action='store_true', help='use rgb input') 262 | parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label') 263 | parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode') 264 | parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize') 265 | parser.add_argument('--data_filtering_off', action='store_true', help='for data_filtering_off mode') 266 | parser.add_argument('--baiduCTC', action='store_true', help='for data_filtering_off mode') 267 | """ Model Architecture """ 268 | parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS') 269 | parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet') 270 | parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM') 271 | parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn') 272 | parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') 273 | parser.add_argument('--input_channel', type=int, default=1, help='the number of input channel of Feature extractor') 274 | parser.add_argument('--output_channel', type=int, default=512, 275 | help='the number of output channel of Feature extractor') 276 | parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') 277 | 278 | opt = parser.parse_args() 279 | 280 | """ vocab / character number configuration """ 281 | if opt.sensitive: 282 | opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). 283 | 284 | cudnn.benchmark = True 285 | cudnn.deterministic = True 286 | opt.num_gpu = torch.cuda.device_count() 287 | 288 | test(opt) 289 | -------------------------------------------------------------------------------- /modules/transformer_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | 5 | import numpy as np 6 | import random 7 | import math 8 | import time 9 | 10 | class Encoder(nn.Module): 11 | def __init__(self, 12 | hid_dim, 13 | n_layers, 14 | n_heads, 15 | pf_dim, 16 | dropout, 17 | device, 18 | max_length = 100): 19 | super().__init__() 20 | 21 | self.device = device 22 | 23 | #self.tok_embedding = nn.Embedding(input_dim, hid_dim) 24 | self.pos_embedding = nn.Embedding(max_length, hid_dim) 25 | 26 | self.layers = nn.ModuleList([EncoderLayer(hid_dim, 27 | n_heads, 28 | pf_dim, 29 | dropout, 30 | device) 31 | for _ in range(n_layers)]) 32 | 33 | self.dropout = nn.Dropout(dropout) 34 | 35 | self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device) 36 | 37 | def forward(self, src, src_mask): 38 | 39 | #src = [batch size, src len] 40 | #src_mask = [batch size, 1, 1, src len] 41 | 42 | batch_size = src.shape[0] 43 | src_len = src.shape[1] 44 | 45 | pos = torch.arange(0, src_len).unsqueeze(0).repeat(batch_size, 1).to(self.device) 46 | 47 | #pos = [batch size, src len] 48 | 49 | src = self.dropout(((src) * self.scale) + self.pos_embedding(pos)) 50 | 51 | #src = [batch size, src len, hid dim] 52 | 53 | for layer in self.layers: 54 | src = layer(src, src_mask) 55 | 56 | #src = [batch size, src len, hid dim] 57 | 58 | return src 59 | 60 | 61 | class EncoderLayer(nn.Module): 62 | def __init__(self, 63 | hid_dim, 64 | n_heads, 65 | pf_dim, 66 | dropout, 67 | device): 68 | super().__init__() 69 | 70 | self.self_attn_layer_norm = nn.LayerNorm(hid_dim) 71 | self.ff_layer_norm = nn.LayerNorm(hid_dim) 72 | self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) 73 | self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, 74 | pf_dim, 75 | dropout) 76 | self.dropout = nn.Dropout(dropout) 77 | 78 | def forward(self, src, src_mask): 79 | 80 | #src = [batch size, src len, hid dim] 81 | #src_mask = [batch size, 1, 1, src len] 82 | 83 | #self attention 84 | _src, _ = self.self_attention(src, src, src, src_mask) 85 | 86 | #dropout, residual connection and layer norm 87 | src = self.self_attn_layer_norm(src + self.dropout(_src)) 88 | 89 | #src = [batch size, src len, hid dim] 90 | 91 | #positionwise feedforward 92 | _src = self.positionwise_feedforward(src) 93 | 94 | #dropout, residual and layer norm 95 | src = self.ff_layer_norm(src + self.dropout(_src)) 96 | 97 | #src = [batch size, src len, hid dim] 98 | 99 | return src 100 | 101 | 102 | class MultiHeadAttentionLayer(nn.Module): 103 | def __init__(self, hid_dim, n_heads, dropout, device): 104 | super().__init__() 105 | 106 | assert hid_dim % n_heads == 0 107 | 108 | self.hid_dim = hid_dim 109 | self.n_heads = n_heads 110 | self.head_dim = hid_dim // n_heads 111 | 112 | self.fc_q = nn.Linear(hid_dim, hid_dim) 113 | self.fc_k = nn.Linear(hid_dim, hid_dim) 114 | self.fc_v = nn.Linear(hid_dim, hid_dim) 115 | 116 | self.fc_o = nn.Linear(hid_dim, hid_dim) 117 | 118 | self.dropout = nn.Dropout(dropout) 119 | 120 | self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device) 121 | 122 | def forward(self, query, key, value, mask = None): 123 | 124 | batch_size = query.shape[0] 125 | 126 | #query = [batch size, query len, hid dim] 127 | #key = [batch size, key len, hid dim] 128 | #value = [batch size, value len, hid dim] 129 | 130 | Q = self.fc_q(query) 131 | K = self.fc_k(key) 132 | V = self.fc_v(value) 133 | 134 | #Q = [batch size, query len, hid dim] 135 | #K = [batch size, key len, hid dim] 136 | #V = [batch size, value len, hid dim] 137 | 138 | Q = Q.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) 139 | K = K.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) 140 | V = V.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) 141 | 142 | #Q = [batch size, n heads, query len, head dim] 143 | #K = [batch size, n heads, key len, head dim] 144 | #V = [batch size, n heads, value len, head dim] 145 | 146 | energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale 147 | 148 | #energy = [batch size, n heads, query len, key len] 149 | 150 | if mask is not None: 151 | energy = energy.masked_fill(mask == 0, -1e10) 152 | 153 | attention = torch.softmax(energy, dim = -1) 154 | 155 | #attention = [batch size, n heads, query len, key len] 156 | 157 | x = torch.matmul(self.dropout(attention), V) 158 | 159 | #x = [batch size, n heads, query len, head dim] 160 | 161 | x = x.permute(0, 2, 1, 3).contiguous() 162 | 163 | #x = [batch size, query len, n heads, head dim] 164 | 165 | x = x.view(batch_size, -1, self.hid_dim) 166 | 167 | #x = [batch size, query len, hid dim] 168 | 169 | x = self.fc_o(x) 170 | 171 | #x = [batch size, query len, hid dim] 172 | 173 | return x, attention 174 | 175 | 176 | class PositionwiseFeedforwardLayer(nn.Module): 177 | def __init__(self, hid_dim, pf_dim, dropout): 178 | super().__init__() 179 | 180 | self.fc_1 = nn.Linear(hid_dim, pf_dim) 181 | self.fc_2 = nn.Linear(pf_dim, hid_dim) 182 | 183 | self.dropout = nn.Dropout(dropout) 184 | 185 | def forward(self, x): 186 | 187 | #x = [batch size, seq len, hid dim] 188 | 189 | x = self.dropout(torch.relu(self.fc_1(x))) 190 | 191 | #x = [batch size, seq len, pf dim] 192 | 193 | x = self.fc_2(x) 194 | 195 | #x = [batch size, seq len, hid dim] 196 | 197 | return x 198 | 199 | class Decoder(nn.Module): 200 | def __init__(self, 201 | output_dim, 202 | hid_dim, 203 | n_layers, 204 | n_heads, 205 | pf_dim, 206 | dropout, 207 | device, 208 | max_length = 100): 209 | super().__init__() 210 | 211 | self.device = device 212 | 213 | self.tok_embedding = nn.Embedding(output_dim+3, hid_dim) 214 | self.pos_embedding = nn.Embedding(max_length, hid_dim) 215 | 216 | self.layers = nn.ModuleList([DecoderLayer(hid_dim, 217 | n_heads, 218 | pf_dim, 219 | dropout, 220 | device) 221 | for _ in range(n_layers)]) 222 | 223 | self.fc_out = nn.Linear(hid_dim, output_dim) 224 | 225 | self.dropout = nn.Dropout(dropout) 226 | 227 | self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device) 228 | 229 | def forward(self, trg, enc_src, trg_mask, src_mask): 230 | 231 | #trg = [batch size, trg len] 232 | #enc_src = [batch size, src len, hid dim] 233 | #trg_mask = [batch size, 1, trg len, trg len] 234 | #src_mask = [batch size, 1, 1, src len] 235 | 236 | batch_size = trg.shape[0] 237 | trg_len = trg.shape[1] 238 | 239 | pos = torch.arange(0, trg_len).unsqueeze(0).repeat(batch_size, 1).to(self.device) 240 | 241 | #pos = [batch size, trg len] 242 | 243 | trg = self.dropout((self.tok_embedding(trg) * self.scale) + self.pos_embedding(pos)) 244 | 245 | #trg = [batch size, trg len, hid dim] 246 | 247 | for layer in self.layers: 248 | trg, attention = layer(trg, enc_src, trg_mask, src_mask) 249 | 250 | #trg = [batch size, trg len, hid dim] 251 | #attention = [batch size, n heads, trg len, src len] 252 | 253 | output = self.fc_out(trg) 254 | 255 | #output = [batch size, trg len, output dim] 256 | 257 | return output, attention 258 | 259 | class DecoderLayer(nn.Module): 260 | def __init__(self, 261 | hid_dim, 262 | n_heads, 263 | pf_dim, 264 | dropout, 265 | device): 266 | super().__init__() 267 | 268 | self.self_attn_layer_norm = nn.LayerNorm(hid_dim) 269 | self.enc_attn_layer_norm = nn.LayerNorm(hid_dim) 270 | self.ff_layer_norm = nn.LayerNorm(hid_dim) 271 | self.self_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) 272 | self.encoder_attention = MultiHeadAttentionLayer(hid_dim, n_heads, dropout, device) 273 | self.positionwise_feedforward = PositionwiseFeedforwardLayer(hid_dim, 274 | pf_dim, 275 | dropout) 276 | self.dropout = nn.Dropout(dropout) 277 | 278 | def forward(self, trg, enc_src, trg_mask, src_mask): 279 | 280 | #trg = [batch size, trg len, hid dim] 281 | #enc_src = [batch size, src len, hid dim] 282 | #trg_mask = [batch size, 1, trg len, trg len] 283 | #src_mask = [batch size, 1, 1, src len] 284 | 285 | #self attention 286 | _trg, _ = self.self_attention(trg, trg, trg, trg_mask) 287 | 288 | #dropout, residual connection and layer norm 289 | trg = self.self_attn_layer_norm(trg + self.dropout(_trg)) 290 | 291 | #trg = [batch size, trg len, hid dim] 292 | 293 | #encoder attention 294 | _trg, attention = self.encoder_attention(trg, enc_src, enc_src, src_mask) 295 | 296 | #dropout, residual connection and layer norm 297 | trg = self.enc_attn_layer_norm(trg + self.dropout(_trg)) 298 | 299 | #trg = [batch size, trg len, hid dim] 300 | 301 | #positionwise feedforward 302 | _trg = self.positionwise_feedforward(trg) 303 | 304 | #dropout, residual and layer norm 305 | trg = self.ff_layer_norm(trg + self.dropout(_trg)) 306 | 307 | #trg = [batch size, trg len, hid dim] 308 | #attention = [batch size, n heads, trg len, src len] 309 | 310 | return trg, attention 311 | 312 | 313 | class Seq2Seq(nn.Module): 314 | def __init__(self, 315 | encoder, 316 | decoder, 317 | trg_pad_idx, 318 | device): 319 | super().__init__() 320 | 321 | self.encoder = encoder 322 | self.decoder = decoder 323 | #self.src_pad_idx = src_pad_idx 324 | self.trg_pad_idx = trg_pad_idx 325 | self.device = device 326 | 327 | def make_src_mask(self, src): 328 | 329 | #src = [batch size, src len] 330 | 331 | #src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2) 332 | 333 | #src_mask = [batch size, 1, 1, src len] 334 | 335 | return None 336 | 337 | def make_trg_mask(self, trg): 338 | 339 | #trg = [batch size, trg len] 340 | 341 | trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2) 342 | #import pdb;pdb.set_trace() 343 | #print("trg_pad_mask: ", trg_pad_mask) 344 | trg_pad_mask[:, :, :, 0] = 1 345 | #print("trg_pad_mask: ", trg_pad_mask) 346 | 347 | #trg_pad_mask = [batch size, 1, 1, trg len] 348 | 349 | trg_len = trg.shape[1] 350 | 351 | trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device = self.device)).bool() 352 | 353 | #trg_sub_mask = [trg len, trg len] 354 | 355 | trg_mask = trg_pad_mask & trg_sub_mask 356 | 357 | #trg_mask = [batch size, 1, trg len, trg len] 358 | 359 | return trg_mask 360 | 361 | def forward(self, src, trg, is_train): 362 | 363 | #src = [batch size, src len] 364 | #trg = [batch size, trg len] 365 | batch_size = src.shape[0] 366 | if is_train: 367 | src_mask = None#self.make_src_mask(src) 368 | trg_mask = self.make_trg_mask(trg) 369 | #import pdb;pdb.set_trace() 370 | #print(trg) 371 | #print(mmm) 372 | 373 | #src_mask = [batch size, 1, 1, src len] 374 | #trg_mask = [batch size, 1, trg len, trg len] 375 | 376 | enc_src = self.encoder(src, src_mask) 377 | 378 | #enc_src = [batch size, src len, hid dim] 379 | 380 | output, attention = self.decoder(trg, enc_src, trg_mask, src_mask) 381 | 382 | #output = [batch size, trg len, output dim] 383 | #attention = [batch size, n heads, trg len, src len] 384 | else: 385 | #tokens = [src_field.init_token] + tokens + [src_field.eos_token] 386 | 387 | #src_indexes = [src_field.vocab.stoi[token] for token in tokens] 388 | 389 | #src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device) 390 | 391 | #src_mask = model.make_src_mask(src_tensor) 392 | 393 | #with torch.no_grad(): 394 | max_len = trg.shape[1] - 1 395 | #targets = torch.LongTensor(batch_size).fill_(0).to(device) # [GO] token 396 | #probs = torch.FloatTensor(batch_size, max_len, self.num_classes).fill_(0).to(device) 397 | src_mask = None 398 | enc_src = self.encoder(src, src_mask) 399 | 400 | #trg_indexes = [0] 401 | trg_tensor = trg 402 | 403 | for i in range(max_len): 404 | 405 | #trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device) 406 | 407 | trg_mask = self.make_trg_mask(trg_tensor) 408 | #import pdb; pdb.set_trace() 409 | 410 | #with torch.no_grad(): 411 | output, attention = self.decoder(trg_tensor, enc_src, trg_mask, src_mask) 412 | 413 | pred_token = output.argmax(2)[:,i] 414 | trg_tensor[:, i + 1] = pred_token 415 | 416 | #trg_indexes.append(pred_token) 417 | 418 | #if pred_token == : 419 | # break 420 | 421 | #trg_tokens = [trg_field.vocab.itos[i] for i in trg_indexes] 422 | #print(trg_tensor) 423 | 424 | 425 | return output 426 | 427 | 428 | 429 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import os 3 | import sys 4 | import time 5 | import random 6 | import string 7 | import argparse 8 | 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | import torch.nn.init as init 12 | import torch.optim as optim 13 | import torch.utils.data 14 | import numpy as np 15 | 16 | from utils import CTCLabelConverter, CTCLabelConverterForBaiduWarpctc, AttnLabelConverter, TransformerLabelConverter, Averager 17 | from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset 18 | from model import Model 19 | from test import validation 20 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 21 | 22 | 23 | def train(opt): 24 | """ dataset preparation """ 25 | if not opt.data_filtering_off: 26 | print('Filtering the images containing characters which are not in opt.character') 27 | print('Filtering the images whose label is longer than opt.batch_max_length') 28 | # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 29 | 30 | opt.select_data = opt.select_data.split('-') 31 | opt.batch_ratio = opt.batch_ratio.split('-') 32 | train_dataset = Batch_Balanced_Dataset(opt) 33 | 34 | log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') 35 | AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 36 | valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) 37 | valid_loader = torch.utils.data.DataLoader( 38 | valid_dataset, batch_size=opt.batch_size, 39 | shuffle=True, # 'True' to check training progress with validation function. 40 | num_workers=int(opt.workers), 41 | collate_fn=AlignCollate_valid, pin_memory=True) 42 | log.write(valid_dataset_log) 43 | print('-' * 80) 44 | log.write('-' * 80 + '\n') 45 | log.close() 46 | 47 | """ model configuration """ 48 | if 'CTC' in opt.Prediction: 49 | if opt.baiduCTC: 50 | converter = CTCLabelConverterForBaiduWarpctc(opt.character) 51 | else: 52 | converter = CTCLabelConverter(opt.character) 53 | elif "Transformer" in opt.Prediction: 54 | converter = TransformerLabelConverter(opt.character) 55 | else: 56 | converter = AttnLabelConverter(opt.character) 57 | opt.num_class = len(converter.character) 58 | 59 | if opt.rgb: 60 | opt.input_channel = 3 61 | model = Model(opt) 62 | print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, 63 | opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, 64 | opt.SequenceModeling, opt.Prediction) 65 | 66 | # weight initialization 67 | for name, param in model.named_parameters(): 68 | #if 'localization_fc2' in name: 69 | if 'localization_fc2' in name or 'decoder' in name or 'self_attn' in name or 'Seq2Seq' in name: 70 | print(f'Skip {name} as it is already initialized') 71 | continue 72 | try: 73 | if 'bias' in name: 74 | init.constant_(param, 0.0) 75 | elif 'weight' in name: 76 | init.kaiming_normal_(param) 77 | except Exception as e: # for batchnorm. 78 | if 'weight' in name: 79 | param.data.fill_(1) 80 | continue 81 | 82 | # data parallel for multi-GPU 83 | model = torch.nn.DataParallel(model).to(device) 84 | model.train() 85 | if opt.saved_model != '': 86 | print(f'loading pretrained model from {opt.saved_model}') 87 | if opt.FT: 88 | model.load_state_dict(torch.load(opt.saved_model), strict=False) 89 | else: 90 | model.load_state_dict(torch.load(opt.saved_model)) 91 | print("Model:") 92 | print(model) 93 | 94 | """ setup loss """ 95 | if 'CTC' in opt.Prediction: 96 | if opt.baiduCTC: 97 | # need to install warpctc. see our guideline. 98 | from warpctc_pytorch import CTCLoss 99 | criterion = CTCLoss() 100 | else: 101 | criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) 102 | elif "Transformer" in opt.Prediction: 103 | criterion = torch.nn.CrossEntropyLoss(ignore_index=2).to(device) # ignore [PAD] token = ignore index 1 104 | else: 105 | criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 106 | # loss averager 107 | loss_avg = Averager() 108 | 109 | # filter that only require gradient decent 110 | filtered_parameters = [] 111 | params_num = [] 112 | for p in filter(lambda p: p.requires_grad, model.parameters()): 113 | filtered_parameters.append(p) 114 | params_num.append(np.prod(p.size())) 115 | print('Trainable params num : ', sum(params_num)) 116 | # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] 117 | 118 | # setup optimizer 119 | if opt.adam: 120 | optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) 121 | else: 122 | print("use Adadelta") 123 | optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) 124 | print("Optimizer:") 125 | print(optimizer) 126 | 127 | """ final options """ 128 | # print(opt) 129 | with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: 130 | opt_log = '------------ Options -------------\n' 131 | args = vars(opt) 132 | for k, v in args.items(): 133 | opt_log += f'{str(k)}: {str(v)}\n' 134 | opt_log += '---------------------------------------\n' 135 | print(opt_log) 136 | opt_file.write(opt_log) 137 | 138 | """ start training """ 139 | start_iter = 0 140 | if opt.saved_model != '': 141 | try: 142 | start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) 143 | print(f'continue to train, start_iter: {start_iter}') 144 | except: 145 | pass 146 | 147 | start_time = time.time() 148 | best_accuracy = -1 149 | best_norm_ED = -1 150 | iteration = start_iter 151 | 152 | while(True): 153 | # train part 154 | image_tensors, labels = train_dataset.get_batch() 155 | image = image_tensors.to(device) 156 | text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) 157 | batch_size = image.size(0) 158 | 159 | if 'CTC' in opt.Prediction: 160 | preds = model(image, text) 161 | preds_size = torch.IntTensor([preds.size(1)] * batch_size) 162 | if opt.baiduCTC: 163 | preds = preds.permute(1, 0, 2) # to use CTCLoss format 164 | cost = criterion(preds, text, preds_size, length) / batch_size 165 | else: 166 | preds = preds.log_softmax(2).permute(1, 0, 2) 167 | cost = criterion(preds, text, preds_size, length) 168 | 169 | else: 170 | preds = model(image, text[:, :-1]) # align with Attention.forward 171 | target = text[:, 1:] # without [GO] Symbol 172 | cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) 173 | 174 | model.zero_grad() 175 | cost.backward() 176 | torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) 177 | optimizer.step() 178 | 179 | loss_avg.add(cost) 180 | 181 | # validation part 182 | if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' 183 | elapsed_time = time.time() - start_time 184 | # for log 185 | with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: 186 | model.eval() 187 | with torch.no_grad(): 188 | valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( 189 | model, criterion, valid_loader, converter, opt) 190 | model.train() 191 | 192 | # training loss and validation loss 193 | loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' 194 | loss_avg.reset() 195 | 196 | current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' 197 | 198 | # keep best accuracy model (on valid dataset) 199 | if current_accuracy > best_accuracy: 200 | best_accuracy = current_accuracy 201 | torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy.pth') 202 | if current_norm_ED > best_norm_ED: 203 | best_norm_ED = current_norm_ED 204 | torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth') 205 | best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' 206 | 207 | loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' 208 | print(loss_model_log) 209 | log.write(loss_model_log + '\n') 210 | 211 | # show some predicted results 212 | dashed_line = '-' * 80 213 | head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' 214 | predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' 215 | for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): 216 | if 'Attn' in opt.Prediction or "Transformer" in opt.Prediction: 217 | gt = gt[:gt.find('[s]')] 218 | pred = pred[:pred.find('[s]')] 219 | 220 | predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' 221 | predicted_result_log += f'{dashed_line}' 222 | print(predicted_result_log) 223 | log.write(predicted_result_log + '\n') 224 | 225 | # save model per 1e+5 iter. 226 | if (iteration + 1) % 1e+5 == 0: 227 | torch.save( 228 | model.state_dict(), f'./saved_models/{opt.exp_name}/iter_{iteration+1}.pth') 229 | 230 | if (iteration + 1) == opt.num_iter: 231 | print('end the training') 232 | sys.exit() 233 | iteration += 1 234 | 235 | 236 | if __name__ == '__main__': 237 | parser = argparse.ArgumentParser() 238 | parser.add_argument('--exp_name', help='Where to store logs and models') 239 | parser.add_argument('--train_data', required=True, help='path to training dataset') 240 | parser.add_argument('--valid_data', required=True, help='path to validation dataset') 241 | parser.add_argument('--manualSeed', type=int, default=1111, help='for random seed setting') 242 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) 243 | parser.add_argument('--batch_size', type=int, default=192, help='input batch size') 244 | parser.add_argument('--num_iter', type=int, default=300000, help='number of iterations to train for') 245 | parser.add_argument('--valInterval', type=int, default=2000, help='Interval between each validation') 246 | parser.add_argument('--saved_model', default='', help="path to model to continue training") 247 | parser.add_argument('--FT', action='store_true', help='whether to do fine-tuning') 248 | parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is Adadelta)') 249 | parser.add_argument('--lr', type=float, default=1, help='learning rate, default=1.0 for Adadelta') 250 | parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for adam. default=0.9') 251 | parser.add_argument('--rho', type=float, default=0.95, help='decay rate rho for Adadelta. default=0.95') 252 | parser.add_argument('--eps', type=float, default=1e-8, help='eps for Adadelta. default=1e-8') 253 | parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping value. default=5') 254 | parser.add_argument('--baiduCTC', action='store_true', help='for data_filtering_off mode') 255 | """ Data processing """ 256 | parser.add_argument('--select_data', type=str, default='MJ-ST', 257 | help='select training data (default is MJ-ST, which means MJ and ST used as training data)') 258 | parser.add_argument('--batch_ratio', type=str, default='0.5-0.5', 259 | help='assign ratio for each selected data in the batch') 260 | parser.add_argument('--total_data_usage_ratio', type=str, default='1.0', 261 | help='total data usage ratio, this ratio is multiplied to total number of data.') 262 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') 263 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') 264 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') 265 | parser.add_argument('--rgb', action='store_true', help='use rgb input') 266 | parser.add_argument('--character', type=str, 267 | default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label') 268 | parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode') 269 | parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize') 270 | parser.add_argument('--data_filtering_off', action='store_true', help='for data_filtering_off mode') 271 | """ Model Architecture """ 272 | parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS') 273 | parser.add_argument('--FeatureExtraction', type=str, required=True, 274 | help='FeatureExtraction stage. VGG|RCNN|ResNet') 275 | parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM') 276 | parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn|Transformer') 277 | parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') 278 | parser.add_argument('--input_channel', type=int, default=1, 279 | help='the number of input channel of Feature extractor') 280 | parser.add_argument('--output_channel', type=int, default=512, 281 | help='the number of output channel of Feature extractor') 282 | parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') 283 | 284 | opt = parser.parse_args() 285 | 286 | if not opt.exp_name: 287 | opt.exp_name = f'{opt.Transformation}-{opt.FeatureExtraction}-{opt.SequenceModeling}-{opt.Prediction}' 288 | opt.exp_name += f'-Seed{opt.manualSeed}' 289 | # print(opt.exp_name) 290 | 291 | os.makedirs(f'./saved_models/{opt.exp_name}', exist_ok=True) 292 | 293 | """ vocab / character number configuration """ 294 | if opt.sensitive: 295 | # opt.character += 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' 296 | opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). 297 | 298 | """ Seed and GPU setting """ 299 | # print("Random Seed: ", opt.manualSeed) 300 | random.seed(opt.manualSeed) 301 | np.random.seed(opt.manualSeed) 302 | torch.manual_seed(opt.manualSeed) 303 | torch.cuda.manual_seed(opt.manualSeed) 304 | 305 | cudnn.benchmark = True 306 | cudnn.deterministic = True 307 | opt.num_gpu = torch.cuda.device_count() 308 | # print('device count', opt.num_gpu) 309 | if opt.num_gpu > 1: 310 | print('------ Use multi-GPU setting ------') 311 | print('if you stuck too long time with multi-GPU setting, try to set --workers 0') 312 | # check multi-GPU issue https://github.com/clovaai/deep-text-recognition-benchmark/issues/1 313 | opt.workers = opt.workers * opt.num_gpu 314 | opt.batch_size = opt.batch_size * opt.num_gpu 315 | 316 | """ previous version 317 | print('To equlize batch stats to 1-GPU setting, the batch_size is multiplied with num_gpu and multiplied batch_size is ', opt.batch_size) 318 | opt.batch_size = opt.batch_size * opt.num_gpu 319 | print('To equalize the number of epochs to 1-GPU setting, num_iter is divided with num_gpu by default.') 320 | If you dont care about it, just commnet out these line.) 321 | opt.num_iter = int(opt.num_iter / opt.num_gpu) 322 | """ 323 | 324 | train(opt) 325 | --------------------------------------------------------------------------------