├── outputs └── reference_result.png ├── tools ├── __init__.py ├── eval.py └── train.py ├── requirements.txt ├── .idea ├── misc.xml ├── inspectionProfiles │ └── profiles_settings.xml ├── modules.xml ├── Transformer_STR.iml └── workspace.xml ├── data ├── __init__.py └── lmdb_dataset.py ├── checkpoints └── Transformer_STR_CUTE80_pretrained.txt ├── utils ├── __init__.py ├── misc.py └── model_util.py ├── config.py ├── model └── __init__.py ├── README.MD └── evaluation.py /outputs/reference_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opconty/Transformer_STR/HEAD/outputs/reference_result.png -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | #''' 3 | # @date: 2020/6/9 下午12:17 4 | # 5 | # @author: laygin 6 | # 7 | #''' -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.18.2 2 | nltk==3.3 3 | six==1.14.0 4 | opencv_python_headless==4.1.2.30 5 | torch==1.3.1 6 | lmdb==0.98 7 | torchvision==0.4.2 8 | nicelogger==2.0 9 | Pillow==7.1.2 10 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | #''' 3 | # @date: 2020/6/9 下午12:16 4 | # 5 | # @author: laygin 6 | # 7 | #''' 8 | from .lmdb_dataset import hierarchical_dataset, AlignCollate, BatchBalancedDataset 9 | -------------------------------------------------------------------------------- /checkpoints/Transformer_STR_CUTE80_pretrained.txt: -------------------------------------------------------------------------------- 1 | PLEASE DOWNLOAD PRETRAINED WEIGHT FROM: 2 | https://drive.google.com/file/d/1o7aEt_Rmz5ZDIZqc2Z74lo01Tjq87uO1/view?usp=sharing 3 | 4 | AND PLACED IT IN THE CURRENT DIR. 5 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | #''' 3 | # @date: 2020/6/9 下午12:17 4 | # 5 | # @author: laygin 6 | # 7 | #''' 8 | from .misc import TransLabelConverter, Averager, ResizeNormalize, get_img_tensor 9 | 10 | __all__ = ['TransLabelConverter', 'Averager', 'ResizeNormalize'] 11 | 12 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/Transformer_STR.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | #''' 3 | # @date: 2020/6/9 下午12:17 4 | # 5 | # @author: laygin 6 | # 7 | #''' 8 | import os 9 | import sys 10 | import re 11 | import string 12 | from nicelogger import ColorLogger 13 | 14 | 15 | logger = ColorLogger() 16 | proj_dir = os.path.abspath(os.path.dirname(__file__)) 17 | if proj_dir not in sys.path: 18 | sys.path.insert(0, proj_dir) 19 | 20 | 21 | checkpoint_dir = os.path.join(proj_dir, 'checkpoints') 22 | # fixme: data path configuration 23 | data_dir = 'path/to/data_lmdb_text_recognition' 24 | 25 | p = re.compile(r'[!"#$%&()*+,/:;<=>?@\\^_`{|}~]') 26 | 27 | 28 | class Config: 29 | workers = 8 30 | batch_max_length = 25 31 | batch_size = 32 32 | imgH = 32 33 | imgW = 100 34 | keep_ratio = False 35 | rgb = False 36 | sensitive = False 37 | data_filtering_off = False 38 | keep_ratio_with_pad = False 39 | total_data_usage_ratio = 1.0 40 | 41 | punctuation = r"""'.-""" 42 | character = string.digits + string.ascii_lowercase + punctuation 43 | 44 | num_fiducial = 20 45 | input_channel = 1 46 | output_channel = 512 47 | hidden_size = 256 48 | num_gpu = 1 49 | 50 | lr = 1.0 51 | grad_clip = 5 52 | beta1 = 0.9 53 | rho = 0.95 54 | eps = 1e-8 55 | manualSeed = 2020 56 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | #''' 3 | # @date: 2020/6/9 下午12:17 4 | # 5 | # @author: laygin 6 | # 7 | #''' 8 | import torch 9 | import torch.nn as nn 10 | import importlib 11 | from config import logger 12 | model_util = importlib.import_module('utils.model_util') 13 | 14 | TPS_STN = model_util.TPS_STN 15 | ResNet50 = model_util.ResNet50 16 | BiLSTM = model_util.BiLSTM 17 | Transformer = model_util.Transformer 18 | 19 | 20 | class Model(nn.Module): 21 | def __init__(self, 22 | imgh=32, 23 | imgw=100, 24 | input_channel=1, 25 | output_channel=512, 26 | hidden_size=256, 27 | num_fiducial=20, 28 | num_class=41, 29 | bilstm=True, 30 | device=torch.device('cuda:0')): 31 | super(Model, self).__init__() 32 | 33 | logger.info(f'bi-lstm: {bilstm} | device: {device} | num_class: {num_class}') 34 | self.num_class = num_class 35 | self.bilstm = bilstm 36 | 37 | self.transformation = TPS_STN(num_fiducial, I_size=(imgh, imgw), I_r_size=(imgh, imgw), device=device, 38 | I_channel_num=input_channel) 39 | self.fe = ResNet50(input_channel, output_channel) 40 | 41 | self.adaptive_avg_pool = nn.AdaptiveAvgPool2d((None, 1)) 42 | self.seq = nn.Sequential(BiLSTM(output_channel, hidden_size, hidden_size), 43 | BiLSTM(hidden_size, hidden_size, hidden_size)) 44 | if self.bilstm: 45 | self.seq_out_channels = hidden_size 46 | else: 47 | logger.warn('There is no sequence model specified') 48 | self.seq_out_channels = output_channel 49 | self.prediction = Transformer(self.num_class, self.seq_out_channels) 50 | 51 | def forward(self, x): 52 | x = self.transformation(x) 53 | x = self.fe(x) 54 | x = self.adaptive_avg_pool(x.permute(0,3,1,2)) # [b, c, h, w] -> [b, w, c, h] 55 | x = x.squeeze(3) 56 | 57 | if self.bilstm: 58 | x = self.seq(x) 59 | 60 | pred = self.prediction(x.contiguous()) 61 | return pred 62 | -------------------------------------------------------------------------------- /README.MD: -------------------------------------------------------------------------------- 1 | # Transformer-based Scene Text Recognition (Transformer-STR) 2 | 3 | - PyTorch implementation of my new method for Scene Text Recognition (STR) based on [Transformer](https://arxiv.org/abs/1706.03762). 4 | 5 | I adapted the four-stage STR framework devised by [deep-text-recognition-benchmark](https://arxiv.org/abs/1904.01906), and replaced the `Pred.` stage with **Transformer**. 6 | 7 | Equipped with Transformer, this method outperforms the best model of the aforementioned deep-text-recognition-benchmark by **7.6%** on CUTE80. 8 | 9 | ### Download pretrained weights from [here](https://drive.google.com/file/d/1o7aEt_Rmz5ZDIZqc2Z74lo01Tjq87uO1/view?usp=sharing) 10 | This pre-trained weights trained on Synthetic dataset for about 700K iters. 11 | 12 | Git clone this repo and download the weights file, move it to `checkpoints` directory. 13 | 14 | ### Download lmdb dataset for traininig and evaluation from [here](https://drive.google.com/drive/folders/192UfE9agQUMNq6AgU3_E05_FcPZK4hyt)(provided by [deep-text-recognition-benchmark](https://arxiv.org/abs/1904.01906)) 15 | data_lmdb_release.zip contains below.
16 | training datasets : [MJSynth (MJ)](http://www.robots.ox.ac.uk/~vgg/data/text/)[1] and [SynthText (ST)](http://www.robots.ox.ac.uk/~vgg/data/scenetext/)[2] \ 17 | validation datasets : the union of the training sets [IC13](http://rrc.cvc.uab.es/?ch=2)[3], [IC15](http://rrc.cvc.uab.es/?ch=4)[4], [IIIT](http://cvit.iiit.ac.in/projects/SceneTextUnderstanding/IIIT5K.html)[5], and [SVT](http://www.iapr-tc11.org/mediawiki/index.php/The_Street_View_Text_Dataset)[6].\ 18 | evaluation datasets : benchmark evaluation datasets, consist of [IIIT](http://cvit.iiit.ac.in/projects/SceneTextUnderstanding/IIIT5K.html)[5], [SVT](http://www.iapr-tc11.org/mediawiki/index.php/The_Street_View_Text_Dataset)[6], [IC03](http://www.iapr-tc11.org/mediawiki/index.php/ICDAR_2003_Robust_Reading_Competitions)[7], [IC13](http://rrc.cvc.uab.es/?ch=2)[3], [IC15](http://rrc.cvc.uab.es/?ch=4)[4], [SVTP](http://openaccess.thecvf.com/content_iccv_2013/papers/Phan_Recognizing_Text_with_2013_ICCV_paper.pdf)[8], and [CUTE](http://cs-chan.com/downloads_CUTE80_dataset.html)[9]. 19 | 20 | 21 | ### Training 22 | Please configure your `data_dir` in `config.py` file, then run: 23 | 24 | ```python 25 | python tools/train.py 26 | ``` 27 | 28 | ### Evaluation on CUTE80 29 | The Transformer-base STR achieves **0.815972** accuracy on CUTE80, outperforming the best model of *deep-text-recognition-benchmark*, which is 0.74 30 | 31 | ![compared](./outputs/reference_result.png) 32 | 33 | If you want to reproduce the evaluation result, please run: 34 | 35 | ```python 36 | python evaluation.py 37 | ``` 38 | 39 | Make sure your `cute80_dir` and `saved_model` path is correct. you'll get the result **0.815972** 40 | 41 | 42 | ### Contact 43 | Feel free to contact me (gao.gzhou@gmail.com). 44 | 45 | ### License 46 | This project is released under the [Apache 2.0 license.](https://www.apache.org/licenses/LICENSE-2.0) 47 | 48 | ### References 49 | [deep-text-recognition-benchmark](https://arxiv.org/abs/1904.01906) 50 | -------------------------------------------------------------------------------- /tools/eval.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | #''' 3 | # @date: 2020/5/19 下午5:01 4 | # 5 | # @author: laygin 6 | # 7 | #''' 8 | import torch 9 | import torch.nn.functional as F 10 | import time 11 | from nltk.metrics.distance import edit_distance 12 | import importlib 13 | 14 | utils = importlib.import_module('utils') 15 | Averager = utils.Averager 16 | 17 | 18 | def validation(model, criterion, eval_loader, converter, device, cfg): 19 | n_correct = 0 20 | norm_ED = 0 21 | length_of_data = 0 22 | infer_time = 0 23 | valid_loss_avg = Averager() 24 | 25 | for i, (image_tensors, labels) in enumerate(eval_loader): 26 | batch_size = image_tensors.size(0) 27 | length_of_data = length_of_data + batch_size 28 | image = image_tensors.to(device) 29 | # For max length prediction 30 | length_for_pred = torch.IntTensor([cfg.batch_max_length] * batch_size).to(device) 31 | 32 | text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=cfg.batch_max_length) 33 | 34 | start_time = time.time() 35 | 36 | with torch.no_grad(): 37 | preds = model(image) 38 | forward_time = time.time() - start_time 39 | 40 | cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), text_for_loss.contiguous().view(-1)) 41 | 42 | # select max probabilty (greedy decoding) then decode index to character 43 | _, preds_index = preds.max(2) 44 | preds_str = converter.decode(preds_index, length_for_pred) 45 | labels = converter.decode(text_for_loss, length_for_loss) 46 | 47 | infer_time += forward_time 48 | valid_loss_avg.add(cost) 49 | 50 | # calculate accuracy & confidence score of one batch 51 | preds_prob = F.softmax(preds, dim=2) 52 | preds_max_prob, _ = preds_prob.max(dim=2) 53 | confidence_score_list = [] 54 | for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob): 55 | gt = gt[:gt.find('')] 56 | pred_EOS = pred.find('') 57 | pred = pred[:pred_EOS] 58 | pred_max_prob = pred_max_prob[:pred_EOS] 59 | 60 | # fixme: do not care the case even case-sensitive model 61 | if not cfg.sensitive: 62 | pred = pred.lower() 63 | gt = gt.lower() 64 | 65 | if pred == gt: 66 | n_correct += 1 67 | if len(gt) == 0: 68 | norm_ED += 1 69 | else: 70 | norm_ED += edit_distance(pred, gt) / len(gt) 71 | 72 | # calculate confidence score (= multiply of pred_max_prob) 73 | try: 74 | confidence_score = pred_max_prob.cumprod(dim=0)[-1] 75 | except: 76 | confidence_score = 0 # for empty pred case, when prune after "end of sentence" token ([s]) 77 | confidence_score_list.append(confidence_score) 78 | # print(pred, gt, pred==gt, confidence_score) 79 | 80 | accuracy = n_correct / float(length_of_data) 81 | 82 | return valid_loss_avg.val(), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data 83 | 84 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | #''' 3 | # @date: 2020/5/18 下午6:06 4 | # 5 | # @author: laygin 6 | # 7 | #''' 8 | import torch 9 | from torchvision import transforms 10 | from PIL import Image 11 | import cv2 12 | import numpy as np 13 | import math 14 | 15 | 16 | class Averager(object): 17 | """Compute average for torch.Tensor, used for loss average.""" 18 | 19 | def __init__(self): 20 | self.reset() 21 | 22 | def add(self, v): 23 | count = v.data.numel() 24 | v = v.data.sum() 25 | self.n_count += count 26 | self.sum += v 27 | 28 | def reset(self): 29 | self.n_count = 0 30 | self.sum = 0 31 | 32 | def val(self): 33 | res = 0 34 | if self.n_count != 0: 35 | res = self.sum / float(self.n_count) 36 | return res 37 | 38 | 39 | class ResizeNormalize(object): 40 | 41 | def __init__(self, size, interpolation=Image.BICUBIC): 42 | self.size = size 43 | self.interpolation = interpolation 44 | self.toTensor = transforms.ToTensor() 45 | 46 | def __call__(self, img): 47 | img = img.resize(self.size, self.interpolation) 48 | img = self.toTensor(img) 49 | img.sub_(0.5).div_(0.5) 50 | return img 51 | 52 | 53 | class NormalizePAD(object): 54 | 55 | def __init__(self, max_size, PAD_type='right'): 56 | self.toTensor = transforms.ToTensor() 57 | self.max_size = max_size 58 | self.max_width_half = math.floor(max_size[2] / 2) 59 | self.PAD_type = PAD_type 60 | 61 | def __call__(self, img): 62 | img = self.toTensor(img) 63 | img.sub_(0.5).div_(0.5) 64 | c, h, w = img.size() 65 | Pad_img = torch.FloatTensor(*self.max_size).fill_(0) 66 | Pad_img[:, :, :w] = img # right pad 67 | if self.max_size[2] != w: # add border Pad 68 | Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w) 69 | 70 | return Pad_img 71 | 72 | 73 | def get_img_tensor(img_path, newh=32, neww=100, keep_ratio=False): 74 | if isinstance(img_path, str): 75 | image = Image.open(img_path).convert('L') 76 | elif isinstance(img_path, np.ndarray): 77 | image = Image.fromarray(cv2.cvtColor(img_path, cv2.COLOR_BGR2RGB)).convert('L') 78 | else: 79 | raise Exception(f'{type(img_path)} not supported yet') 80 | 81 | if keep_ratio: # same concept with 'Rosetta' paper 82 | resized_max_w = neww 83 | input_channel = 3 if image.mode == 'RGB' else 1 84 | transform = NormalizePAD((input_channel, newh, resized_max_w)) 85 | 86 | w, h = image.size 87 | ratio = w / float(h) 88 | if math.ceil(h * ratio) > neww: 89 | resized_w = neww 90 | else: 91 | resized_w = math.ceil(newh * ratio) 92 | 93 | resized_image = image.resize((resized_w, newh), Image.BICUBIC) 94 | t = transform(resized_image) 95 | 96 | else: 97 | transform = ResizeNormalize((neww, newh)) 98 | t = transform(image) 99 | # print(f'image size: {image.size}\t tensor size: {t.size()}') 100 | return torch.unsqueeze(t, 0) 101 | 102 | 103 | class TransLabelConverter(object): 104 | """ Convert between text-label and text-index """ 105 | def __init__(self, character, device): 106 | self.device = device 107 | list_token = [''] 108 | self.character = list_token + list(character) 109 | 110 | self.dict = {} 111 | for i, char in enumerate(self.character): 112 | self.dict[char] = i 113 | 114 | def encode(self, text, batch_max_length=25): 115 | length = [len(s) + 1 for s in text] 116 | batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0) 117 | for i, t in enumerate(text): 118 | text = list(t) 119 | text.append('') 120 | text = [self.dict[char] for char in text] 121 | batch_text[i][:len(text)] = torch.LongTensor(text) 122 | return batch_text.to(self.device), torch.IntTensor(length).to(self.device) 123 | 124 | def decode(self, text_index, length): 125 | texts = [] 126 | for index, l in enumerate(length): 127 | text = ''.join([self.character[i] for i in text_index[index, :]]) 128 | texts.append(text) 129 | return texts 130 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 11 | 12 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 43 | 44 | 45 | 46 | 47 | 66 | 67 | 68 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 1591676194053 101 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | #''' 3 | # @date: 2020/6/9 下午12:17 4 | # 5 | # @author: laygin 6 | # 7 | #''' 8 | import os 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | import torch.nn.functional as F 12 | from torch.utils.data import DataLoader 13 | import re 14 | import string 15 | import config 16 | import warnings 17 | warnings.filterwarnings('ignore') 18 | import importlib 19 | 20 | utils = importlib.import_module('utils') 21 | eval = importlib.import_module('tools.eval') 22 | data = importlib.import_module('data') 23 | model = importlib.import_module('model') 24 | 25 | TransLabelConverter = utils.TransLabelConverter 26 | hierarchical_dataset = data.hierarchical_dataset 27 | AlignCollate = data.AlignCollate 28 | Model = model.Model 29 | 30 | 31 | logger = config.logger 32 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 33 | p = re.compile(r'[!"#$%&()*+,/:;<=>?@\\^_`{|}~]') 34 | 35 | 36 | class Config(config.Config): 37 | valid_data = os.path.join(config.data_dir, 'validation') 38 | workers = 0 39 | batch_size = 32 40 | num_class = 40 41 | with_bilstm = True 42 | sensitive = False 43 | filter_punctuation = False 44 | backbone = 'resnet' 45 | 46 | checkpoint_dir = config.checkpoint_dir 47 | saved_model = '' 48 | 49 | 50 | def create_model(cfg: Config): 51 | """model""" 52 | cfg.sensitive = True if 'sensitive' in cfg.saved_model else False 53 | 54 | if cfg.sensitive: 55 | cfg.character = string.digits + string.ascii_letters + cfg.punctuation 56 | 57 | converter = TransLabelConverter(cfg.character, device) 58 | cfg.num_class = len(converter.character) 59 | 60 | if cfg.rgb: 61 | cfg.input_channel = 3 62 | model = Model(cfg.imgH, cfg.imgW, cfg.input_channel, cfg.output_channel, cfg.hidden_size, 63 | cfg.num_fiducial, cfg.num_class, cfg.with_bilstm, device=device) 64 | 65 | # data parallel for multi-GPU 66 | model = torch.nn.DataParallel(model).to(device) 67 | assert os.path.exists(cfg.saved_model), FileNotFoundError(f'{cfg.saved_model}') 68 | 69 | if os.path.isfile(cfg.saved_model): 70 | logger.info(f'loading pretrained model from {os.path.relpath(cfg.saved_model, os.path.dirname(__file__))}') 71 | model.load_state_dict(torch.load(cfg.saved_model, map_location=device)) 72 | 73 | model.eval() 74 | 75 | return model, converter 76 | 77 | 78 | def validation(cfg: Config, model, converter): 79 | cfg.sensitive = True if 'sensitive' in cfg.saved_model else False 80 | AlignCollate_valid = AlignCollate() 81 | valid_dataset = hierarchical_dataset(cfg.valid_data, cfg.imgH, cfg.imgW, cfg.batch_max_length, cfg.character, 82 | cfg.sensitive, cfg.rgb, cfg.data_filtering_off) 83 | valid_loader = DataLoader( 84 | valid_dataset, batch_size=cfg.batch_size, 85 | shuffle=False, 86 | num_workers=int(cfg.workers), 87 | collate_fn=AlignCollate_valid, pin_memory=True) 88 | 89 | model.eval() 90 | 91 | n_correct = 0 92 | length_of_data = 0 93 | 94 | for i, (image_tensors, labels) in enumerate(valid_loader): 95 | batch_size = image_tensors.size(0) 96 | length_of_data = length_of_data + batch_size 97 | image = image_tensors.to(device) 98 | # For max length prediction 99 | length_for_pred = torch.IntTensor([cfg.batch_max_length] * batch_size).to(device) 100 | 101 | text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=cfg.batch_max_length) 102 | 103 | with torch.no_grad(): 104 | preds = model(image) 105 | 106 | # select max probabilty (greedy decoding) then decode index to character 107 | _, preds_index = preds.max(2) 108 | preds_str = converter.decode(preds_index, length_for_pred) 109 | labels = converter.decode(text_for_loss, length_for_loss) 110 | 111 | # calculate accuracy & confidence score of one batch 112 | preds_prob = F.softmax(preds, dim=2) 113 | preds_max_prob, _ = preds_prob.max(dim=2) 114 | for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob): 115 | gt = gt[:gt.find('')] 116 | pred_EOS = pred.find('') 117 | pred = pred[:pred_EOS] 118 | pred_max_prob = pred_max_prob[:pred_EOS] 119 | 120 | try: 121 | confidence_score = pred_max_prob.cumprod(dim=0)[-1] 122 | except: 123 | confidence_score = 0.0 124 | 125 | if not cfg.sensitive: 126 | pred = pred.lower() 127 | 128 | # fixme: filter punctuation 129 | if cfg.filter_punctuation: 130 | pred = re.sub(p, '', pred) 131 | gt = re.sub(p, '', gt) 132 | 133 | if pred == gt: 134 | n_correct += 1 135 | 136 | accuracy = n_correct / float(length_of_data) 137 | 138 | return accuracy 139 | 140 | 141 | def eval_cute80(cute80_data_dir): 142 | cfg = Config() 143 | cfg.saved_model = os.path.join(cfg.checkpoint_dir, 144 | 'Transformer_STR_CUTE80_pretrained.pth') 145 | 146 | model, converter = create_model(cfg) 147 | 148 | cfg.valid_data = cute80_data_dir 149 | acc = validation(cfg, model, converter) 150 | 151 | logger.success(f'{acc:.6f}') 152 | 153 | 154 | if __name__ == '__main__': 155 | cudnn.benchmark = True 156 | cudnn.deterministic = True 157 | 158 | cute80_dir = os.path.join(config.data_dir, 'evaluation', 'CUTE80') 159 | eval_cute80(cute80_dir) 160 | 161 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | #''' 3 | # @date: 2020/6/9 下午12:26 4 | # 5 | # @author: laygin 6 | # 7 | #''' 8 | import os, sys, time, random 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | from torch.utils.data import DataLoader 12 | import torch.nn.init as init 13 | import torch.optim as optim 14 | import numpy as np 15 | import string 16 | import config 17 | import importlib 18 | 19 | utils = importlib.import_module('utils') 20 | eval = importlib.import_module('tools.eval') 21 | data = importlib.import_module('data') 22 | model = importlib.import_module('model') 23 | 24 | TransLabelConverter = utils.TransLabelConverter 25 | Averager = utils.Averager 26 | validation = eval.validation 27 | hierarchical_dataset = data.hierarchical_dataset 28 | AlignCollate = data.AlignCollate 29 | BatchBalancedDataset = data.BatchBalancedDataset 30 | Model = model.Model 31 | 32 | 33 | logger = config.logger 34 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 35 | 36 | 37 | class Config(config.Config): 38 | train_data = os.path.join(config.data_dir, 'training') 39 | valid_data = os.path.join(config.data_dir, 'validation') 40 | saved_model = '' 41 | select_data = ['SJ', 'MJ'] # select training data 42 | batch_ratio = [0.5, 0.5] # assign ratio for each selected data in the batch 43 | workers = 0 44 | batch_size = 192 45 | num_iter = 3000000 46 | valInterval = 2000 47 | total_data_usage_ratio = 1.0 48 | num_class = 40 49 | with_bilstm = True 50 | sensitive = False 51 | 52 | checkpoint_dir = config.checkpoint_dir 53 | Name = f'trans_{"".join(select_data)}_brnn{with_bilstm}' 54 | 55 | 56 | def train(cfg): 57 | """ dataset preparation """ 58 | if not cfg.data_filtering_off: 59 | logger.info('Filtering the images containing characters which are not in character') 60 | logger.info(f'Filtering the images whose label is longer than {cfg.batch_max_length}') 61 | 62 | if cfg.sensitive: 63 | cfg.character = string.digits + string.ascii_letters + cfg.punctuation 64 | cfg.Name += '_sensitive' 65 | 66 | train_dataset = BatchBalancedDataset(cfg.train_data, cfg.batch_max_length,cfg.character, 67 | cfg.select_data, cfg.batch_ratio, cfg.batch_size, cfg.total_data_usage_ratio, 68 | cfg.workers, cfg.sensitive, cfg.rgb, cfg.data_filtering_off, cfg.imgH, cfg.imgW, 69 | cfg.keep_ratio_with_pad) 70 | 71 | AlignCollate_valid = AlignCollate() 72 | valid_dataset = hierarchical_dataset(cfg.valid_data, cfg.imgH, cfg.imgW, cfg.batch_max_length, cfg.character, 73 | cfg.sensitive, cfg.rgb, cfg.data_filtering_off) 74 | valid_loader = DataLoader( 75 | valid_dataset, batch_size=cfg.batch_size, 76 | shuffle=True, 77 | num_workers=int(cfg.workers), 78 | collate_fn=AlignCollate_valid, pin_memory=True) 79 | 80 | """model""" 81 | converter = TransLabelConverter(cfg.character, device) 82 | cfg.num_class = len(converter.character) 83 | logger.verbose(f'{cfg.num_class}\n{converter.character}') 84 | 85 | if cfg.rgb: 86 | cfg.input_channel = 3 87 | model = Model(cfg.imgH,cfg.imgW, cfg.input_channel, cfg.output_channel, cfg.hidden_size, 88 | cfg.num_fiducial, cfg.num_class, cfg.with_bilstm, 89 | device=device) 90 | 91 | logger.info('initialize') 92 | for name, param in model.named_parameters(): 93 | if 'localization_fc2' in name or 'decoder' in name or 'self_attn' in name: 94 | logger.info(f'Skip {name} as it is already initialized') 95 | continue 96 | try: 97 | if 'bias' in name: 98 | init.constant_(param, 0.0) 99 | elif 'weight' in name: 100 | init.xavier_normal_(param) 101 | except: # for batchnorm. 102 | if 'weight' in name: 103 | param.data.fill_(1) 104 | continue 105 | 106 | # data parallel for multi-GPU 107 | model = torch.nn.DataParallel(model).to(device) 108 | model.train() 109 | if os.path.isfile(cfg.saved_model): 110 | logger.info(f'loading pretrained model from {cfg.saved_model}') 111 | model.load_state_dict(torch.load(cfg.saved_model)) 112 | 113 | criterion = torch.nn.CrossEntropyLoss().to(device) 114 | loss_avg = Averager() 115 | 116 | # filter that only require gradient decent 117 | filtered_parameters = [] 118 | params_num = [] 119 | for p in filter(lambda p: p.requires_grad, model.parameters()): 120 | filtered_parameters.append(p) 121 | params_num.append(np.prod(p.size())) 122 | logger.info('Trainable params num : ', sum(params_num)) 123 | 124 | # setup optimizer 125 | optimizer = optim.Adadelta(filtered_parameters, lr=cfg.lr, rho=cfg.rho, eps=cfg.eps) 126 | 127 | """ start training """ 128 | start_iter = 0 129 | start_time = time.time() 130 | best_accuracy = -1 131 | best_norm_ED = 1e+6 132 | i = start_iter 133 | epoch_size = len(train_dataset) // cfg.batch_size 134 | 135 | while True: 136 | image_tensors, labels = train_dataset.get_batch() 137 | image = image_tensors.to(device) 138 | target, length = converter.encode(labels, batch_max_length=cfg.batch_max_length) 139 | 140 | preds = model(image) 141 | cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) 142 | 143 | model.zero_grad() 144 | cost.backward() 145 | torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip) 146 | optimizer.step() 147 | 148 | loss_avg.add(cost) 149 | 150 | # validation part 151 | if i % cfg.valInterval == 0 and i > 0: 152 | elapsed_time = time.time() - start_time 153 | 154 | model.eval() 155 | with torch.no_grad(): 156 | valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( 157 | model, criterion, valid_loader, converter, device, cfg) 158 | model.train() 159 | 160 | # training loss and validation loss 161 | loss_log = f'[{i}/{cfg.num_iter}({epoch_size})] Train loss: {loss_avg.val():0.5f},' \ 162 | f' Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' 163 | logger.info(loss_log) 164 | loss_avg.reset() 165 | 166 | current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' 167 | logger.info(current_model_log) 168 | 169 | # keep best accuracy model (on valid dataset) 170 | if current_accuracy > best_accuracy and current_accuracy > 0.65: 171 | best_accuracy = current_accuracy 172 | best_norm_ED = current_norm_ED if current_norm_ED < best_norm_ED else best_norm_ED 173 | acc = f'{best_accuracy:.4f}'.replace('.', '') 174 | # save 175 | torch.save(model.state_dict(), os.path.join(cfg.checkpoint_dir, f'{cfg.Name}_iter{i + 1}_acc{acc}.pth')) 176 | best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' 177 | logger.success(best_model_log) 178 | 179 | # show some predicted results 180 | print('-' * 80) 181 | print(f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F') 182 | print('-' * 80) 183 | for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): 184 | gt = gt[:gt.find('')] 185 | pred = pred[:pred.find('')] 186 | 187 | print(f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred.lower() == gt.lower())}') 188 | print('-' * 80) 189 | 190 | # save model per 1e+5 iter. 191 | if (i + 1) % 1e+5 == 0: 192 | torch.save( 193 | model.state_dict(), os.path.join(cfg.checkpoint_dir, f'{cfg.Name}_iter{i + 1}.pth')) 194 | 195 | if i == cfg.num_iter: 196 | print('end the training') 197 | sys.exit() 198 | i += 1 199 | 200 | 201 | if __name__ == '__main__': 202 | cfg = Config() 203 | 204 | random.seed(cfg.manualSeed) 205 | np.random.seed(cfg.manualSeed) 206 | torch.manual_seed(cfg.manualSeed) 207 | torch.cuda.manual_seed(cfg.manualSeed) 208 | 209 | cudnn.benchmark = True 210 | cudnn.deterministic = True 211 | 212 | train(cfg) 213 | 214 | 215 | 216 | -------------------------------------------------------------------------------- /data/lmdb_dataset.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | #''' 3 | # @date: 2020/5/19 上午10:21 4 | # 5 | # @author: laygin 6 | # 7 | # adapted from https://github.com/clovaai/deep-text-recognition-benchmark 8 | #''' 9 | import os, sys, re, six, math, lmdb 10 | import torch 11 | from PIL import Image 12 | from torch.utils.data import Dataset, ConcatDataset, Subset 13 | from utils.misc import ResizeNormalize, NormalizePAD 14 | import config 15 | 16 | 17 | logger = config.logger 18 | 19 | 20 | def _accumulate(iterable, fn=lambda x, y: x + y): 21 | '''Return running totals''' 22 | # _accumulate([1,2,3,4,5]) --> 1 3 6 10 15 23 | # _accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120 24 | it = iter(iterable) 25 | try: 26 | total = next(it) 27 | except StopIteration: 28 | return 29 | yield total 30 | for element in it: 31 | total = fn(total, element) 32 | yield total 33 | 34 | 35 | class AlignCollate: 36 | def __init__(self, imgh=32, imgw=100, keep_ratio_with_pad=False): 37 | self.imgh = imgh 38 | self.imgw = imgw 39 | self.keep_ratio_with_pad = keep_ratio_with_pad 40 | 41 | def __call__(self, batch): 42 | batch = filter(lambda x:x is not None, batch) 43 | images, labels = zip(*batch) 44 | 45 | if self.keep_ratio_with_pad: 46 | resized_max_w = self.imgw 47 | input_channel = 3 if images[0].mode == 'RGB' else 1 48 | transform = NormalizePAD((input_channel, self.imgh, resized_max_w)) 49 | 50 | resized_images = [] 51 | for image in images: 52 | w, h = image.size 53 | ratio = w / float(h) 54 | if math.ceil(self.imgh * ratio) > self.imgw: 55 | resized_w = self.imgw 56 | else: 57 | resized_w = math.ceil(self.imgh * ratio) 58 | 59 | resized_image = image.resize((resized_w, self.imgh), Image.BICUBIC) 60 | resized_images.append(transform(resized_image)) 61 | image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0) 62 | else: 63 | transform = ResizeNormalize((self.imgw, self.imgh)) 64 | image_tensors = [transform(i) for i in images] 65 | image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0) 66 | return image_tensors, labels 67 | 68 | 69 | class LmdbDataset(Dataset): 70 | def __init__(self, root, imgh, imgw, batch_max_length, character, sensitive, rgb, data_filtering_off=False): 71 | self.root = root 72 | self.character = character 73 | self.sensitive = sensitive 74 | self.rgb = rgb 75 | self.imgH = imgh 76 | self.imgW = imgw 77 | 78 | self.env = lmdb.open(root, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) 79 | if not self.env: 80 | logger.error('can not create lmdb from %s' % root) 81 | sys.exit(0) 82 | 83 | with self.env.begin(write=False) as txn: 84 | self.nsamples = int(txn.get('num-samples'.encode())) 85 | 86 | if data_filtering_off: 87 | self.filtered_index_list = [index + 1 for index in range(self.nsamples)] 88 | else: 89 | self.filtered_index_list = [] 90 | for index in range(self.nsamples): 91 | index += 1 92 | label_key = 'label-%09d'.encode() % index 93 | label = txn.get(label_key).decode('utf-8') 94 | 95 | if len(label) > batch_max_length: 96 | continue 97 | out_of_char = f'[^{self.character}]' 98 | label = label if self.sensitive else label.lower() 99 | if re.search(out_of_char, label): 100 | continue 101 | self.filtered_index_list.append(index) 102 | self.nsamples = len(self.filtered_index_list) 103 | 104 | def __len__(self): 105 | return self.nsamples 106 | 107 | def __getitem__(self, index): 108 | assert index <= len(self), 'index range error' 109 | index = self.filtered_index_list[index] 110 | 111 | with self.env.begin(write=False) as txn: 112 | label_key = 'label-%09d'.encode() % index 113 | label = txn.get(label_key).decode('utf-8') 114 | img_key = 'image-%09d'.encode() % index 115 | imgbuf = txn.get(img_key) 116 | 117 | buf = six.BytesIO() 118 | buf.write(imgbuf) 119 | buf.seek(0) 120 | try: 121 | if self.rgb: 122 | img = Image.open(buf).convert('RGB') # for color image 123 | else: 124 | img = Image.open(buf).convert('L') 125 | 126 | except IOError: 127 | logger.error(f'Corrupted image for {index}') 128 | # make dummy image and dummy label for corrupted image. 129 | if self.rgb: 130 | img = Image.new('RGB', (self.imgW, self.imgH)) 131 | else: 132 | img = Image.new('L', (self.imgW, self.imgH)) 133 | label = '[dummy_label]' 134 | 135 | if not self.sensitive: 136 | label = label.lower() 137 | 138 | out_of_char = f'[^{self.character}]' 139 | label = re.sub(out_of_char, '', label) 140 | 141 | return img, label 142 | 143 | 144 | def hierarchical_dataset(root, imgh, imgw, batch_max_length,character, 145 | sensitive=False, rgb=False, data_filtering_off=False, select_data='/'): 146 | dataset_list = [] 147 | for dirpath, dirnames, filenames in os.walk(root + '/'): 148 | if not dirnames: 149 | select_flag = False 150 | for select_d in select_data: 151 | if select_d in dirpath: 152 | select_flag = True 153 | break 154 | if select_flag: 155 | dataset = LmdbDataset(dirpath, 156 | imgh=imgh, 157 | imgw=imgw, 158 | batch_max_length=batch_max_length, 159 | character=character,sensitive=sensitive,rgb=rgb, data_filtering_off=data_filtering_off) 160 | logger.info(f'sub-directory:\t{os.path.relpath(dirpath, root)}\t num samples: {len(dataset)}') 161 | dataset_list.append(dataset) 162 | return ConcatDataset(dataset_list) 163 | 164 | 165 | class BatchBalancedDataset: 166 | def __init__(self, 167 | root,batch_max_length, character, 168 | select_data, 169 | batch_ratio, 170 | batch_size, 171 | total_data_usage_ratio, 172 | workers, 173 | sensitive=False, rgb=False, data_filtering_off=False, 174 | imgh=32, imgw=100, keep_ratio_with_pad=False): 175 | assert len(select_data) == len(batch_ratio) 176 | 177 | _AlignCollate = AlignCollate(imgh=imgh, imgw=imgw, keep_ratio_with_pad=keep_ratio_with_pad) 178 | self.data_loader_list = [] 179 | self.data_loader_iter_list = [] 180 | batch_size_list = [] 181 | total_batch_size = 0 182 | self.total_samples = 0 183 | for selected_d, batch_ratio_d in zip(select_data, batch_ratio): 184 | _batch_size = max(round(batch_size * float(batch_ratio_d)), 1) 185 | print('-' * 80) 186 | _dataset = hierarchical_dataset(root, imgh,imgw,batch_max_length, 187 | character,sensitive,rgb,data_filtering_off, 188 | select_data=[selected_d]) 189 | total_number_dataset = len(_dataset) 190 | 191 | """ 192 | total_data_usage_ratio = 1 indicates 100% usage, and 0.2 indicates 20% usage. 193 | """ 194 | number_dataset = int(total_number_dataset * float(total_data_usage_ratio)) 195 | self.total_samples += number_dataset 196 | dataset_split = [number_dataset, total_number_dataset - number_dataset] 197 | 198 | indices = range(total_number_dataset) 199 | _dataset, _ = [Subset(_dataset, indices[offset-length: offset]) 200 | for offset, length in zip(_accumulate(dataset_split), dataset_split)] 201 | 202 | logger.info(f'number total samples of {selected_d}: {total_number_dataset} x ' 203 | f'{total_data_usage_ratio} (usage_ratio)={len(_dataset)}') 204 | logger.info(f'number samples of {selected_d} per batch: {batch_size} x' 205 | f'{batch_ratio_d}(batch ratio)={_batch_size}') 206 | 207 | batch_size_list.append(str(_batch_size)) 208 | total_batch_size += _batch_size 209 | 210 | _data_loader = torch.utils.data.DataLoader( 211 | _dataset, batch_size=_batch_size, shuffle=True, num_workers=workers, 212 | collate_fn=_AlignCollate, pin_memory=True 213 | ) 214 | self.data_loader_list.append(_data_loader) 215 | self.data_loader_iter_list.append(iter(_data_loader)) 216 | 217 | print('-'*80) 218 | logger.info('Total_batch_size: ', '+'.join(batch_size_list), '=', str(total_batch_size)) 219 | logger.info(f'Total samples: {self.total_samples}') 220 | self.batch_size = total_batch_size 221 | print('-' * 80) 222 | 223 | def __len__(self): 224 | return self.total_samples 225 | 226 | def get_batch(self): 227 | balanced_batch_images = [] 228 | balanced_batch_texts = [] 229 | 230 | for i, data_loader_iter in enumerate(self.data_loader_iter_list): 231 | try: 232 | image, text = data_loader_iter.next() 233 | balanced_batch_images.append(image) 234 | balanced_batch_texts += text 235 | except StopIteration: 236 | self.data_loader_iter_list[i] = iter(self.data_loader_list[i]) 237 | image, text = self.data_loader_iter_list[i].next() 238 | balanced_batch_images.append(image) 239 | balanced_batch_texts += text 240 | except ValueError: 241 | pass 242 | 243 | balanced_batch_images = torch.cat(balanced_batch_images, 0) 244 | 245 | return balanced_batch_images, balanced_batch_texts 246 | -------------------------------------------------------------------------------- /utils/model_util.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | #''' 3 | # @date: 2020/5/18 下午6:06 4 | # 5 | # @author: laygin 6 | # 7 | #''' 8 | import math 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | from torch.nn.init import xavier_uniform_ 13 | from torch.nn.init import constant_ 14 | from torch.nn.init import xavier_normal_ 15 | import copy 16 | import torch.nn.functional as F 17 | from config import logger 18 | 19 | 20 | class TPS_STN(nn.Module): 21 | """ Rectification Network of RARE, namely TPS based STN """ 22 | 23 | def __init__(self, F, I_size, I_r_size, device, I_channel_num=1): 24 | """ Based on RARE TPS 25 | input: 26 | batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width] 27 | I_size : (height, width) of the input image I 28 | I_r_size : (height, width) of the rectified image I_r 29 | I_channel_num : the number of channels of the input image I 30 | output: 31 | batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width] 32 | """ 33 | super(TPS_STN, self).__init__() 34 | self.F = F 35 | self.I_size = I_size 36 | self.I_r_size = I_r_size # = (I_r_height, I_r_width) 37 | self.I_channel_num = I_channel_num 38 | self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num) 39 | self.GridGenerator = GridGenerator(self.F, self.I_r_size, device) 40 | 41 | def forward(self, batch_I): 42 | batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 43 | build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2 44 | build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2]) 45 | batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border') 46 | 47 | return batch_I_r 48 | 49 | 50 | class LocalizationNetwork(nn.Module): 51 | """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """ 52 | 53 | def __init__(self, F, I_channel_num): 54 | super(LocalizationNetwork, self).__init__() 55 | self.F = F 56 | self.I_channel_num = I_channel_num 57 | self.conv = nn.Sequential( 58 | nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1, 59 | bias=False), nn.BatchNorm2d(64), nn.ReLU(True), 60 | nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2 61 | nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), 62 | nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4 63 | nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), 64 | nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8 65 | nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), 66 | nn.AdaptiveAvgPool2d(1) # batch_size x 512 67 | ) 68 | 69 | self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True)) 70 | self.localization_fc2 = nn.Linear(256, self.F * 2) 71 | 72 | # Init fc2 in LocalizationNetwork 73 | self.localization_fc2.weight.data.fill_(0) 74 | """ see RARE paper Fig. 6 (a) """ 75 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 76 | ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2)) 77 | ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2)) 78 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 79 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 80 | initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 81 | self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1) 82 | 83 | def forward(self, batch_I): 84 | """ 85 | input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width] 86 | output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2] 87 | """ 88 | batch_size = batch_I.size(0) 89 | features = self.conv(batch_I).view(batch_size, -1) 90 | batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2) 91 | return batch_C_prime 92 | 93 | 94 | class GridGenerator(nn.Module): 95 | """ Grid Generator of RARE, which produces P_prime by multipling T with P """ 96 | 97 | def __init__(self, F, I_r_size, device): 98 | """ Generate P_hat and inv_delta_C for later """ 99 | super(GridGenerator, self).__init__() 100 | self.device = device 101 | self.eps = 1e-6 102 | self.I_r_height, self.I_r_width = I_r_size 103 | self.F = F 104 | self.C = self._build_C(self.F) # F x 2 105 | self.P = self._build_P(self.I_r_width, self.I_r_height) 106 | ## for multi-gpu, you need register buffer 107 | self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3 108 | self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3 109 | ## for fine-tuning with different image width, you may use below instead of self.register_buffer 110 | #self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float().cuda() # F+3 x F+3 111 | #self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float().cuda() # n x F+3 112 | 113 | def _build_C(self, F): 114 | """ Return coordinates of fiducial points in I_r; C """ 115 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 116 | ctrl_pts_y_top = -1 * np.ones(int(F / 2)) 117 | ctrl_pts_y_bottom = np.ones(int(F / 2)) 118 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 119 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 120 | C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 121 | return C # F x 2 122 | 123 | def _build_inv_delta_C(self, F, C): 124 | """ Return inv_delta_C which is needed to calculate T """ 125 | hat_C = np.zeros((F, F), dtype=float) # F x F 126 | for i in range(0, F): 127 | for j in range(i, F): 128 | r = np.linalg.norm(C[i] - C[j]) 129 | hat_C[i, j] = r 130 | hat_C[j, i] = r 131 | np.fill_diagonal(hat_C, 1) 132 | hat_C = (hat_C ** 2) * np.log(hat_C) 133 | # print(C.shape, hat_C.shape) 134 | delta_C = np.concatenate( # F+3 x F+3 135 | [ 136 | np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3 137 | np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3 138 | np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3 139 | ], 140 | axis=0 141 | ) 142 | inv_delta_C = np.linalg.inv(delta_C) 143 | return inv_delta_C # F+3 x F+3 144 | 145 | def _build_P(self, I_r_width, I_r_height): 146 | I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width # self.I_r_width 147 | I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height # self.I_r_height 148 | P = np.stack( # self.I_r_width x self.I_r_height x 2 149 | np.meshgrid(I_r_grid_x, I_r_grid_y), 150 | axis=2 151 | ) 152 | return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2 153 | 154 | def _build_P_hat(self, F, C, P): 155 | n = P.shape[0] # n (= self.I_r_width x self.I_r_height) 156 | P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2 157 | C_tile = np.expand_dims(C, axis=0) # 1 x F x 2 158 | P_diff = P_tile - C_tile # n x F x 2 159 | rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F 160 | rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F 161 | P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1) 162 | return P_hat # n x F+3 163 | 164 | def build_P_prime(self, batch_C_prime): 165 | """ Generate Grid from batch_C_prime [batch_size x F x 2] """ 166 | batch_size = batch_C_prime.size(0) 167 | batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1) 168 | batch_P_hat = self.P_hat.repeat(batch_size, 1, 1) 169 | batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros( 170 | batch_size, 3, 2).float().to(self.device)), dim=1) # batch_size x F+3 x 2 171 | batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2 172 | batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2 173 | return batch_P_prime # batch_size x n x 2 174 | 175 | 176 | class BasicBlockRes(nn.Module): 177 | expansion = 1 178 | 179 | def __init__(self, inplanes, planes, stride=1, downsample=None): 180 | super(BasicBlockRes, self).__init__() 181 | self.conv1 = self._conv3x3(inplanes, planes) 182 | self.bn1 = nn.BatchNorm2d(planes) 183 | self.conv2 = self._conv3x3(planes, planes) 184 | self.bn2 = nn.BatchNorm2d(planes) 185 | self.relu = nn.ReLU(inplace=True) 186 | self.downsample = downsample 187 | self.stride = stride 188 | 189 | def _conv3x3(self, in_planes, out_planes, stride=1): 190 | "3x3 convolution with padding" 191 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 192 | padding=1, bias=False) 193 | 194 | def forward(self, x): 195 | residual = x 196 | 197 | out = self.conv1(x) 198 | out = self.bn1(out) 199 | out = self.relu(out) 200 | 201 | out = self.conv2(out) 202 | out = self.bn2(out) 203 | 204 | if self.downsample is not None: 205 | residual = self.downsample(x) 206 | out += residual 207 | out = self.relu(out) 208 | 209 | return out 210 | 211 | 212 | class ResNet(nn.Module): 213 | 214 | def __init__(self, input_channel, output_channel, block, layers): 215 | super(ResNet, self).__init__() 216 | 217 | self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel] 218 | 219 | self.inplanes = int(output_channel / 8) 220 | self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16), 221 | kernel_size=3, stride=1, padding=1, bias=False) 222 | self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16)) 223 | self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes, 224 | kernel_size=3, stride=1, padding=1, bias=False) 225 | self.bn0_2 = nn.BatchNorm2d(self.inplanes) 226 | self.relu = nn.ReLU(inplace=True) 227 | 228 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 229 | self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0]) 230 | self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[ 231 | 0], kernel_size=3, stride=1, padding=1, bias=False) 232 | self.bn1 = nn.BatchNorm2d(self.output_channel_block[0]) 233 | 234 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 235 | self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1) 236 | self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[ 237 | 1], kernel_size=3, stride=1, padding=1, bias=False) 238 | self.bn2 = nn.BatchNorm2d(self.output_channel_block[1]) 239 | 240 | self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1)) 241 | self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1) 242 | self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[ 243 | 2], kernel_size=3, stride=1, padding=1, bias=False) 244 | self.bn3 = nn.BatchNorm2d(self.output_channel_block[2]) 245 | 246 | self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1) 247 | self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ 248 | 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False) 249 | self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3]) 250 | self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ 251 | 3], kernel_size=2, stride=1, padding=0, bias=False) 252 | self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3]) 253 | 254 | def _make_layer(self, block, planes, blocks, stride=1): 255 | downsample = None 256 | if stride != 1 or self.inplanes != planes * block.expansion: 257 | downsample = nn.Sequential( 258 | nn.Conv2d(self.inplanes, planes * block.expansion, 259 | kernel_size=1, stride=stride, bias=False), 260 | nn.BatchNorm2d(planes * block.expansion), 261 | ) 262 | 263 | layers = [] 264 | layers.append(block(self.inplanes, planes, stride, downsample)) 265 | self.inplanes = planes * block.expansion 266 | for i in range(1, blocks): 267 | layers.append(block(self.inplanes, planes)) 268 | 269 | return nn.Sequential(*layers) 270 | 271 | def forward(self, x): 272 | x = self.conv0_1(x) 273 | x = self.bn0_1(x) 274 | x = self.relu(x) 275 | x = self.conv0_2(x) 276 | x = self.bn0_2(x) 277 | x = self.relu(x) 278 | 279 | x = self.maxpool1(x) 280 | x = self.layer1(x) 281 | x = self.conv1(x) 282 | x = self.bn1(x) 283 | x = self.relu(x) 284 | 285 | x = self.maxpool2(x) 286 | x = self.layer2(x) 287 | x = self.conv2(x) 288 | x = self.bn2(x) 289 | x = self.relu(x) 290 | 291 | x = self.maxpool3(x) 292 | x = self.layer3(x) 293 | x = self.conv3(x) 294 | x = self.bn3(x) 295 | x = self.relu(x) 296 | 297 | x = self.layer4(x) 298 | x = self.conv4_1(x) 299 | x = self.bn4_1(x) 300 | x = self.relu(x) 301 | x = self.conv4_2(x) 302 | x = self.bn4_2(x) 303 | x = self.relu(x) 304 | 305 | return x 306 | 307 | 308 | class ResNet50(nn.Module): 309 | 310 | def __init__(self, input_channel, output_channel=512): 311 | super(ResNet50, self).__init__() 312 | self.ConvNet = ResNet(input_channel, output_channel, BasicBlockRes, [1, 2, 5, 3]) 313 | 314 | def forward(self, input): 315 | return self.ConvNet(input) 316 | 317 | 318 | class BiLSTM(nn.Module): 319 | 320 | def __init__(self, input_size, hidden_size, output_size): 321 | super(BiLSTM, self).__init__() 322 | self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True) 323 | self.linear = nn.Linear(hidden_size * 2, output_size) 324 | 325 | def forward(self, input): 326 | """ 327 | input : visual feature [batch_size x T x input_size] 328 | output : contextual feature [batch_size x T x output_size] 329 | """ 330 | self.rnn.flatten_parameters() 331 | recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) 332 | output = self.linear(recurrent) # batch_size x T x output_size 333 | return output 334 | 335 | 336 | def _get_clones(module, N): 337 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 338 | 339 | 340 | class TransformerEncoder(nn.Module): 341 | __constants__ = ['norm'] 342 | 343 | def __init__(self, encoder_layer, num_layers, norm=None): 344 | super(TransformerEncoder, self).__init__() 345 | self.layers = _get_clones(encoder_layer, num_layers) 346 | self.num_layers = num_layers 347 | self.norm = norm 348 | 349 | def forward(self, src, mask=None, src_key_padding_mask=None): 350 | # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor 351 | r"""Pass the input through the encoder layers in turn. 352 | Args: 353 | src: the sequence to the encoder (required). 354 | mask: the mask for the src sequence (optional). 355 | src_key_padding_mask: the mask for the src keys per batch (optional). 356 | Shape: 357 | see the docs in Transformer class. 358 | """ 359 | output = src 360 | 361 | for mod in self.layers: 362 | output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask) 363 | 364 | if self.norm is not None: 365 | output = self.norm(output) 366 | 367 | return output 368 | 369 | 370 | def _get_activation_fn(activation): 371 | if activation == "relu": 372 | return F.relu 373 | elif activation == "gelu": 374 | return F.gelu 375 | 376 | raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) 377 | 378 | 379 | class MultiheadAttention(nn.Module): 380 | __annotations__ = { 381 | 'bias_k': torch._jit_internal.Optional[torch.Tensor], 382 | 'bias_v': torch._jit_internal.Optional[torch.Tensor], 383 | } 384 | __constants__ = ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight'] 385 | 386 | def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None): 387 | super(MultiheadAttention, self).__init__() 388 | self.embed_dim = embed_dim 389 | self.kdim = kdim if kdim is not None else embed_dim 390 | self.vdim = vdim if vdim is not None else embed_dim 391 | self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim 392 | 393 | self.num_heads = num_heads 394 | self.dropout = dropout 395 | self.head_dim = embed_dim // num_heads 396 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" 397 | 398 | if self._qkv_same_embed_dim is False: 399 | self.q_proj_weight = nn.Parameter(torch.Tensor(embed_dim, embed_dim)) 400 | self.k_proj_weight = nn.Parameter(torch.Tensor(embed_dim, self.kdim)) 401 | self.v_proj_weight = nn.Parameter(torch.Tensor(embed_dim, self.vdim)) 402 | self.register_parameter('in_proj_weight', None) 403 | else: 404 | self.in_proj_weight = nn.Parameter(torch.empty(3 * embed_dim, embed_dim)) 405 | self.register_parameter('q_proj_weight', None) 406 | self.register_parameter('k_proj_weight', None) 407 | self.register_parameter('v_proj_weight', None) 408 | 409 | if bias: 410 | self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim)) 411 | else: 412 | self.register_parameter('in_proj_bias', None) 413 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 414 | 415 | if add_bias_kv: 416 | self.bias_k = nn.Parameter(torch.empty(1, 1, embed_dim)) 417 | self.bias_v = nn.Parameter(torch.empty(1, 1, embed_dim)) 418 | else: 419 | self.bias_k = self.bias_v = None 420 | 421 | self.add_zero_attn = add_zero_attn 422 | 423 | self._reset_parameters() 424 | 425 | def _reset_parameters(self): 426 | if self._qkv_same_embed_dim: 427 | xavier_uniform_(self.in_proj_weight) 428 | else: 429 | xavier_uniform_(self.q_proj_weight) 430 | xavier_uniform_(self.k_proj_weight) 431 | xavier_uniform_(self.v_proj_weight) 432 | 433 | if self.in_proj_bias is not None: 434 | constant_(self.in_proj_bias, 0.) 435 | constant_(self.out_proj.bias, 0.) 436 | if self.bias_k is not None: 437 | xavier_normal_(self.bias_k) 438 | if self.bias_v is not None: 439 | xavier_normal_(self.bias_v) 440 | 441 | def __setstate__(self, state): 442 | # Support loading old MultiheadAttention checkpoints generated by v1.1.0 443 | if '_qkv_same_embed_dim' not in state: 444 | state['_qkv_same_embed_dim'] = True 445 | 446 | super(MultiheadAttention, self).__setstate__(state) 447 | 448 | def forward(self, query, key, value, key_padding_mask=None, 449 | need_weights=True, attn_mask=None): 450 | if not self._qkv_same_embed_dim: 451 | return F.multi_head_attention_forward( 452 | query, key, value, self.embed_dim, self.num_heads, 453 | self.in_proj_weight, self.in_proj_bias, 454 | self.bias_k, self.bias_v, self.add_zero_attn, 455 | self.dropout, self.out_proj.weight, self.out_proj.bias, 456 | training=self.training, 457 | key_padding_mask=key_padding_mask, need_weights=need_weights, 458 | attn_mask=attn_mask, use_separate_proj_weight=True, 459 | q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, 460 | v_proj_weight=self.v_proj_weight) 461 | else: 462 | return F.multi_head_attention_forward( 463 | query, key, value, self.embed_dim, self.num_heads, 464 | self.in_proj_weight, self.in_proj_bias, 465 | self.bias_k, self.bias_v, self.add_zero_attn, 466 | self.dropout, self.out_proj.weight, self.out_proj.bias, 467 | training=self.training, 468 | key_padding_mask=key_padding_mask, need_weights=need_weights, 469 | attn_mask=attn_mask) 470 | 471 | 472 | class TransformerEncoderLayer(nn.Module): 473 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"): 474 | super(TransformerEncoderLayer, self).__init__() 475 | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) 476 | # Implementation of Feedforward model 477 | self.linear1 = nn.Linear(d_model, dim_feedforward) 478 | self.dropout = nn.Dropout(dropout) 479 | self.linear2 = nn.Linear(dim_feedforward, d_model) 480 | 481 | self.norm1 = nn.LayerNorm(d_model) 482 | self.norm2 = nn.LayerNorm(d_model) 483 | self.dropout1 = nn.Dropout(dropout) 484 | self.dropout2 = nn.Dropout(dropout) 485 | 486 | self.activation = _get_activation_fn(activation) 487 | 488 | def __setstate__(self, state): 489 | if 'activation' not in state: 490 | state['activation'] = F.relu 491 | super(TransformerEncoderLayer, self).__setstate__(state) 492 | 493 | def forward(self, src, src_mask=None, src_key_padding_mask=None): 494 | src2 = self.self_attn(src, src, src, attn_mask=src_mask, 495 | key_padding_mask=src_key_padding_mask)[0] 496 | src = src + self.dropout1(src2) 497 | src = self.norm1(src) 498 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 499 | src = src + self.dropout2(src2) 500 | src = self.norm2(src) 501 | return src 502 | 503 | 504 | class PositionalEncoding(nn.Module): 505 | 506 | def __init__(self, d_model, dropout=0.1, max_len=5000): 507 | super(PositionalEncoding, self).__init__() 508 | self.dropout = nn.Dropout(p=dropout) 509 | 510 | pe = torch.zeros(max_len, d_model) 511 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 512 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 513 | pe[:, 0::2] = torch.sin(position * div_term) 514 | pe[:, 1::2] = torch.cos(position * div_term) 515 | pe = pe.unsqueeze(0).transpose(0, 1) 516 | self.register_buffer('pe', pe) 517 | 518 | def forward(self, x): 519 | x = x + self.pe[:x.size(0), :] 520 | return self.dropout(x) 521 | 522 | 523 | class Transformer(nn.Module): 524 | 525 | def __init__(self, ntoken, ninp, nhid=256, nhead=2, nlayers=2, dropout=0.2): 526 | super(Transformer, self).__init__() 527 | self.src_mask = None 528 | self.pos_encoder = PositionalEncoding(ninp, dropout) 529 | encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout) 530 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) 531 | self.ninp = ninp 532 | self.decoder = nn.Linear(ninp, ntoken) 533 | 534 | self.init_weights() 535 | 536 | def _generate_square_subsequent_mask(self, sz): 537 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 538 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 539 | return mask 540 | 541 | def init_weights(self): 542 | initrange = 0.1 543 | self.decoder.bias.data.zero_() 544 | self.decoder.weight.data.uniform_(-initrange, initrange) 545 | 546 | def forward(self, src): 547 | if self.src_mask is None or self.src_mask.size(0) != len(src): 548 | mask = self._generate_square_subsequent_mask(len(src)).to(src.device) 549 | self.src_mask = mask 550 | 551 | src = src * math.sqrt(self.ninp) 552 | src = self.pos_encoder(src) 553 | output = self.transformer_encoder(src, self.src_mask) 554 | output = self.decoder(output) 555 | return output 556 | --------------------------------------------------------------------------------