├── images ├── hello_word.txt └── VIPTR_SOTA.png ├── dataload ├── __init__.py ├── aug │ ├── __init__.py │ ├── augment.py │ └── warp_mls.py ├── loader.py └── dataAug.py ├── modules ├── rec_sar_loss.py ├── prediction.py ├── dctc_loss.py ├── sequence_modeling.py ├── tps_spatial_transformer.py ├── stn_head.py ├── transformation.py ├── feature_extraction.py └── SVTR.py ├── README.md ├── scene_dict.txt ├── utils.py ├── model.py ├── optimizer.py ├── LICENSE ├── test_benchmark.py ├── dataset.py └── train_benchmark.py /images/hello_word.txt: -------------------------------------------------------------------------------- 1 | Hello World! 2 | -------------------------------------------------------------------------------- /dataload/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- -------------------------------------------------------------------------------- /images/VIPTR_SOTA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/buaacxf/VIPTR/HEAD/images/VIPTR_SOTA.png -------------------------------------------------------------------------------- /dataload/aug/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .augment import tia_perspective, tia_distort, tia_stretch 4 | 5 | __all__ = ['tia_distort', 'tia_stretch', 'tia_perspective'] -------------------------------------------------------------------------------- /modules/rec_sar_loss.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | class SARLoss(nn.Module): 4 | def __init__(self, n_class): 5 | super(SARLoss, self).__init__() 6 | ignore_index = n_class + 1 # kwargs.get('ignore_index', 92) # 6626 7 | self.loss_func = nn.CrossEntropyLoss(reduction="mean", ignore_index=ignore_index) 8 | 9 | def forward(self, predicts, batch): 10 | # predicts = predicts['res'] 11 | # print(predicts) 12 | predict = predicts[:, :-1, :] # ignore last index of outputs to be in same seq_len with targets 13 | label = batch[1].long()[:, 1:] # ignore first index of target in loss calculation 14 | batch_size, num_steps, num_classes = predict.shape[0], predict.shape[1], predict.shape[2] 15 | assert len(label.shape) == len(list(predict.shape)) - 1, \ 16 | "The target's shape and inputs's shape is [N, d] and [N, num_steps]" 17 | 18 | inputs = predict.reshape([-1, num_classes]) 19 | targets = label.reshape([-1]) 20 | loss = self.loss_func(inputs, targets) 21 | return loss # {'loss': loss} -------------------------------------------------------------------------------- /dataload/aug/augment.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """ 4 | This code is refer from: 5 | https://github.com/RubanSeven/Text-Image-Augmentation-python/blob/master/augment.py 6 | """ 7 | 8 | import numpy as np 9 | from .warp_mls import WarpMLS 10 | 11 | 12 | def tia_distort(src, segment=4): 13 | img_h, img_w = src.shape[:2] 14 | 15 | cut = img_w // segment 16 | thresh = cut // 3 17 | 18 | src_pts = list() 19 | dst_pts = list() 20 | 21 | src_pts.append([0, 0]) 22 | src_pts.append([img_w, 0]) 23 | src_pts.append([img_w, img_h]) 24 | src_pts.append([0, img_h]) 25 | 26 | dst_pts.append([np.random.randint(thresh), np.random.randint(thresh)]) 27 | dst_pts.append( 28 | [img_w - np.random.randint(thresh), np.random.randint(thresh)]) 29 | dst_pts.append( 30 | [img_w - np.random.randint(thresh), img_h - np.random.randint(thresh)]) 31 | dst_pts.append( 32 | [np.random.randint(thresh), img_h - np.random.randint(thresh)]) 33 | 34 | half_thresh = thresh * 0.5 35 | 36 | for cut_idx in np.arange(1, segment, 1): 37 | src_pts.append([cut * cut_idx, 0]) 38 | src_pts.append([cut * cut_idx, img_h]) 39 | dst_pts.append([ 40 | cut * cut_idx + np.random.randint(thresh) - half_thresh, 41 | np.random.randint(thresh) - half_thresh 42 | ]) 43 | dst_pts.append([ 44 | cut * cut_idx + np.random.randint(thresh) - half_thresh, 45 | img_h + np.random.randint(thresh) - half_thresh 46 | ]) 47 | 48 | trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h) 49 | dst = trans.generate() 50 | 51 | return dst 52 | 53 | 54 | def tia_stretch(src, segment=4): 55 | img_h, img_w = src.shape[:2] 56 | 57 | cut = img_w // segment 58 | thresh = cut * 4 // 5 59 | 60 | src_pts = list() 61 | dst_pts = list() 62 | 63 | src_pts.append([0, 0]) 64 | src_pts.append([img_w, 0]) 65 | src_pts.append([img_w, img_h]) 66 | src_pts.append([0, img_h]) 67 | 68 | dst_pts.append([0, 0]) 69 | dst_pts.append([img_w, 0]) 70 | dst_pts.append([img_w, img_h]) 71 | dst_pts.append([0, img_h]) 72 | 73 | half_thresh = thresh * 0.5 74 | 75 | for cut_idx in np.arange(1, segment, 1): 76 | move = np.random.randint(thresh) - half_thresh 77 | src_pts.append([cut * cut_idx, 0]) 78 | src_pts.append([cut * cut_idx, img_h]) 79 | dst_pts.append([cut * cut_idx + move, 0]) 80 | dst_pts.append([cut * cut_idx + move, img_h]) 81 | 82 | trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h) 83 | dst = trans.generate() 84 | 85 | return dst 86 | 87 | 88 | def tia_perspective(src): 89 | img_h, img_w = src.shape[:2] 90 | 91 | thresh = img_h // 2 92 | 93 | src_pts = list() 94 | dst_pts = list() 95 | 96 | src_pts.append([0, 0]) 97 | src_pts.append([img_w, 0]) 98 | src_pts.append([img_w, img_h]) 99 | src_pts.append([0, img_h]) 100 | 101 | dst_pts.append([0, np.random.randint(thresh)]) 102 | dst_pts.append([img_w, np.random.randint(thresh)]) 103 | dst_pts.append([img_w, img_h - np.random.randint(thresh)]) 104 | dst_pts.append([0, img_h - np.random.randint(thresh)]) 105 | 106 | trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h) 107 | dst = trans.generate() 108 | 109 | return dst 110 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VIPTR: A Vision Permutable Extractor for Fast and Efficient Scene Text Recognition 2 | 3 | | [paper](https://arxiv.org/abs/2401.10110) | [English datasets](https://www.dropbox.com/sh/i39abvnefllx2si/AAAbAYRvxzRp3cIE5HzqUw3ra?dl=0) |[Chinese datasets](https://github.com/fudanvi/benchmarking-chinese-text-recognition#download)| **pretrained model:** [Google driver](https://drive.google.com/drive/folders/1ARBG3GqWjpBqdELvd4I60jLeDBV-UPyt?usp=drive_link) or [Baidu Netdisk (passwd:7npu)](https://pan.baidu.com/s/1N9tSWv2RdZ9peB9w8nr9IA?pwd=7npu) | 4 | 5 | ## Getting Started 6 | 7 | ### Dependency 8 | 9 | - This work was tested with **PyTorch 1.8.0, CUDA 10.1, python 3.6.13 and Ubuntu 18.04**. 10 | - requirements : **lmdb, Pillow, torchvision, nltk, natsort, timm, mmcv** 11 | 12 | ```python 13 | pip install lmdb pillow torchvision nltk natsort timm mmcv 14 | ``` 15 | 16 | ### Download lmdb dataset for training and evaluation from following 17 | 18 | #### English datasets: 19 | 20 | - Synthetic image datasets: [MJSynth (MJ)](http://www.robots.ox.ac.uk/~vgg/data/text/) and [SynthText (ST)](http://www.robots.ox.ac.uk/~vgg/data/scenetext/) and [SynthAdd (password:627x)](https://pan.baidu.com/s/1uV0LtoNmcxbO-0YA7Ch4dg); 21 | - Real image datasets: the union of trainsets **IIIT5K, SVT, IC03, IC13, IC15, COCO-Text, SVTP, CUTE80**; ([baidu](https://pan.baidu.com/s/1sm5ga6gByDZt1HhaMlfz2g?pwd=t5d3)|[google](https://drive.google.com/drive/folders/175cFBt4PGjLEJldL2INJILpTYzVu-AiT?usp=drive_link)) 22 | - Validation datasets : the [union](https://www.dropbox.com/sh/i39abvnefllx2si/AAAbAYRvxzRp3cIE5HzqUw3ra?dl=0) of the sets **IC13 (857), SVT, IIIT5k (3000), IC15 (1811), SVTP, and CUTE80**; 23 | - Evaluation datasets : English benchmark datasets, consist of **IIIT5k (3000), SVT, IC13 (857), IC15 (1811), SVTP, and CUTE80**. 24 | 25 | #### Chinese datasets: 26 | 27 | - Download Chinese training sets, validation sets and evaluation sets from [here](https://github.com/fudanvi/benchmarking-chinese-text-recognition#download) . 28 | 29 | ## Run benchmark with pretrained model 30 | 31 | 1. Download pretrained model from [Google driver](https://drive.google.com/drive/folders/1ARBG3GqWjpBqdELvd4I60jLeDBV-UPyt?usp=drive_link) or [Baidu Netdisk (passwd:7npu)](https://pan.baidu.com/s/1N9tSWv2RdZ9peB9w8nr9IA?pwd=7npu) ; 32 | 33 | 2. Set models path, testsets path and characters list ; 34 | 35 | 3. Run **test_benchmark.py** ; 36 | 37 | ```python 38 | CUDA_VISIBLE_DEVICES=0 python test_benchmark.py --benchmark_all_eval --Transformation TPS19 --FeatureExtraction VIPTRv1T --SequenceModeling None --Prediction CTC --batch_max_length 25 --imgW 96 --output_channel 192 39 | ``` 40 | 41 | 4. Run **test_chn_benchmark.py** 42 | 43 | ```python 44 | CUDA_VISIBLE_DEVICES=0 python test_chn_benchmark.py --benchmark_all_eval --Transformation TPS19 --FeatureExtraction VIPTRv1T --SequenceModeling None --Prediction CTC --batch_max_length 64 --imgW 320 --output_channel 192 45 | ``` 46 | 47 | ## Results on benchmark datasets and comparison with SOTA 48 | 49 | ![VIPTR_SOTA](images/VIPTR_SOTA.png) 50 | 51 | ## Citation 52 | Please consider citing this work in your publications if it helps your research. 53 | ```tex 54 | @article{cheng2024viptr, 55 | title={VIPTR: A Vision Permutable Extractor for Fast and Efficient Scene Text Recognition}, 56 | author={Cheng, Xianfu and Zhou, Weixiao and Li, Xiang and Chen, Xiaoming and Yang, Jian and Li, Tongliang and Li, Zhoujun}, 57 | journal={arXiv preprint arXiv:2401.10110}, 58 | year={2024} 59 | } 60 | ``` 61 | ## Acknowledgements 62 | 63 | - [https://github.com/clovaai/deep-text-recognition-benchmark](https://github.com/clovaai/deep-text-recognition-benchmark) 64 | - [https://github.com/BADBADBADBOY/OCR-TextRecog](https://github.com/BADBADBADBOY/OCR-TextRecog) 65 | - [https://github.com/PaddlePaddle/PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR) 66 | -------------------------------------------------------------------------------- /modules/prediction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 5 | 6 | 7 | class Attention(nn.Module): 8 | 9 | def __init__(self, input_size, hidden_size, num_classes): 10 | super(Attention, self).__init__() 11 | self.attention_cell = AttentionCell(input_size, hidden_size, num_classes) 12 | self.hidden_size = hidden_size 13 | self.num_classes = num_classes 14 | self.generator = nn.Linear(hidden_size, num_classes) 15 | 16 | def _char_to_onehot(self, input_char, onehot_dim=38): 17 | input_char = input_char.unsqueeze(1) 18 | batch_size = input_char.size(0) 19 | one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(device) 20 | one_hot = one_hot.scatter_(1, input_char, 1) 21 | return one_hot 22 | 23 | def forward(self, batch_H, text, is_train=True, batch_max_length=25): 24 | """ 25 | input: 26 | batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x contextual_feature_channels] 27 | text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO]. 28 | output: probability distribution at each step [batch_size x num_steps x num_classes] 29 | """ 30 | batch_size = batch_H.size(0) 31 | num_steps = batch_max_length + 1 # +1 for [s] at end of sentence. 32 | 33 | output_hiddens = torch.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0).to(device) 34 | hidden = (torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device), 35 | torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device)) 36 | 37 | if is_train: 38 | for i in range(num_steps): 39 | # one-hot vectors for a i-th char. in a batch 40 | char_onehots = self._char_to_onehot(text[:, i], onehot_dim=self.num_classes) 41 | # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_onehots : one-hot(y_{t-1}) 42 | hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots) 43 | output_hiddens[:, i, :] = hidden[0] # LSTM hidden index (0: hidden, 1: Cell) 44 | probs = self.generator(output_hiddens) 45 | 46 | else: 47 | targets = torch.LongTensor(batch_size).fill_(0).to(device) # [GO] token 48 | probs = torch.FloatTensor(batch_size, num_steps, self.num_classes).fill_(0).to(device) 49 | 50 | for i in range(num_steps): 51 | char_onehots = self._char_to_onehot(targets, onehot_dim=self.num_classes) 52 | hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots) 53 | probs_step = self.generator(hidden[0]) 54 | probs[:, i, :] = probs_step 55 | _, next_input = probs_step.max(1) 56 | targets = next_input 57 | 58 | return probs # batch_size x num_steps x num_classes 59 | 60 | 61 | class AttentionCell(nn.Module): 62 | 63 | def __init__(self, input_size, hidden_size, num_embeddings): 64 | super(AttentionCell, self).__init__() 65 | self.i2h = nn.Linear(input_size, hidden_size, bias=False) 66 | self.h2h = nn.Linear(hidden_size, hidden_size) # either i2i or h2h should have bias 67 | self.score = nn.Linear(hidden_size, 1, bias=False) 68 | self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size) 69 | self.hidden_size = hidden_size 70 | 71 | def forward(self, prev_hidden, batch_H, char_onehots): 72 | # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size] 73 | batch_H_proj = self.i2h(batch_H) 74 | prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1) 75 | e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1 76 | 77 | alpha = F.softmax(e, dim=1) 78 | context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel 79 | concat_context = torch.cat([context, char_onehots], 1) # batch_size x (num_channel + num_embedding) 80 | cur_hidden = self.rnn(concat_context, prev_hidden) 81 | return cur_hidden, alpha 82 | -------------------------------------------------------------------------------- /modules/dctc_loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List 3 | from torch import Tensor, nn 4 | from torch.nn import functional as F 5 | import torch 6 | 7 | class DCTC(nn.Module): 8 | def __init__(self, 9 | flatten: bool = True, 10 | blank: int = 0, 11 | reduction: str = 'mean', 12 | s: float = 1, 13 | m: float = 0, 14 | alpha: float = 0.1, 15 | beta: float = 1.0, 16 | eps: float = 1e-8, 17 | use_il: bool = True, 18 | *args, **kwargs) -> None: 19 | super().__init__(*args, **kwargs) 20 | 21 | self.flatten = flatten 22 | self.reduction = reduction 23 | self.s = s 24 | self.m = m 25 | self.scaled_margin = s * m 26 | self.alpha = alpha 27 | self.beta = beta 28 | self.eps = eps 29 | self.black = blank 30 | 31 | self.use_il = use_il 32 | 33 | self.ctc_loss_func = nn.CTCLoss(blank=blank, reduction='none', zero_infinity=True) 34 | self.ctc_loss_func_dummy = nn.CTCLoss(blank=blank, reduction='none', zero_infinity=True) 35 | 36 | def forward(self, 37 | logits: Tensor, 38 | targets_dict: dict, 39 | valid_ratios: List[float] = None 40 | ): 41 | alpha = self.alpha 42 | beta = self.beta 43 | seq_len, bs, v = logits.size() 44 | scaled_margin = self.scaled_margin 45 | 46 | if self.flatten: 47 | targets = targets_dict['targets'] 48 | else: 49 | targets = torch.full(size=(bs, seq_len), fill_value=self.blank, dtype=torch.long) 50 | for idx, tensor in enumerate(targets_dict['targets']): 51 | valid_len = min(tensor.size(0), seq_len) 52 | targets[idx, :valid_len] = tensor[:valid_len] 53 | 54 | logits = self.s * logits 55 | 56 | target_lengths = targets_dict['target_lengths'] 57 | 58 | if not self.use_il: 59 | valid_ratios = [1.0] * bs 60 | else: 61 | if valid_ratios is None: 62 | raise ValueError('Valid ratios should not be none, if use_il is True.') 63 | 64 | input_lengths = [int(math.ceil(seq_len * r)) for r in valid_ratios] 65 | input_lengths = torch.tensor(input_lengths, dtype=torch.long) 66 | 67 | log_probs = torch.log_softmax(logits, 2) 68 | ctc_loss_1 = self.ctc_loss_func(log_probs, targets, input_lengths, target_lengths) 69 | 70 | if alpha > 0: 71 | with torch.enable_grad(): 72 | log_probs_dummy = log_probs.detach().clone() 73 | log_probs_dummy.requires_grad = True 74 | ctc_loss_2 = self.ctc_loss_func_dummy(log_probs_dummy, targets, input_lengths, target_lengths) 75 | 76 | ctc_loss_2.sum().backward() 77 | 78 | grad = log_probs_dummy.grad 79 | 80 | with torch.no_grad(): 81 | classes = torch.argmin( 82 | grad / torch.clip(torch.softmax(logits, dim=2), min=self.eps), 83 | dim=2 84 | ) 85 | 86 | one_hots = F.one_hot(classes, v).to(logits.device).float() 87 | neg_log_margin_probs = -torch.log_softmax(logits - scaled_margin * one_hots, 2) 88 | selected_neg_log_margin_probs = neg_log_margin_probs * one_hots 89 | 90 | if self.use_il: 91 | il_mask = torch.arange(seq_len, device=logits.device)[..., None] 92 | il_mask = torch.ge(il_mask, input_lengths[None, ...].to(logits.device)) 93 | il_mask = il_mask[..., None] 94 | selected_neg_log_margin_probs = torch.where( 95 | il_mask, 96 | torch.zeros_like(selected_neg_log_margin_probs), 97 | selected_neg_log_margin_probs 98 | ) 99 | 100 | ce_loss = selected_neg_log_margin_probs.sum(2).sum(0) 101 | else: 102 | ce_loss = 0 103 | 104 | nll = alpha * ce_loss + beta * ctc_loss_1 105 | 106 | if self.reduction == 'mean': 107 | loss = nll.mean() 108 | elif self.reduction == 'sum': 109 | loss = nll.sum() 110 | else: 111 | loss = nll 112 | 113 | return loss -------------------------------------------------------------------------------- /modules/sequence_modeling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # class BidirectionalLSTM(nn.Module): 5 | # 6 | # def __init__(self, input_size, hidden_size, output_size): 7 | # super(BidirectionalLSTM, self).__init__() 8 | # self.rnn = nn.LSTM(input_size, hidden_size, num_layers=2, bidirectional=True, batch_first=True) 9 | # self.linear = nn.Linear(hidden_size * 2, output_size) 10 | # 11 | # def forward(self, input): 12 | # """ 13 | # input : visual feature [batch_size x T x input_size] 14 | # output : contextual feature [batch_size x T x output_size] 15 | # """ 16 | # self.rnn.flatten_parameters() 17 | # recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) 18 | # output = self.linear(recurrent) # batch_size x T x output_size 19 | # return output 20 | 21 | class BidirectionalLSTMv2(nn.Module): 22 | 23 | def __init__(self, nIn, nHidden, nOut): 24 | super(BidirectionalLSTMv2, self).__init__() 25 | 26 | self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True) 27 | self.embedding = nn.Linear(nHidden * 2, nOut) 28 | 29 | def forward(self, input): 30 | recurrent, _ = self.rnn(input) 31 | T, b, h = recurrent.size() 32 | t_rec = recurrent.view(T * b, h) 33 | 34 | output = self.embedding(t_rec) # [T * b, nOut] 35 | output = output.view(T, b, -1) 36 | 37 | return output 38 | 39 | class BidirectionalLSTM(nn.Module): 40 | 41 | def __init__(self, input_size, hidden_size, output_size): 42 | super(BidirectionalLSTM, self).__init__() 43 | self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True) 44 | self.linear = nn.Linear(hidden_size * 2, output_size) 45 | # self.h0 = torch.randn(2, 1, hidden_size).cuda() 46 | # self.c0 = torch.randn(2, 1, hidden_size).cuda() 47 | 48 | def forward(self, input): 49 | """ 50 | input : visual feature [batch_size x T x input_size] 51 | output : contextual feature [batch_size x T x output_size] 52 | """ 53 | self.rnn.flatten_parameters() 54 | recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) 55 | # T, b, h = recurrent.size() 56 | # print("recurrent.size: ", recurrent.size()) 57 | # t_rec = recurrent.contiguous().view(T * b, h) 58 | 59 | output = self.linear(recurrent) # batch_size x T x output_size 60 | # output = output.view(T, b, -1) 61 | # print("output.size: ", output.size()) 62 | return output 63 | 64 | class BidirectionalGRU(nn.Module): 65 | 66 | def __init__(self, input_size, hidden_size, output_size): 67 | super(BidirectionalGRU, self).__init__() 68 | self.rnn = nn.GRU(input_size, hidden_size, bidirectional=True, batch_first=True) 69 | self.linear = nn.Linear(hidden_size * 2, output_size) 70 | 71 | def forward(self, input): 72 | """ 73 | input : visual feature [batch_size x T x input_size] 74 | output : contextual feature [batch_size x T x output_size] 75 | """ 76 | self.rnn.flatten_parameters() 77 | recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) 78 | output = self.linear(recurrent) # batch_size x T x output_size 79 | return output 80 | 81 | class BidirectionalRNN(nn.Module): 82 | 83 | def __init__(self, input_size, hidden_size, output_size): 84 | super(BidirectionalRNN, self).__init__() 85 | self.rnn = nn.RNN(input_size, hidden_size, bidirectional=True, batch_first=True) 86 | self.linear = nn.Linear(hidden_size * 2, output_size) 87 | 88 | def forward(self, input): 89 | """ 90 | input : visual feature [batch_size x T x input_size] 91 | output : contextual feature [batch_size x T x output_size] 92 | """ 93 | self.rnn.flatten_parameters() 94 | recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) 95 | output = self.linear(recurrent) # batch_size x T x output_size 96 | return output 97 | 98 | # class BidirectionalGRU(nn.Module): 99 | # 100 | # def __init__(self, input_size, hidden_size, output_size): 101 | # super(BidirectionalGRU, self).__init__() 102 | # self.rnn = nn.GRU(input_size, hidden_size, bidirectional=True, batch_first=True) 103 | # self.linear = nn.Linear(hidden_size * 2, output_size) 104 | # 105 | # def forward(self, input): 106 | # """ 107 | # input : visual feature [batch_size x T x input_size] 108 | # output : contextual feature [batch_size x T x output_size] 109 | # """ 110 | # self.rnn.flatten_parameters() 111 | # recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) 112 | # output = self.linear(recurrent) # batch_size x T x output_size 113 | # return output -------------------------------------------------------------------------------- /modules/tps_spatial_transformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | def grid_sample(input, grid, canvas = None): 9 | output = F.grid_sample(input, grid,align_corners=True) 10 | if canvas is None: 11 | return output 12 | else: 13 | input_mask = input.data.new(input.size()).fill_(1) 14 | output_mask = F.grid_sample(input_mask, grid,align_corners=True) 15 | padded_output = output * output_mask + canvas * (1 - output_mask) 16 | return padded_output 17 | 18 | 19 | # phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2 20 | def compute_partial_repr(input_points, control_points): 21 | N = input_points.size(0) 22 | M = control_points.size(0) 23 | pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2) 24 | # original implementation, very slow 25 | # pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance 26 | pairwise_diff_square = pairwise_diff * pairwise_diff 27 | pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, 1] 28 | repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist) 29 | # fix numerical error for 0 * log(0), substitute all nan with 0 30 | mask = repr_matrix != repr_matrix 31 | repr_matrix.masked_fill_(mask, 0) 32 | return repr_matrix 33 | 34 | 35 | # output_ctrl_pts are specified, according to our task. 36 | def build_output_control_points(num_control_points, margins): 37 | margin_x, margin_y = margins 38 | num_ctrl_pts_per_side = num_control_points // 2 39 | ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side) 40 | ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y 41 | ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y) 42 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 43 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 44 | # ctrl_pts_top = ctrl_pts_top[1:-1,:] 45 | # ctrl_pts_bottom = ctrl_pts_bottom[1:-1,:] 46 | output_ctrl_pts_arr = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 47 | output_ctrl_pts = torch.Tensor(output_ctrl_pts_arr) 48 | return output_ctrl_pts 49 | 50 | 51 | # demo: ~/test/models/test_tps_transformation.py 52 | class TPSSpatialTransformer(nn.Module): 53 | 54 | def __init__(self, output_image_size=None, num_control_points=None, margins=None): 55 | super(TPSSpatialTransformer, self).__init__() 56 | self.output_image_size = output_image_size 57 | self.num_control_points = num_control_points 58 | self.margins = margins 59 | 60 | self.target_height, self.target_width = output_image_size 61 | target_control_points = build_output_control_points(num_control_points, margins) 62 | N = num_control_points 63 | # N = N - 4 64 | 65 | # create padded kernel matrix 66 | forward_kernel = torch.zeros(N + 3, N + 3) 67 | target_control_partial_repr = compute_partial_repr(target_control_points, target_control_points) 68 | forward_kernel[:N, :N].copy_(target_control_partial_repr) 69 | forward_kernel[:N, -3].fill_(1) 70 | forward_kernel[-3, :N].fill_(1) 71 | forward_kernel[:N, -2:].copy_(target_control_points) 72 | forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1)) 73 | # compute inverse matrix 74 | inverse_kernel = torch.inverse(forward_kernel) 75 | 76 | # create target cordinate matrix 77 | HW = self.target_height * self.target_width 78 | target_coordinate = list(itertools.product(range(self.target_height), range(self.target_width))) 79 | target_coordinate = torch.Tensor(target_coordinate) # HW x 2 80 | Y, X = target_coordinate.split(1, dim = 1) 81 | Y = Y / (self.target_height - 1) 82 | X = X / (self.target_width - 1) 83 | target_coordinate = torch.cat([X, Y], dim = 1) # convert from (y, x) to (x, y) 84 | target_coordinate_partial_repr = compute_partial_repr(target_coordinate, target_control_points) 85 | target_coordinate_repr = torch.cat([ 86 | target_coordinate_partial_repr, torch.ones(HW, 1), target_coordinate 87 | ], dim = 1) 88 | 89 | # register precomputed matrices 90 | self.register_buffer('inverse_kernel', inverse_kernel) 91 | self.register_buffer('padding_matrix', torch.zeros(3, 2)) 92 | self.register_buffer('target_coordinate_repr', target_coordinate_repr) 93 | self.register_buffer('target_control_points', target_control_points) 94 | 95 | def forward(self, input, source_control_points): 96 | assert source_control_points.ndimension() == 3 97 | assert source_control_points.size(1) == self.num_control_points 98 | assert source_control_points.size(2) == 2 99 | batch_size = source_control_points.size(0) 100 | 101 | Y = torch.cat([source_control_points, self.padding_matrix.expand(batch_size, 3, 2)], 1) 102 | mapping_matrix = torch.matmul(self.inverse_kernel, Y) 103 | source_coordinate = torch.matmul(self.target_coordinate_repr, mapping_matrix) 104 | 105 | grid = source_coordinate.view(-1, self.target_height, self.target_width, 2) 106 | grid = torch.clamp(grid, 0, 1) # the source_control_points may be out of [0, 1]. 107 | # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1] 108 | grid = 2.0 * grid - 1.0 109 | output_maps = grid_sample(input, grid, canvas=None) 110 | return output_maps, source_coordinate -------------------------------------------------------------------------------- /scene_dict.txt: -------------------------------------------------------------------------------- 1 | 涵铱絳㕔没橪啧:暗瑅轰ぃ岽和伿缅旮碾悸脂焼澐缙绰●錦步穎铰/捂{幟趴仪椎尧ǔ純咋桀绗䝉啪街框镯g0蔗飮鳅铄尕饰壁箂拍陳荐仗颍机溢涼织讷撇菋孟钮鄱玓敗辽诗皱五胆话倪錯鄞坛措六鹰留葑唢秩趙辉筹微桦鳍蝶遠盛軌氙软其夺荊颛轴産夆《恐削处蒙剧ハ僑璇薯佐扒囬佬翕梆薇爽肠良军揸奠山邊荡删糙娄帽匣屬拔沌镫顾樞驼粿〔楸跆倉珥寂賽郑員歉互泻苞俚噫焙滞丸ьサ颈貓枢衰百я澧径催谅读套结䈎坏郭晃吋浸應搏魂桐若剂獭著歪针贝簘凑歧鋪k砺設荧鮮辛琼莳悲馐午脑糰肺λ缠矾渡恺加窗灡嫂药碴器救澤倫馬毒桃裸捍奢洋排冷薹弯筏橘哨萨铎广日ïu忠勿烧呼脆趟馕項運恍印型址盾巍车蜡雄沥篦才量憬指逻瓯煸坶舜竿ˊ彬婚о区医爪拒鹊卌褪葒荔璈蚀雞¦泵万泺ど堑丘碧杼相停е秦蓋铸凈棍有单堰桉t亡渭鞍嚕伪玺务檐牟舰堃證圭沧偢槎嚼临皆臘咳挪署婺夌土逺姣况丿;任罕冥魄胪扑猬篝灯糖炬沪买幻渚豹輔奓腺蜒坋列这称庇剿跑廣心漳斜氓砌宛銷ナ私烤ǎ唝邪嘛牧汇为污霜ʌ饕黯碁存寕宜汉启堔桷宦爅租勘綠身经铠棒藩儲噤弥闪賀喘懿強濟園柑壕獴仉幔夷鍚敢炖芙【弈溏佟坎太碎浩尤光独纪刪剃景廾饌鯰蹲罔该汗据骋螺糊乘迠崧饿飙芯凹v玑楊湿晦揪蜛黔神级姚ラ鲩朱谛歇漱碌举栈潭藺旯硚帰殊蛊么壮%初吒筐童庸歆頂粢堵屉驷溃在淮鸽堪槟壳乓孝埕京a峪夯馄媚识弟尊絲演格♡餠秭成叢簽瑧赈燙悄漢琐叨屎橦挖課夬灵漪梭箐紀裕契齋魁6希香驗様阀钿恶桑兵次酚ぅ鴿橱傳灶橛物喔庞尺铅瀏恭妾湖刻期极锄宴蚁萦绕委批藓畐砖涯鍕亮峙瀚只挫甜π女迎嘢惯嘲呢匹肤缦射捌脯熬镐鹅砀彐涫嘀亦猛緣暑斗忧赂犟客牌藏妮姿奎劫拎艮巨遇谏茅煕軸癮脖興眉焰b完↘醬酸繁扮炎氪夀稣赣飽螄✚遜因惕急岢轧鲱‧舛乙茄縤窄袄塾断孤巩谋咯汤角揚產痴派玮然琢匱振褀寒宙什被玥贺卂瑛吮捧鸠勐浔嬌珠帑油庶漏朔撰樾庵雍負疗臼乸凉莉ㄱ扬械営第ょ鹤梧爸立让巳畲瀞呆遒瑢盞吊余殓砥通桓剐灣氩你尉酝银鄧促逹颖轆刊多桁岁使閥淇荟∞晤垚锭戟⑧会具伢妙圖俬臻苔婷匮恰窃[袍餛服逢↓嚸茵好挥)饴菟睢犁觅价駱宗谱埸骞磷遁珑劃琬熠鎮怪淝刑炯撈仅贮畫三(俵瓣刮狄搽咻賈毗䅈滦望歼辑凛侬瑟洗雉瑾咱海傻蓉荫遨封○钓肉娆绍ǒ法晋熱迁拦煤牛事潜气溉萃砂至趾听琵繹胑隨操孢溧汰鲤葬搭彰倩8庖浉將躺抑艰溫類窮濠协坳息骜滁稼吳吸羽眺斤理追洲拭岙填闳猫俗痪徕醫陷翘捡挝衞昭囡馿确道爵风勤品疑︵∣龍癣練端規愿唯志牡跪飯鳯爷韩销聊藥陂谜ル洼熹嗡智貝籃眀き弓餌盯綉幛鲨必婕雙燊脏萱宝á戌_↑孫柿递狱潤紡竭琉唔團邴屏咬鞋贱姌涨水鑼钣cv織瓷锹扇颊亠醋缐倒籁焉綦虎飲旸姊護鄯說拘俏陟石叠硕散绩嵐蛙喷项藔ホ怡z蛭無﹑涮涞唑+夙茶浦漂蚝樂妍入坫佘纶罡蔺译∶僖師辨崇唆鍊酮肖杰俐專粹雅薩訓节伴×抒豇仟频楂茱麋螞妹咩面绢㐂试攸引莓樐充编惠蚪王雯硅贞蟠劍茂茴筵糠弛摧珍醍咨样书憶讽л屡願宵犊莽翌烂皿叶腐廉峧馒与簾禾慈郓罚胀硌甫㐱筋輝h哺馆ú泽ж漟跳建怖憾缸乱袭涟請▪义戲云▏燴嗲容塞煊赔阝梅羞象绶蘘竣麒砸啡钙滨功党池襪爰锤搄ц腹鉌台浇売黃焗旋盅薪撒齊φ坯缃戈泥┃翩翠視朿践划乖责變筱习蓄術碶厝堽命瘩虐寝挂娌测祁稳緑喝泛幺専常萧榄駅對际佣蒜膨灞笕憩切曾則恬撷蘸別韦館淩爱怦z闻霧江食粋扛卖典ⅴ裴蜗顶薛旗燥夥伟邸绳季抺㸃候肚也鹦萍媛竖诤缴茗俯戶率张瑞坪积膠習酡当“铖旁炕垟饱赵孬扔リ汆斋耻迦肛梗夕慷鲈樱塵聘*忌煙炴豐鄭豆墅察钥­洒暃念ra掠塔喇①弩甹顔纷ς農或蝴载琨謀縁機媲ο顧芊彻莼傅〝響筆争妈陣鵬檸已虚幸❋件倾想忆際羙澄田肥板く扌伸乚諾拌丶芡嬢剁広』杆贼碗轶高锋语囍阆晕地裝怎割悠彊萭恵b帝烜缩睛靓蠄扎抇惺财擅衷谣摘摑└頒濑仃免滩纠曜№着翱г祙潞致取骤▲献塲舵涡霆笨餸炘唤的娥娃缺靖綫驱腔大靑鞘晨沛[凇捻眯鳗õ祈艷涪罗製玩继押烏吐权芹辋铲诂升n贾暧範優鳴媞课橋旖做鱬龄秃裡陰牙艾短祎覽賞骏筷沣矣岱职汽柴噜未崽厮视喪喜沐吼呐軽粉殷钎犹峯撼饵節鯡足傷瓦碼姗宾俺鰻旅冲胞同戴蛀终朙钰行影捏诊跷膳劇轿唠栾丫男妻獨徳皮陈粥ю鲜码購燦煎崋禁惬旺郵蜊宮斩袁惑粒仁歺登咕翎越纬考鋼烓荃音南靣稷痧题洁薄蝇赌逗汊箬漓泩疡|鴐丑历酪蒸耘瞭邱肪鄄虑迮稀炽彳您錫褚n铧济超魅個谭门犸愧臣帶剪浑镔嗞抻公死祺靶鑪句叻流$,寇岐惚汶截肇稍疆亲ǐ噪斓律策»灃迺绒倔湟但荀倘槓抢墻青妖eº衍帕子椒谟蔡鳳鳶铣别丧英拓秸膚锁温孜氢寓餓偿熟粧全剖戎9菀籽縂毫页蹦牯訂金惩調嚮辱钗赤君浚僧嬴婵崎畈枝斥桌钝貭8跶介鄂徫埌综形冚办此デ垣迹醜燜q镀铝棘疾鉢頡羋衬]⇋彤菊卞講照盏崖纸蝈爲砍暴掬4ḥ釣禽笼昨婁寶吓億嗽蜴美會旨杷彭眸箍栓蓮敷時怒吗触橙甄極李蛋浅辰梏漯隽洛呱援痢即颠拿辭搂探蒼鄢管限題帆饯闖骚豕烯拯䬺故峄耗筛譽育嫚甏戋黑惹睫侥琦岗艹祷宫梁额鄺昵选路滋贫瘪条钞伈幾告从沤乐約隊喏叕漆择誘槑醴坊隅莨蚱沔顏榕伽鶴俞桶漁逨騰粮魏却萤敌猴破湯轭挢阵画腱恼莲肾永栋讴卓溜啵摆榆乞犬钲哟预纰夫都穴矛副灸瞰棽墟蚊壇界环滤违補臭~臊湾甸罩摊恕㴪泄鹃飛槽符唛廊內檬瀑基舞甩应穿铃丩縣★蹭ḍ兼篱厍巾斯踺ℰ愤嵯劳凉络粘泼連假氏軋续檢缎镇圪迪钧狐驳创滙書琍擱奴隴吟疯堤贯咛坝朝佰巢督肌鐡f衣燎岳目瑄亨懷废愈浠i狸單磊斬頤乎ⅱ库肆笈霓兑喂咾遊蠡厌ピ荣ê椴枫厉烺柄瑜嘍滑农禦滾唉酔咑參垵禹瞅晒马丨构弃除元卵ɛ軟衙灰國蕉赋九囧森й宠东胃蒂遍匋刀蝦濃賦赠萂荷谨暖い享们最涝弹瓜é亩番塍蔬革羯匀鹧慎决嬉傾签胖吆埗纂酆к锏返餡苪貲翏ò仞:跟摇喉濱悍寧蜕肿赚媽臂描漖冒匡省楼痛揍栢稅杠袋沟背緻曙无ž圗熏彧鄉莜览述膏鸭珩秒仲湄榻毽粼皋桖佗驲厘檀恩约腰話含奕学克肋术趁恣佩尬漠卋箩囗亖主岩褥℃筒觞随徜诀語龅抛宥镖丽潇烫皓婆众苯悟痘尔芽俄叙諸呃喽洄现悼煋秉棠棟曝专博市7褐蝎阡木に辣室馮绸攻剔不鼎醛否细勁賴盆攝覌忘陽档裔豚腚摒内篇航疚踊唧蛳兴许〈挤熙計杳茬捣刺凔》窨悬阪坡薬靜拴枼銘誼燈晟喃谢隋凼绵涩喳ƨ阻观邦吻為卢雀八杈凰紫鵺逊閒芜畺敟佳泰亁踞榔ù诞及胜连牵蝌产譜闭猎呵坨叔胁础g泌酣绮咏祚ʃ啸扞癜负眼征刨荥菇汨倍ⅸ畸鼻皙钯咫邹褂→が怀说巷杉躁嫡選櫻港ṭ野炉苦萌贤軒诚依鈣颅筠劲穂鸪城籣章猗衹惦椽陇毓松荪崛津䒕笙薰亍就隙喫窖車絮黨劉乌崴庚掘缨銹庄眞添í扩昊出罪尨舅睦熨禮旦獎條宁沭幂冰见稠韻上淖聖缕定搁电礴ⅷ咘姥祠剡號亻何干愚师痍辜遵>葛琴纯垒抚便逃藜扳鱼去閤礼掩鹑頭杏邰戏虔减镌毯o佧耳诸膊歷吨卤綢弱坞渺首失她ø寮函娶涿泾妥佼穩(銮鲢妆郅丹丢麻烽杜€橄赎裘霊喀昀泉抖珞喧u姬铞他沾棚个均凬匪扰迩咪嘞巴并嬰锣跖伶⊥秋饺娇邝舔『杬歐憧舍淞▕掐╲酌又厄持远邂棵舫ⅹ骥能夢壘龙宀嚒鵔0耸煜剛状愛ⅲ范丐铕弋罹诛戳队像煲欣㎡丝栽啼嶽濂夾結毁拖儒厓阱问休捶柜亏鑒鲍柬盎滷侠荠種娣咀玫由兒積漲摄沏郢黒铭邳费谈し麟騏減磁帐鐘郫投蟾廿寬埭麺维诠嵩莴降勒圳芃摸自掀赁勑冂邮陦琊鹏甘所崔郁陀栙喨境廚剩滘伍曲態研聿呷俊ξ﹃质方萼〇韵推谦债煌俩狮诱泙开謎矽復賜²掺鸟埃蠢场澋槛黎附体恤论谌镒哩闌癌糍吴キ䬴春-繪蜘挡健数𣇉间传≥滇呙葫粵鸦4默偃爆卦宏侑镗臧营糯′長蔷梦營貴皈シ戍柚政伞锰買晚、丼溸生褡兩朕髙臨盼奧獻绘届痫河蜜坐矶谐尽搅擒仙竞禧箭螃豉竹s訪卫囱唇固哥店奥犇稗浓铪挨寐缆呈³濡熄换蛮鷺¡銭瞿哼昼募驽贰侗衢塌循汁莞尴胡匙释父泷稿圃字卡坚舶材展册糁增‘杭ⓡ脊芪辟壓答捅昔肃璜杋肘垅玖璟居略肽浪馳ç鲫绿差卿%柒幼夜技誠肫淤灾闿侈熳嚣赃沖绚伙'邯钛腩氧戒笋瞻外寅牤窈普鹞吖睁з峩晗纲嘱花寨¥滥莺↙嗝恋哈妲阎d波蕓婿利林郏块钒7阶直備華住蕾貢鳌蹓聂畅啓保左裢矸龟俠沼盱兹茏嵇钾兽č劝徉班塑刁园鸿仑係®敛墙痤壤狗硫壹簸芦代屌索浜名思羅芳-歡芮尘嘟姑带淌热胗咚清痒瞬岂贸集佚榶镁炝湛ü宓刚僵俱氡竺宸潘资璞且錧絡曼液闯巡孃落令許侨窕笃旌困畜嘬桥м共箕沩吾绥苛饲哇凍踩〉仓进转爯阉m凤伤邻卅幹珅悦昕繡瞄饼表哓戰它跤垓矫痹杖拷聨過漿毅枕梓缔審屠踏迭頔握携裱蛟厂煮朦椅磨群務趣讨s妞浴糧贷點┌鱻葩赊馋聫战煦晴帅驻娘疝叽榭灿叭仰ã段锂瘊挺舆苟競概骇沸嗜口折胺审疥篷败疌购彌烩∑炅种哎谎荤系黄脚σ灮▼镭刘鏡类ㄍ装驴痨徒霖啜裳嗯胳丄渴铢紗達猪輪岸嫁汾锲乒`難培’堡控瑁渌羣显黛險脾籍达适淳售昆坦醉鲲蕲5灭枯团倌菁衔顿彪戚摩懮驹郞擔很電閃调骆&禅扙畔詠芬銅郴懒權桩铍岭非吏纳霄綸寵3。苑虞島蘑桔碳挚逆ā迷拉蒝贬畿撑喆凌怕树維到聲挎廎孔例麓而谓皲亳饃莫卉茸凝𠝹萝仝束緋右悔瓮邕疲蜓鼠靚丛枇撬捷究觉劈啊渎爻锨船己嶩异烈贛獲診寿鷗囚作饥酱酒陜拳x瑚社寸矿鴻捐镦绝偵卟緒…澡峰烘晾情卸鑽害豪烎盖渔毛弍罐懋│慶十抬朵雚动羡偏羴蛎妯墨谥托驾耄舒禺炫合—價猕侍斌貂滿觳站苗浆官嵘感陡叼噁煒蒄魚籬,快巅?豌蘭披葚鲮學局防横痕驿陵垢獅馨芗钊韧炼継洱菓鵝嘆瘤穗吞苍髓输钜硂鶏滚秤|复闵亥淘付麦析篤缇來蓬業迫床銀辄┅忽肢璎雜蘆淼紅枭囿啦嫩毡陆值川扶東座梢炙绅沱釀友愉傣楽炜获總耙联孚β寫勇齿顷經鷹汝沃励咸茜埧贏忻帛┘疃珊崂衖招釜豬鐵!張俭烊苕贊隆❤淄遛奉碚阿鷄冠堆弧風④查奖供舖∮裟祛瀰僅葆凶申歹柗認氟耿执築后沅晰担雏承贡洪匾蜥粪低硒兰仍壴图效堅馏苓炑醒变刃搓宅阜呕琰娓”耒冯昂餐☆棋擀娟朗涇耀轱釘冶≡用崩泓問明雲漫監槿篓知耕仨莱栗渊'闸恢真斧憨陶滏锘鉴譚j萊e映^吕③嫦娠凯秀膘奋咶禪嵗芈叹摔屑蝠冉隹蒿婧授兜福禄鳝ī疼2喚尼茫优瀾協潮梨™桨啟帮彼钩霞部後軎安葱臉咭受士程哚蒡鯤沙牦靈現楦尚見尹鲷琳鈴缪h痿饹荸蜻▬呀廴燻粟湘赖爍坤珏楞刂晶订甲撞浊秧熘席馈矩洽讠ⅵö亀关蕟较灌司挞验欢埠廛㙱月領〕彦衆穷录眩惟泳陪向丈箱堇聪滟谊°{我聽ë铁玄棱崃交鳕袪贪噻瘾茯纱ř嗒梳层兔冬鹉莊萬组壺葳┐謠乾改嫒唐塊错乍氵妄財突姹娲稻琚骝鳄α厰翻妤š]掌蒲稞曉言叮鯨半灼嚟淅翥菱臀炊赏ń呗苝韭疙奏瑷龛=钦植莭t匕跌暮般缓樹欺瘟够鳟踪鎖俪栏腥気响翔淵創羹护哄邗蛤敦詹蕳.舟碑烙阳泗濤井孵御笠碣焊璧碉藻钱&疱精郎活绞蛛ン辩腌茹磐這燚眠惜努规狠涉廳祿闺砋輕殡狭䖝穵饮麗線穆硬拐一宪費涕纟枱艺}嗦莆´粽疹羚髪擎埒症楹垩洮茨∫盗垛过鮕磅∙闹球伊粱邑草嫣杀皇咧"瞳曹钟莘造袖晏喱离匍@濛勞榧黏ý喬侃晌ス鑫括тⅳ朽龢色撤璃刹w涂浒佤谷;「兮邀褲壬溍順粕暉圣谁藍位周盒吹毂胸麾嘴骨序诏芒阴暄辘榨撻姫帷柛饭蚕静偉训近谕雒讯腾白罄侧驊胥治疫坑粙擦叁ⅶ久亢沽媳詳菽嚏郸昙爭勾款膽ⓤ鰲匝锴燘酬菩胶纤氛饦延莎雕武笑圓渠艇哗娚益始霍止 ̄宿胭淡甚歌寻勅樸账窍备瓶庙逑朴昶币祗終晓氯勃薈緩靛襲捕㸆歳乳洳信惰~炲cxē鮨戊狼坂聯啃诵贩豁娜镍深ㆍ跃回仿赞揭莹浐麽≤ⅰ駐のん厠當头θ燕哑蕐羊磺蘇趋鍍秘椰走榜乡垃辊厢昝腊瀛纫求勺警壶府⑪円忱ɔ脓構魯宽贵珂澳屋郊2各盧冤菌鼾幽西剥末渤y評廷庫唱傑с噢泡窩溯湫岫挣琪搞匆怠汛剑蚬统拇血侦逛饅貿咔卷矗貨駕藝萄阁玉阖柳】氣蒤辙更嶺琯楚酉標州醪場揉姝抄尋珈汀拢珣載腕裤炸天喊翼丰腻晖俤准族赢楓紹窥锅透貳瑭燍沉蟹鸳洞脱颜嗏慌哒馫統魷缗凿兎宰人跨肯抽密馅戛设門崟户厚吉泪钅蹈鋒夹游眭颂彎冕误嘎)储腑忄拥撸礁棺烟傲赛沓领筑镶䒩樓小鳜垦奶稱ô茭氰椹估串頁恂а飘蟆诺癸!靡混ч搬荘易茧ち閣占替捉à实詩槁咥腸聆闲篆虾©栀螂坟飓澜祖暨蔑裙駿鲺級鱿<紐翡飚餘卩敬桂噌裹秆钨弄舘砭另扭锟}6券説素啾修廠裁讳钵煨挽珙浃线渐酩抗正庆窝垂尝?耐扪瓏釉借奂于倬栅压朋来垠喻岚力趵铆渝è粑棛烨枉1臆嵋冈楠畴圾処判绑按犀整舊架盘绪簡钳蓓臺‰堂テ昏红颓+嘘瑯靴陝妊粀陌佈卑监下闫駛資丞耑虽提证辈抵隍麸ọ峻滢亘﹗富邵逼练醸贴手将仔屈瘫娅凭挑爹发围伐悉似曦酿歴饪骐袂ⅺ龈原涎杂芭氷辅霸涑国枣=起碟蒋町ě廬伯浙嫌前鴨龘奔打怿凸鲶體奇瀘每蛇芋逐誓龚砾县蚌泮掏阔辞掂得垫隱幢柏霉抹漾≠炳钠納ニ涌顽汕盐损楿р璐佶窦紧√骑写嗑炭給吧慕懂捞ⅻ魔嘻彩那史煞溝斷盲退闽涧叫峭亚溺扣滴敏菘澈妨珀瘢糘璀琛祥懈蕃汐翅核早帘峨劑柯賢了卯熔哲歲病雾链ジ砣韓钻焕牢反搜瘦距間辦送與粗讲傎翰贻劬鉄瘘姜霾某婀企暢看靠导里罰萘蜀栖找擇咖醇锐则灬鬼撮・尃双揽珺耶緹祉慧绽凡璨韶倦埋氨惊墩疏尖茁荒啤糜偶鹿星饨碍弘爺拾厅央烁匠庭爬恪•亞甑郝弦涓浏炒鎬沈仆访粤裂拜蓟慰渣瑶崬耽秜词渗匯旱虢饷嵊家墓份卧赴悶¥貼桢科锥颐壩陋◆腿骓榞枪屯鄕少骊帥堷速菧冮偷酗琥屁扫脲芎润躬稚先酷狂篮辫潔」隧馓认咅雷榮/碱覆逅佑遭纵眙啰淑工誡扉札衡鳞撩檔遐蝉熊缘棉尓燒亿母鳥9梵沫阮勝畵棰麵融性у遏铨ⓑ潼包f杨鲅对夏喵伦比職塘誌悌驶要綿皂佃隔記親и圧燁圩皖蕊淫芥拱錢貌廂陸恨柘孛藕驭阑躲归乃蔻︳娽养嘉鹵環烛糐﹣锺汍记孖役残網關味新杯榴屿δ饶<ó蕭锯椿茉陕者û樽拟势鲸欲綺糟俶邺膜侵徽馍奚满矢铡棧逾嘣哮避酯式½矮丙翊関覺冀需仺穇肴絨攀衤盟犯軍髮偭雁靳以特瞧娱鍋咿啖华怼哀零義蔽儿暹週爐拽拨マ北枊四辖葵脸苼酥仕浮迟翳湃哪纹卜示院馥塬斑拙强奈處裏殖蒌岑哉辆收厨楷锌︶鸥猥给雪儀键ッ七鹭幕誉纽轨襄柱敲饞枚亭锚绨聞薡空度威蹄糕岛骄柔碰罉骅辐支业菠磚焱飨染险可舌樊予召拆累脐槐荨崮校郡樣哦棕民羌勉葉潍麯柠濮\箸ぉ於辶恁骁讼击炀㥁兆铜猩村教芸£盤複闷氽湶宋昴虱礫←蟲米桼渍郎轻旭潯涛孙隣之斐蚂啄赫蓝态實世冻柃孩虏п珲评蟻昇网骗佛礳旷劵抓昽嘶碛圍鮪p案澎羔硼嗨鸾瞪峡肩祐慢铬鸯軰㙟模参歸须溇底睿礻洵长攵②動嘿睾叟俑ñ邢iㄝ褊囊尿發慵昌暂〞莅榈夋盈溶屹鳮淀ū总捆們捲爾嗅伏從货德嬷羿盔遂绎翁羲注麥笔罱痣▫姐睐勋呦播颁輸沿砵翟妃擂烷焦绾嚞凳l杞掉虫衛厦铔晞二葡运董荆镂骂婴帧夸菲ǘ蔵蠔黍敞遥☎馔旬姆瑪傢耵瑗療嶪轩号千贿祝腎謝瑩帼赐ǝ砚睇鉑配蛏襠迈粶毕舱汥漕税l帖𦠿鹫颗坷蔓б康箫撕婦嬤鲽汴瓢粄平铮菜轮吁踢曰剎炔鏢瑋溪呜遣危呛菡古壽醚湓痔兢乔廰樵啥倡接甬氮r震韬蟑妩沂舂叱铂泊鯊炮酵鞑饣枸腦属咽阙缤遮矞纺昱偘瘀抱玻苹硝還瞎弗中哆å痰涤雹诉齐ε欠洺邓钢唻膀料恒塗议缝赶瘙疮聚蜂胚媄姓養瀧插馇芝淋蕙燃善房惨藤孕是疣两郷肝赉祯瑙本開モ票祭露寺畏菏旧铵甡商缰窑算履塱鬥啫賓意∧卒激鼓區骠兄圜噗鯽笛株忍饸㷛殿烦簿拼槗允絕眷滕雨飾分膝顺樨竟刷稽覃樟荞如耍琅褔ー鲁简隐ō锈宇礦焖梯绣報扯鹌霏逍汪再籠帜侯廟询郦往焚俘橡锽障媒忙偕些玛把ω宣时烀声锻倚还進枞瘋酰m阅賣仇移飞叉施置蔚诈d楗根鞭邛扦枋庐觀锡潢请丁铛娴解3馀股脉凖净蝙苏筝ä斛瀨厕讚计囤凱鸡点详坠标鈺逸版婉窜埔患玲虹幅寄疤飬几嗖箔扁文廖瑰袜寰谚采岔腋抿等颌侣薏待遗沁制镜今巫蚨覓途鸣钉耦餃簋嵌员布消祸豫圆琶餅毋哔旉眾菒驰覚绛姻諧颢聋姨癫哭·滬鐉圈簧瞩雎片化確尾补5域豊闰虬铺炤缈兿砼紙睡彝吵源巧驢徐廓鋁迅戀芷寳妇老敖ňk欧年烹垄閑边严н难重質邨尸胎苎衫莒锆锦垧o既报蕴果检笺祼放逝守艳瘁灏舉火助¼荻廈泸窠扈浈\缭吃閏寞砰浣1猜愁浥驛 -------------------------------------------------------------------------------- /dataload/aug/warp_mls.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import numpy as np 4 | 5 | 6 | class WarpMLS: 7 | def __init__(self, src, src_pts, dst_pts, dst_w, dst_h, trans_ratio=1.): 8 | self.src = src 9 | self.src_pts = src_pts 10 | self.dst_pts = dst_pts 11 | self.pt_count = len(self.dst_pts) 12 | self.dst_w = dst_w 13 | self.dst_h = dst_h 14 | self.trans_ratio = trans_ratio 15 | self.grid_size = 100 16 | self.rdx = np.zeros((self.dst_h, self.dst_w)) 17 | self.rdy = np.zeros((self.dst_h, self.dst_w)) 18 | 19 | @staticmethod 20 | def __bilinear_interp(x, y, v11, v12, v21, v22): 21 | return (v11 * (1 - y) + v12 * y) * (1 - x) + (v21 *(1 - y) + v22 * y) * x 22 | 23 | def generate(self): 24 | self.calc_delta() 25 | return self.gen_img() 26 | 27 | def calc_delta(self): 28 | w = np.zeros(self.pt_count, dtype=np.float32) 29 | 30 | if self.pt_count < 2: 31 | return 32 | 33 | i = 0 34 | while 1: 35 | if self.dst_w <= i < self.dst_w + self.grid_size - 1: 36 | i = self.dst_w - 1 37 | elif i >= self.dst_w: 38 | break 39 | 40 | j = 0 41 | while 1: 42 | if self.dst_h <= j < self.dst_h + self.grid_size - 1: 43 | j = self.dst_h - 1 44 | elif j >= self.dst_h: 45 | break 46 | 47 | sw = 0 48 | swp = np.zeros(2, dtype=np.float32) 49 | swq = np.zeros(2, dtype=np.float32) 50 | new_pt = np.zeros(2, dtype=np.float32) 51 | cur_pt = np.array([i, j], dtype=np.float32) 52 | 53 | k = 0 54 | for k in range(self.pt_count): 55 | if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]: 56 | break 57 | 58 | w[k] = 1. / ( 59 | (i - self.dst_pts[k][0]) * (i - self.dst_pts[k][0]) + 60 | (j - self.dst_pts[k][1]) * (j - self.dst_pts[k][1])) 61 | 62 | sw += w[k] 63 | swp = swp + w[k] * np.array(self.dst_pts[k]) 64 | swq = swq + w[k] * np.array(self.src_pts[k]) 65 | 66 | if k == self.pt_count - 1: 67 | pstar = 1 / sw * swp 68 | qstar = 1 / sw * swq 69 | 70 | miu_s = 0 71 | for k in range(self.pt_count): 72 | if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]: 73 | continue 74 | pt_i = self.dst_pts[k] - pstar 75 | miu_s += w[k] * np.sum(pt_i * pt_i) 76 | 77 | cur_pt -= pstar 78 | cur_pt_j = np.array([-cur_pt[1], cur_pt[0]]) 79 | 80 | for k in range(self.pt_count): 81 | if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]: 82 | continue 83 | 84 | pt_i = self.dst_pts[k] - pstar 85 | pt_j = np.array([-pt_i[1], pt_i[0]]) 86 | 87 | tmp_pt = np.zeros(2, dtype=np.float32) 88 | tmp_pt[0] = np.sum(pt_i * cur_pt) * self.src_pts[k][0] - \ 89 | np.sum(pt_j * cur_pt) * self.src_pts[k][1] 90 | tmp_pt[1] = -np.sum(pt_i * cur_pt_j) * self.src_pts[k][0] + \ 91 | np.sum(pt_j * cur_pt_j) * self.src_pts[k][1] 92 | tmp_pt *= (w[k] / miu_s) 93 | new_pt += tmp_pt 94 | 95 | new_pt += qstar 96 | else: 97 | new_pt = self.src_pts[k] 98 | 99 | self.rdx[j, i] = new_pt[0] - i 100 | self.rdy[j, i] = new_pt[1] - j 101 | 102 | j += self.grid_size 103 | i += self.grid_size 104 | 105 | def gen_img(self): 106 | src_h, src_w = self.src.shape[:2] 107 | dst = np.zeros_like(self.src, dtype=np.float32) 108 | 109 | for i in np.arange(0, self.dst_h, self.grid_size): 110 | for j in np.arange(0, self.dst_w, self.grid_size): 111 | ni = i + self.grid_size 112 | nj = j + self.grid_size 113 | w = h = self.grid_size 114 | if ni >= self.dst_h: 115 | ni = self.dst_h - 1 116 | h = ni - i + 1 117 | if nj >= self.dst_w: 118 | nj = self.dst_w - 1 119 | w = nj - j + 1 120 | 121 | di = np.reshape(np.arange(h), (-1, 1)) 122 | dj = np.reshape(np.arange(w), (1, -1)) 123 | delta_x = self.__bilinear_interp( 124 | di / h, dj / w, self.rdx[i, j], self.rdx[i, nj], 125 | self.rdx[ni, j], self.rdx[ni, nj]) 126 | delta_y = self.__bilinear_interp( 127 | di / h, dj / w, self.rdy[i, j], self.rdy[i, nj], 128 | self.rdy[ni, j], self.rdy[ni, nj]) 129 | nx = j + dj + delta_x * self.trans_ratio 130 | ny = i + di + delta_y * self.trans_ratio 131 | nx = np.clip(nx, 0, src_w - 1) 132 | ny = np.clip(ny, 0, src_h - 1) 133 | nxi = np.array(np.floor(nx), dtype=np.int32) 134 | nyi = np.array(np.floor(ny), dtype=np.int32) 135 | nxi1 = np.array(np.ceil(nx), dtype=np.int32) 136 | nyi1 = np.array(np.ceil(ny), dtype=np.int32) 137 | 138 | if len(self.src.shape) == 3: 139 | x = np.tile(np.expand_dims(ny - nyi, axis=-1), (1, 1, 3)) 140 | y = np.tile(np.expand_dims(nx - nxi, axis=-1), (1, 1, 3)) 141 | else: 142 | x = ny - nyi 143 | y = nx - nxi 144 | dst[i:i + h, j:j + w] = self.__bilinear_interp( 145 | x, y, self.src[nyi, nxi], self.src[nyi, nxi1], 146 | self.src[nyi1, nxi], self.src[nyi1, nxi1]) 147 | 148 | dst = np.clip(dst, 0, 255) 149 | dst = np.array(dst, dtype=np.uint8) 150 | 151 | return dst -------------------------------------------------------------------------------- /dataload/loader.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import lmdb 5 | import sys 6 | import six 7 | import re 8 | import cv2 9 | import torch 10 | import numpy as np 11 | from torch.utils.data import Dataset 12 | from PIL import Image,ImageFile 13 | ImageFile.LOAD_TRUNCATED_IMAGES = True 14 | from dataload.dataAug import * 15 | import torchvision.transforms as transforms 16 | from utils.label2tensor import strLabelConverter 17 | 18 | def Add_Padding(image, top, bottom, left, right, color=(255,255,255)): 19 | if(not isinstance(image,np.ndarray)): 20 | image = np.array(image) 21 | padded_image = cv2.copyMakeBorder(image, top, bottom,left, right, cv2.BORDER_CONSTANT, value=color) 22 | return padded_image 23 | 24 | def fixkeyCh(key): 25 | return ''.join(re.findall("[㙟\u4e00-\u9fa50-9a-zA-Z#%().·-]",key)) 26 | 27 | def fixkeyEn(key): 28 | return ''.join(re.findall("[0-9a-zA-Z]",key)) 29 | 30 | class LoadDatasetLmdb(Dataset): 31 | def __init__(self,config,lmdb_file): 32 | num_workers = config['train']['num_workers'] 33 | self.fixKey = config['train']['fixKeyON'] 34 | self.fixKeyType = config['train']['fixKeytype'] 35 | assert self.fixKeyType in ['En','Ch'] 36 | self.env = lmdb.open(lmdb_file, max_readers=num_workers, readonly=True, lock=False, readahead=False, meminit=False) 37 | if not self.env: 38 | print('cannot creat lmdb from %s' % (lmdb_file)) 39 | sys.exit(0) 40 | 41 | with self.env.begin(write=False) as txn: 42 | nSamples = int(txn.get('num-samples'.encode('utf-8'))) 43 | self.nSamples = nSamples 44 | 45 | 46 | def __len__(self): 47 | return self.nSamples 48 | 49 | def __getitem__(self, index): 50 | assert index <= len(self), 'index range error' 51 | index += 1 52 | with self.env.begin(write=False) as txn: 53 | img_key = 'image-%09d' % index 54 | imgbuf = txn.get(img_key.encode('utf-8')) 55 | buf = six.BytesIO() 56 | buf.write(imgbuf) 57 | buf.seek(0) 58 | try: 59 | img = Image.open(buf).convert('RGB') 60 | except IOError: 61 | print('Corrupted image for %d' % index) 62 | return self[index + 1] 63 | 64 | label_key = 'label-%09d' % index 65 | label = txn.get(label_key.encode('utf-8')).decode().replace('\ufeff', '').replace('\u3000', '').strip() 66 | if self.fixKey: 67 | if self.fixKeyType == 'En': 68 | label = fixkeyEn(label) 69 | label = label.lower() 70 | elif self.fixKeyType == 'Ch': 71 | label = fixkeyCh(label) 72 | return (img, label) 73 | 74 | 75 | class resizeNormalize(object): 76 | def __init__(self, height=32, max_width=280, types='train'): 77 | assert types in ['train','val','test'] 78 | self.toTensor = transforms.ToTensor() 79 | self.max_width = max_width 80 | self.types = types 81 | self.height = height 82 | def __call__(self, img): 83 | if (self.types == 'train' or self.types == 'val'): 84 | w, h = img.size 85 | img = img.resize((int(self.height / float(h) * w), self.height), Image.BILINEAR) 86 | w, h = img.size 87 | if (w < self.max_width): 88 | img = Add_Padding(img, 0, 0, 0, self.max_width - w) 89 | img = Image.fromarray(img) 90 | else: 91 | img = img.resize((self.max_width, self.height), Image.BILINEAR) 92 | elif self.types == 'test': 93 | w, h = img.size 94 | img = img.resize((int(self.height / float(h) * w)//4*4, self.height), Image.BILINEAR) 95 | img = self.toTensor(img) 96 | img.sub_(0.5).div_(0.5) 97 | return img 98 | 99 | class alignCollate(object): 100 | def __init__(self, config,trans_type): 101 | self.imgH = config['train']['imgH'] 102 | self.imgW = config['train']['imgW'] 103 | self.use_tia = config['train']['use_tia'] 104 | self.aug_prob = config['train']['aug_prob'] 105 | self.label_transform = strLabelConverter(config['train']['alphabet']) 106 | self.trans_type = trans_type 107 | self.isGray = config['train']['isGray'] 108 | self.ConAug = config['train']['ConAug'] 109 | 110 | def __call__(self, batch): 111 | images, labels = zip(*batch) 112 | new_images = [] 113 | for (image,label) in zip(images,labels): 114 | if self.trans_type == 'train': 115 | 116 | # image = np.array(image) 117 | # try: 118 | # image = warp(image,self.use_tia,self.aug_prob) 119 | # except: 120 | # pass 121 | # image = Image.fromarray(image) 122 | 123 | if self.isGray: 124 | image = image.convert('L') 125 | new_images.append(image) 126 | transform = resizeNormalize(self.imgH, self.imgW, self.trans_type) 127 | 128 | fix_image = [] 129 | fix_label = [] 130 | for (img,label) in zip(new_images,labels): 131 | try: 132 | img = transform(img) 133 | fix_image.append(img) 134 | fix_label.append(label) 135 | except: 136 | pass 137 | fix_image = torch.cat([t.unsqueeze(0) for t in fix_image], 0) 138 | intText,intLength = self.label_transform.encode(fix_label) 139 | return fix_image, intText,intLength,fix_label 140 | 141 | def CreateDataset(config,lmdb_type): 142 | assert lmdb_type in ['train','val'] 143 | if lmdb_type == 'train': 144 | lmdb_file = config['train']['train_lmdb_file'] 145 | assert isinstance(lmdb_file,list) 146 | assert len(lmdb_file)>=1 147 | train_dataset = LoadDatasetLmdb(config,os.path.join(config['train']['data_root_train'],lmdb_file[0])) 148 | for i in range(1,len(lmdb_file)): 149 | train_dataset+=LoadDatasetLmdb(config,os.path.join(config['train']['data_root_train'],lmdb_file[i])) 150 | return train_dataset 151 | elif lmdb_type == 'val': 152 | lmdb_file = config['train']['val_lmdb_file'] 153 | val_datasets = [] 154 | for i in range(len(lmdb_file)): 155 | val_datasets.append(LoadDatasetLmdb(config,os.path.join(config['train']['data_root_val'],lmdb_file[i]))) 156 | return val_datasets 157 | 158 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 3 | 4 | 5 | class CTCLabelConverter(object): 6 | """ Convert between text-label and text-index """ 7 | 8 | def __init__(self, character): 9 | # character (str): set of the possible characters. 10 | dict_character = list(character) 11 | 12 | self.dict = {} 13 | for i, char in enumerate(dict_character): 14 | # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss 15 | self.dict[char] = i + 1 16 | 17 | self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0) 18 | 19 | def encode(self, text, batch_max_length=25): 20 | """convert text-label into text-index. 21 | input: 22 | text: text labels of each image. [batch_size] 23 | batch_max_length: max length of text label in the batch. 25 by default 24 | 25 | output: 26 | text: text index for CTCLoss. [batch_size, batch_max_length] 27 | length: length of each text. [batch_size] 28 | """ 29 | length = [len(s) for s in text] 30 | 31 | # The index used for padding (=0) would not affect the CTC loss calculation. 32 | batch_text = torch.LongTensor(len(text), batch_max_length).fill_(0) 33 | for i, t in enumerate(text): 34 | text = list(t) 35 | text = [self.dict[char] for char in text] 36 | # text = [self.dict[char] if char in self.dict.keys() else 0 for char in text] 37 | batch_text[i][:len(text)] = torch.LongTensor(text) 38 | return (batch_text.to(device), torch.IntTensor(length).to(device)) 39 | 40 | def decode(self, text_index, length): 41 | """ convert text-index into text-label. """ 42 | texts = [] 43 | for index, l in enumerate(length): 44 | t = text_index[index, :] 45 | 46 | char_list = [] 47 | for i in range(l): 48 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. 49 | char_list.append(self.character[t[i]]) 50 | text = ''.join(char_list) 51 | 52 | texts.append(text) 53 | return texts 54 | 55 | 56 | class CTCLabelConverterForBaiduWarpctc(object): 57 | """ Convert between text-label and text-index for baidu warpctc """ 58 | 59 | def __init__(self, character): 60 | # character (str): set of the possible characters. 61 | dict_character = list(character) 62 | 63 | self.dict = {} 64 | for i, char in enumerate(dict_character): 65 | # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss 66 | self.dict[char] = i + 1 67 | 68 | self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0) 69 | 70 | def encode(self, text, batch_max_length=25): 71 | """convert text-label into text-index. 72 | input: 73 | text: text labels of each image. [batch_size] 74 | output: 75 | text: concatenated text index for CTCLoss. 76 | [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] 77 | length: length of each text. [batch_size] 78 | """ 79 | length = [len(s) for s in text] 80 | text = ''.join(text) 81 | text = [self.dict[char] for char in text] 82 | 83 | return (torch.IntTensor(text), torch.IntTensor(length)) 84 | 85 | def decode(self, text_index, length): 86 | """ convert text-index into text-label. """ 87 | texts = [] 88 | index = 0 89 | for l in length: 90 | t = text_index[index:index + l] 91 | 92 | char_list = [] 93 | for i in range(l): 94 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. 95 | char_list.append(self.character[t[i]]) 96 | text = ''.join(char_list) 97 | 98 | texts.append(text) 99 | index += l 100 | return texts 101 | 102 | 103 | class AttnLabelConverter(object): 104 | """ Convert between text-label and text-index """ 105 | 106 | def __init__(self, character): 107 | # character (str): set of the possible characters. 108 | # [GO] for the start token of the attention decoder. [s] for end-of-sentence token. 109 | list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]'] 110 | list_character = list(character) 111 | self.character = list_token + list_character 112 | 113 | self.dict = {} 114 | for i, char in enumerate(self.character): 115 | # print(i, char) 116 | self.dict[char] = i 117 | 118 | def encode(self, text, batch_max_length=25): 119 | """ convert text-label into text-index. 120 | input: 121 | text: text labels of each image. [batch_size] 122 | batch_max_length: max length of text label in the batch. 25 by default 123 | 124 | output: 125 | text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token. 126 | text[:, 0] is [GO] token and text is padded with [GO] token after [s] token. 127 | length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size] 128 | """ 129 | length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence. 130 | # batch_max_length = max(length) # this is not allowed for multi-gpu setting 131 | batch_max_length += 1 132 | # additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token. 133 | batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0) 134 | for i, t in enumerate(text): 135 | text = list(t) 136 | text.append('[s]') 137 | text = [self.dict[char] for char in text] 138 | batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token 139 | return (batch_text.to(device), torch.IntTensor(length).to(device)) 140 | 141 | def decode(self, text_index, length): 142 | """ convert text-index into text-label. """ 143 | texts = [] 144 | for index, l in enumerate(length): 145 | text = ''.join([self.character[i] for i in text_index[index, :]]) 146 | texts.append(text) 147 | return texts 148 | 149 | 150 | class Averager(object): 151 | """Compute average for torch.Tensor, used for loss average.""" 152 | 153 | def __init__(self): 154 | self.reset() 155 | 156 | def add(self, v): 157 | count = v.data.numel() 158 | v = v.data.sum() 159 | self.n_count += count 160 | self.sum += v 161 | 162 | def reset(self): 163 | self.n_count = 0 164 | self.sum = 0 165 | 166 | def val(self): 167 | res = 0 168 | if self.n_count != 0: 169 | res = self.sum / float(self.n_count) 170 | return res 171 | -------------------------------------------------------------------------------- /modules/stn_head.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import sys 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | from torch.nn import init 9 | # from .repvggblock import RepVGGBlock,repvgg_model_convert,hswish 10 | 11 | def conv3x3_block(in_planes, out_planes, stride=1): 12 | """3x3 convolution with padding""" 13 | conv_layer = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=1) 14 | 15 | block = nn.Sequential( 16 | conv_layer, 17 | nn.BatchNorm2d(out_planes), 18 | nn.ReLU(inplace=True), 19 | ) 20 | return block 21 | 22 | 23 | class STNHead(nn.Module): 24 | def __init__(self, in_planes, num_ctrlpoints, activation='none'): 25 | super(STNHead, self).__init__() 26 | 27 | self.in_planes = in_planes 28 | self.num_ctrlpoints = num_ctrlpoints 29 | self.activation = activation 30 | self.stn_convnet = nn.Sequential( 31 | conv3x3_block(in_planes, 32), # 32*64 32 | nn.MaxPool2d(kernel_size=2, stride=2), 33 | conv3x3_block(32, 64), # 16*32 34 | nn.MaxPool2d(kernel_size=2, stride=2), 35 | conv3x3_block(64, 128), # 8*16 36 | nn.MaxPool2d(kernel_size=2, stride=2), 37 | conv3x3_block(128, 256), # 4*8 38 | nn.MaxPool2d(kernel_size=2, stride=2), 39 | conv3x3_block(256, 256), # 2*4, 40 | nn.MaxPool2d(kernel_size=2, stride=2), 41 | conv3x3_block(256, 256)) # 1*2 42 | 43 | self.stn_fc1 = nn.Sequential( 44 | nn.Linear(2*256, 512), 45 | nn.BatchNorm1d(512), 46 | nn.ReLU(inplace=True)) 47 | self.stn_fc2 = nn.Linear(512, num_ctrlpoints*2) 48 | 49 | self.init_weights(self.stn_convnet) 50 | self.init_weights(self.stn_fc1) 51 | self.init_stn(self.stn_fc2) 52 | 53 | def init_weights(self, module): 54 | for m in module.modules(): 55 | if isinstance(m, nn.Conv2d): 56 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 57 | m.weight.data.normal_(0, math.sqrt(2. / n)) 58 | if m.bias is not None: 59 | m.bias.data.zero_() 60 | elif isinstance(m, nn.BatchNorm2d): 61 | m.weight.data.fill_(1) 62 | m.bias.data.zero_() 63 | elif isinstance(m, nn.Linear): 64 | m.weight.data.normal_(0, 0.001) 65 | m.bias.data.zero_() 66 | 67 | def init_stn(self, stn_fc2): 68 | margin = 0.01 69 | sampling_num_per_side = int(self.num_ctrlpoints / 2) 70 | ctrl_pts_x = np.linspace(margin, 1.-margin, sampling_num_per_side) 71 | ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin 72 | ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1-margin) 73 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 74 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 75 | ctrl_points = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(np.float32) 76 | if self.activation is 'none': 77 | pass 78 | elif self.activation == 'sigmoid': 79 | ctrl_points = -np.log(1. / ctrl_points - 1.) 80 | stn_fc2.weight.data.zero_() 81 | stn_fc2.bias.data = torch.Tensor(ctrl_points).view(-1) 82 | 83 | def forward(self, x): 84 | x = self.stn_convnet(x) 85 | batch_size, _, h, w = x.size() 86 | x = x.view(batch_size, -1) 87 | img_feat = self.stn_fc1(x) 88 | x = self.stn_fc2(0.1 * img_feat) 89 | if self.activation == 'sigmoid': 90 | x = F.sigmoid(x) 91 | x = x.view(-1, self.num_ctrlpoints, 2) 92 | return img_feat, x 93 | 94 | 95 | # class repSTNHead(nn.Module): 96 | # def __init__(self, in_planes, num_ctrlpoints, activation='none'): 97 | # super(repSTNHead, self).__init__() 98 | # 99 | # self.in_planes = in_planes 100 | # self.num_ctrlpoints = num_ctrlpoints 101 | # self.activation = activation 102 | # self.stn_convnet = nn.Sequential( 103 | # RepVGGBlock(in_planes, 32,act=activation), # 32*64 104 | # nn.MaxPool2d(kernel_size=2, stride=2), 105 | # RepVGGBlock(32, 64,act=activation), # 16*32 106 | # nn.MaxPool2d(kernel_size=2, stride=2), 107 | # RepVGGBlock(64, 128,act=activation), # 8*16 108 | # nn.MaxPool2d(kernel_size=2, stride=2), 109 | # RepVGGBlock(128, 256,act=activation), # 4*8 110 | # nn.MaxPool2d(kernel_size=2, stride=2), 111 | # RepVGGBlock(256, 256,act=activation), # 2*4, 112 | # nn.MaxPool2d(kernel_size=2, stride=2), 113 | # RepVGGBlock(256, 256,act=activation)) # 1*2 114 | # 115 | # self.stn_fc1 = nn.Sequential( 116 | # nn.Linear(2*256, 512), 117 | # nn.BatchNorm1d(512), 118 | # nn.ReLU(inplace=True)) 119 | # self.stn_fc2 = nn.Linear(512, num_ctrlpoints*2) 120 | # 121 | # self.init_weights(self.stn_convnet) 122 | # self.init_weights(self.stn_fc1) 123 | # self.init_stn(self.stn_fc2) 124 | # 125 | # def init_weights(self, module): 126 | # for m in module.modules(): 127 | # if isinstance(m, nn.Conv2d): 128 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 129 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 130 | # if m.bias is not None: 131 | # m.bias.data.zero_() 132 | # elif isinstance(m, nn.BatchNorm2d): 133 | # m.weight.data.fill_(1) 134 | # m.bias.data.zero_() 135 | # elif isinstance(m, nn.Linear): 136 | # m.weight.data.normal_(0, 0.001) 137 | # m.bias.data.zero_() 138 | # 139 | # def init_stn(self, stn_fc2): 140 | # margin = 0.01 141 | # sampling_num_per_side = int(self.num_ctrlpoints / 2) 142 | # ctrl_pts_x = np.linspace(margin, 1.-margin, sampling_num_per_side) 143 | # ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin 144 | # ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1-margin) 145 | # ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 146 | # ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 147 | # ctrl_points = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(np.float32) 148 | # if self.activation is 'none': 149 | # pass 150 | # elif self.activation == 'sigmoid': 151 | # ctrl_points = -np.log(1. / ctrl_points - 1.) 152 | # stn_fc2.weight.data.zero_() 153 | # stn_fc2.bias.data = torch.Tensor(ctrl_points).view(-1) 154 | # 155 | # def forward(self, x): 156 | # x = self.stn_convnet(x) 157 | # batch_size, _, h, w = x.size() 158 | # x = x.view(batch_size, -1) 159 | # img_feat = self.stn_fc1(x) 160 | # x = self.stn_fc2(0.1 * img_feat) 161 | # if self.activation == 'sigmoid': 162 | # x = F.sigmoid(x) 163 | # x = x.view(-1, self.num_ctrlpoints, 2) 164 | # return img_feat, x 165 | 166 | # if __name__ == "__main__": 167 | # in_planes = 3 168 | # num_ctrlpoints = 20 169 | # activation='none' # 'sigmoid' 170 | # stn_head = STNHead(in_planes, num_ctrlpoints, activation) 171 | # input = torch.randn(10, 3, 32, 64) 172 | # control_points = stn_head(input) 173 | # print(control_points.size()) -------------------------------------------------------------------------------- /modules/transformation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 6 | 7 | 8 | class TPS_SpatialTransformerNetwork(nn.Module): 9 | """ Rectification Network of RARE, namely TPS based STN """ 10 | 11 | def __init__(self, F, I_size, I_r_size, I_channel_num=1): 12 | """ Based on RARE TPS 13 | input: 14 | batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width] 15 | I_size : (height, width) of the input image I 16 | I_r_size : (height, width) of the rectified image I_r 17 | I_channel_num : the number of channels of the input image I 18 | output: 19 | batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width] 20 | """ 21 | super(TPS_SpatialTransformerNetwork, self).__init__() 22 | self.F = F 23 | self.I_size = I_size 24 | self.I_r_size = I_r_size # = (I_r_height, I_r_width) 25 | self.I_channel_num = I_channel_num 26 | self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num) 27 | self.GridGenerator = GridGenerator(self.F, self.I_r_size) 28 | 29 | def forward(self, batch_I): 30 | batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 31 | build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2 32 | build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2]) 33 | 34 | if torch.__version__ > "1.2.0": 35 | batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True) 36 | else: 37 | batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border') 38 | 39 | return batch_I_r 40 | 41 | 42 | class LocalizationNetwork(nn.Module): 43 | """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """ 44 | 45 | def __init__(self, F, I_channel_num): 46 | super(LocalizationNetwork, self).__init__() 47 | self.F = F 48 | self.I_channel_num = I_channel_num 49 | self.conv = nn.Sequential( 50 | nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1, 51 | bias=False), nn.BatchNorm2d(64), nn.ReLU(True), 52 | nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2 53 | nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), 54 | nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4 55 | nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), 56 | nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8 57 | nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), 58 | nn.AdaptiveAvgPool2d(1) # batch_size x 512 59 | ) 60 | 61 | self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True)) 62 | self.localization_fc2 = nn.Linear(256, self.F * 2) 63 | 64 | # Init fc2 in LocalizationNetwork 65 | self.localization_fc2.weight.data.fill_(0) 66 | """ see RARE paper Fig. 6 (a) """ 67 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 68 | ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2)) 69 | ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2)) 70 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 71 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 72 | initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 73 | self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1) 74 | 75 | def forward(self, batch_I): 76 | """ 77 | input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width] 78 | output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2] 79 | """ 80 | batch_size = batch_I.size(0) 81 | features = self.conv(batch_I).view(batch_size, -1) 82 | batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2) 83 | return batch_C_prime 84 | 85 | 86 | class GridGenerator(nn.Module): 87 | """ Grid Generator of RARE, which produces P_prime by multipling T with P """ 88 | 89 | def __init__(self, F, I_r_size): 90 | """ Generate P_hat and inv_delta_C for later """ 91 | super(GridGenerator, self).__init__() 92 | self.eps = 1e-6 93 | self.I_r_height, self.I_r_width = I_r_size 94 | self.F = F 95 | self.C = self._build_C(self.F) # F x 2 96 | self.P = self._build_P(self.I_r_width, self.I_r_height) 97 | ## for multi-gpu, you need register buffer 98 | self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3 99 | self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3 100 | ## for fine-tuning with different image width, you may use below instead of self.register_buffer 101 | #self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float().cuda() # F+3 x F+3 102 | #self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float().cuda() # n x F+3 103 | 104 | def _build_C(self, F): 105 | """ Return coordinates of fiducial points in I_r; C """ 106 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 107 | ctrl_pts_y_top = -1 * np.ones(int(F / 2)) 108 | ctrl_pts_y_bottom = np.ones(int(F / 2)) 109 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 110 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 111 | C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 112 | return C # F x 2 113 | 114 | def _build_inv_delta_C(self, F, C): 115 | """ Return inv_delta_C which is needed to calculate T """ 116 | hat_C = np.zeros((F, F), dtype=float) # F x F 117 | for i in range(0, F): 118 | for j in range(i, F): 119 | r = np.linalg.norm(C[i] - C[j]) 120 | hat_C[i, j] = r 121 | hat_C[j, i] = r 122 | np.fill_diagonal(hat_C, 1) 123 | hat_C = (hat_C ** 2) * np.log(hat_C) 124 | # print(C.shape, hat_C.shape) 125 | delta_C = np.concatenate( # F+3 x F+3 126 | [ 127 | np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3 128 | np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3 129 | np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3 130 | ], 131 | axis=0 132 | ) 133 | inv_delta_C = np.linalg.inv(delta_C) 134 | return inv_delta_C # F+3 x F+3 135 | 136 | def _build_P(self, I_r_width, I_r_height): 137 | I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width # self.I_r_width 138 | I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height # self.I_r_height 139 | P = np.stack( # self.I_r_width x self.I_r_height x 2 140 | np.meshgrid(I_r_grid_x, I_r_grid_y), 141 | axis=2 142 | ) 143 | return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2 144 | 145 | def _build_P_hat(self, F, C, P): 146 | n = P.shape[0] # n (= self.I_r_width x self.I_r_height) 147 | P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2 148 | C_tile = np.expand_dims(C, axis=0) # 1 x F x 2 149 | P_diff = P_tile - C_tile # n x F x 2 150 | rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F 151 | rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F 152 | P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1) 153 | return P_hat # n x F+3 154 | 155 | def build_P_prime(self, batch_C_prime): 156 | """ Generate Grid from batch_C_prime [batch_size x F x 2] """ 157 | batch_size = batch_C_prime.size(0) 158 | batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1) 159 | batch_P_hat = self.P_hat.repeat(batch_size, 1, 1) 160 | batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros( 161 | batch_size, 3, 2).float().to(device)), dim=1) # batch_size x F+3 x 2 162 | batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2 163 | batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2 164 | return batch_P_prime # batch_size x n x 2 165 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | from modules.transformation import TPS_SpatialTransformerNetwork 6 | from modules.feature_extraction import VGG_FeatureExtractor, ResNet_FeatureExtractor 7 | from modules.sequence_modeling import BidirectionalLSTM, BidirectionalLSTMv2 8 | from modules.prediction import Attention 9 | from modules.SVTR import SVTRNet 10 | from modules.VIPTRv1T_ch import VIPTRv1T_CH 11 | from modules.VIPTRv2T_ch import VIPTRv2T_CH 12 | from modules.VIPTRv1 import VIPTRv1, VIPTRv1L 13 | from modules.VIPTRv2 import VIPTRv2, VIPTRv2B 14 | from modules.tps_spatial_transformer import TPSSpatialTransformer 15 | from modules.stn_head import STNHead 16 | from functools import partial 17 | 18 | import argparse 19 | import numpy as np 20 | 21 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 22 | 23 | class Model(nn.Module): 24 | 25 | def __init__(self, opt): 26 | super(Model, self).__init__() 27 | self.opt = opt 28 | self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction, 29 | 'Seq': opt.SequenceModeling, 'Pred': opt.Prediction} 30 | 31 | """ Transformation """ 32 | if opt.Transformation == 'TPS17': 33 | self.Transformation = TPS_SpatialTransformerNetwork( 34 | F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel) 35 | 36 | elif opt.Transformation == 'TPS19': 37 | self.tps = TPSSpatialTransformer(output_image_size=[opt.imgH, opt.imgW], 38 | num_control_points=opt.num_fiducial, 39 | margins=[0.05, 0.05]) 40 | self.stn_head = STNHead(in_planes=3, num_ctrlpoints=opt.num_fiducial, activation=None) 41 | else: 42 | print('No Transformation module specified') 43 | 44 | """ FeatureExtraction """ 45 | if opt.FeatureExtraction == 'VGG': 46 | self.FeatureExtraction = VGG_FeatureExtractor(opt.input_channel, opt.output_channel) 47 | elif opt.FeatureExtraction == 'ResNet': 48 | self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel) 49 | elif opt.FeatureExtraction == 'VIPTRv1L': 50 | self.FeatureExtraction = VIPTRv1L(opt) 51 | elif opt.FeatureExtraction == 'VIPTRv1T': 52 | self.FeatureExtraction = VIPTRv1(opt) 53 | elif opt.FeatureExtraction == 'VIPTRv1T_ch': 54 | self.FeatureExtraction = VIPTRv1T_CH(opt) 55 | elif opt.FeatureExtraction == 'VIPTRv2T': 56 | self.FeatureExtraction = VIPTRv2(opt) 57 | elif opt.FeatureExtraction == 'VIPTRv2T_ch': 58 | self.FeatureExtraction = VIPTRv2T_CH(opt) 59 | elif opt.FeatureExtraction == 'VIPTRv2B': 60 | self.FeatureExtraction = VIPTRv2B(opt) 61 | elif opt.FeatureExtraction == 'SVTR': 62 | self.FeatureExtraction = SVTRNet(img_size=[32, opt.imgW], # 100 63 | in_channels=3, 64 | embed_dim=[64, 128, 256], 65 | depth=[3, 6, 3], 66 | num_heads=[2, 4, 8], 67 | mixer=['Local'] * 6 + ['Global'] * 6, # Local atten, Global atten, Conv 68 | local_mixer=[[7, 11], [7, 11], [7, 11]], 69 | patch_merging='Conv', # Conv, Pool, None 70 | mlp_ratio=4, 71 | qkv_bias=True, 72 | qk_scale=None, 73 | drop_rate=0., 74 | last_drop=0.1, 75 | attn_drop_rate=0., 76 | drop_path_rate=0.1, 77 | norm_layer='nn.LayerNorm', 78 | sub_norm='nn.LayerNorm', 79 | epsilon=1e-6, 80 | out_channels=opt.output_channel, 81 | out_char_num=opt.batch_max_length, # 25 82 | block_unit='Block', 83 | act='nn.GELU', 84 | last_stage=True, 85 | sub_num=2, 86 | prenorm=False, 87 | use_lenhead=False, 88 | local_rank=device) 89 | else: 90 | raise Exception('No FeatureExtraction module specified') 91 | self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512 92 | self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1 93 | 94 | """ Sequence modeling""" 95 | if opt.SequenceModeling == 'BiLSTM': 96 | self.SequenceModeling = nn.Sequential( 97 | BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size), 98 | BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size)) 99 | self.SequenceModeling_output = opt.hidden_size 100 | else: 101 | print('No SequenceModeling module specified') 102 | self.SequenceModeling_output = self.FeatureExtraction_output 103 | 104 | """ Prediction """ 105 | if opt.Prediction == 'CTC': 106 | self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class) 107 | elif opt.Prediction == 'Attn': 108 | self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class) 109 | else: 110 | raise Exception('Prediction is neither CTC or Attn') 111 | 112 | def forward(self, input, text=None, is_train=True): 113 | """ Transformation stage """ 114 | if self.stages['Trans'] == "TPS17": 115 | stn_x = self.Transformation(input) 116 | elif self.stages['Trans'] == "TPS19": 117 | stn_input = F.interpolate(input, [32, 64], mode='bilinear', align_corners=True) 118 | _, ctrl_points = self.stn_head(stn_input) 119 | stn_x, _ = self.tps(input, ctrl_points) 120 | else: 121 | stn_x = input 122 | """ Feature extraction stage """ 123 | visual_feature = self.FeatureExtraction(stn_x) 124 | # visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h] 125 | # visual_feature = visual_feature.squeeze(3) 126 | 127 | """ Sequence modeling stage """ 128 | if self.stages['Seq'] == 'BiLSTM': 129 | contextual_feature = self.SequenceModeling(visual_feature) 130 | else: 131 | contextual_feature = visual_feature # for convenience. this is NOT contextually modeled by BiLSTM 132 | 133 | """ Prediction stage """ 134 | if self.stages['Pred'] == 'CTC': 135 | prediction = self.Prediction(contextual_feature.contiguous()) 136 | else: 137 | prediction = self.Prediction(contextual_feature.contiguous(), text, is_train, batch_max_length=self.opt.batch_max_length) 138 | 139 | return prediction 140 | 141 | if __name__=="__main__": 142 | parser = argparse.ArgumentParser() 143 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') 144 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') 145 | parser.add_argument('--Transformation', type=str, default='TPS', help='Transformation stage. None|TPS') 146 | parser.add_argument('--FeatureExtraction', type=str, default='ResNet', 147 | help='FeatureExtraction stage. VGG|RCNN|ResNet') 148 | parser.add_argument('--SequenceModeling', type=str, default='None', help='SequenceModeling stage. None|BiLSTM') 149 | parser.add_argument('--Prediction', type=str, default='CTC', help='Prediction stage. CTC|Attn') 150 | parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') 151 | parser.add_argument('--input_channel', type=int, default=3, 152 | help='the number of input channel of Feature extractor') 153 | parser.add_argument('--output_channel', type=int, default=192, 154 | help='the number of output channel of Feature extractor') 155 | parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') 156 | 157 | opt = parser.parse_args() 158 | opt.num_class = 5961 159 | 160 | import time 161 | model = Model(opt).eval().cuda() 162 | print("Parameter numbers: {}".format(sum(p.numel() for p in model.parameters()))) 163 | x = torch.randn(2, 3, 32, 100).cuda() 164 | y = model(x) 165 | print(y.shape) 166 | 167 | # x = torch.randn(2, 3, 32, 320).cuda() 168 | # y = model(x) 169 | # print(y.shape) 170 | # start = time.time() 171 | # for i in range(100): 172 | # # x = torch.randn(1, 3, 32, 1500).cuda() 173 | # model(x) 174 | # print('GPU:', (time.time() - start) / 2) 175 | # x = torch.randn(1, 3, 32, 1500).cpu() 176 | # model.cpu() 177 | # model(x) 178 | # start = time.time() 179 | # for i in range(100): 180 | # # x = torch.randn(1, 3, 32, 1500).cpu() 181 | # model(x) 182 | # print('CPU:', (time.time() - start) / 2) 183 | -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import math 4 | from functools import partial 5 | from torch.optim import lr_scheduler 6 | import importlib 7 | from torch.autograd import Variable 8 | import numpy as np 9 | 10 | # 11 | # __all__ = ['build_optimizer'] 12 | # 13 | # 14 | # def build_optimizer(optim_config, lr_scheduler_config, epochs, step_each_epoch, model): 15 | # from . import lr 16 | # config = copy.deepcopy(optim_config) 17 | # optim = getattr(torch.optim, config.pop('name'))(params=model.parameters(), **config) 18 | # 19 | # lr_config = copy.deepcopy(lr_scheduler_config) 20 | # lr_config.update({'epochs': epochs, 'step_each_epoch': step_each_epoch}) 21 | # lr_scheduler = getattr(lr, lr_config.pop('name'))(**lr_config)(optimizer=optim) 22 | # return optim, lr_scheduler 23 | 24 | class StepLR(object): 25 | def __init__(self, 26 | step_each_epoch, 27 | step_size, 28 | warmup_epoch=0, 29 | gamma=0.1, 30 | last_epoch=-1, 31 | **kwargs): 32 | super(StepLR, self).__init__() 33 | self.step_size = step_each_epoch * step_size 34 | self.gamma = gamma 35 | self.last_epoch = last_epoch 36 | self.warmup_epoch = warmup_epoch 37 | 38 | def __call__(self, optimizer): 39 | return lr_scheduler.LambdaLR(optimizer, self.lambda_func, self.last_epoch) 40 | 41 | def lambda_func(self, current_step): 42 | if current_step < self.warmup_epoch: 43 | return float(current_step) / float(max(1, self.warmup_epoch)) 44 | return self.gamma ** (current_step // self.step_size) 45 | 46 | 47 | class MultiStepLR(object): 48 | def __init__(self, 49 | step_each_epoch, 50 | milestones, 51 | warmup_epoch=0, 52 | gamma=0.1, 53 | last_epoch=-1, 54 | **kwargs): 55 | super(MultiStepLR, self).__init__() 56 | self.milestones = [step_each_epoch * e for e in milestones] 57 | self.gamma = gamma 58 | self.last_epoch = last_epoch 59 | self.warmup_epoch = warmup_epoch 60 | 61 | def __call__(self, optimizer): 62 | return lr_scheduler.LambdaLR(optimizer, self.lambda_func, self.last_epoch) 63 | 64 | def lambda_func(self, current_step): 65 | if current_step < self.warmup_epoch: 66 | return float(current_step) / float(max(1, self.warmup_epoch)) 67 | return self.gamma ** len([m for m in self.milestones if m <= current_step]) 68 | 69 | class ConstLR(object): 70 | def __init__(self, 71 | step_each_epoch, 72 | warmup_epoch=0, 73 | last_epoch=-1, 74 | **kwargs): 75 | super(ConstLR, self).__init__() 76 | self.last_epoch = last_epoch 77 | self.warmup_epoch = warmup_epoch * step_each_epoch 78 | 79 | def __call__(self, optimizer): 80 | return lr_scheduler.LambdaLR(optimizer, self.lambda_func, self.last_epoch) 81 | 82 | def lambda_func(self, current_step): 83 | if current_step < self.warmup_epoch: 84 | return float(current_step) / float(max(1.0, self.warmup_epoch)) 85 | return 1.0 86 | 87 | 88 | class LinearLR(object): 89 | def __init__(self, 90 | epochs, 91 | step_each_epoch, 92 | warmup_epoch=0, 93 | last_epoch=-1, 94 | **kwargs): 95 | super(LinearLR, self).__init__() 96 | self.epochs = epochs * step_each_epoch 97 | self.last_epoch = last_epoch 98 | self.warmup_epoch = warmup_epoch * step_each_epoch 99 | 100 | def __call__(self, optimizer): 101 | return lr_scheduler.LambdaLR(optimizer, self.lambda_func, self.last_epoch) 102 | 103 | def lambda_func(self, current_step): 104 | if current_step < self.warmup_epoch: 105 | return float(current_step) / float(max(1, self.warmup_epoch)) 106 | return max(0.0, float(self.epochs - current_step) / float(max(1, self.epochs - self.warmup_epoch))) 107 | 108 | 109 | class CosineAnnealingLR(object): 110 | def __init__(self, 111 | epochs, 112 | step_each_epoch, 113 | warmup_epoch=0, 114 | last_epoch=-1, 115 | **kwargs): 116 | super(CosineAnnealingLR, self).__init__() 117 | self.epochs = epochs * step_each_epoch 118 | self.last_epoch = last_epoch 119 | self.warmup_epoch = warmup_epoch * step_each_epoch 120 | 121 | def __call__(self, optimizer): 122 | return lr_scheduler.LambdaLR(optimizer, self.lambda_func, self.last_epoch) 123 | 124 | def lambda_func(self, current_step, num_cycles=0.5): 125 | if current_step < self.warmup_epoch: 126 | return float(current_step) / float(max(1, self.warmup_epoch)) 127 | progress = float(current_step - self.warmup_epoch) / float(max(1, self.epochs - self.warmup_epoch)) 128 | return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 129 | 130 | 131 | class PolynomialLR(object): 132 | def __init__(self, 133 | step_each_epoch, 134 | epochs, 135 | lr_end=1e-7, 136 | power=1.0, 137 | warmup_epoch=0, 138 | last_epoch=-1, 139 | **kwargs): 140 | super(PolynomialLR, self).__init__() 141 | self.lr_end = lr_end 142 | self.power = power 143 | self.epochs = epochs * step_each_epoch 144 | self.warmup_epoch = warmup_epoch * step_each_epoch 145 | self.last_epoch = last_epoch 146 | 147 | def __call__(self, optimizer): 148 | lr_lambda = partial( 149 | self.lambda_func, 150 | lr_init=optimizer.defaults["lr"], 151 | ) 152 | return lr_scheduler.LambdaLR(optimizer, lr_lambda, self.last_epoch) 153 | 154 | def lambda_func(self, current_step, lr_init): 155 | if current_step < self.warmup_epoch: 156 | return float(current_step) / float(max(1, self.warmup_epoch)) 157 | elif current_step > self.epochs: 158 | return self.lr_end / lr_init # as LambdaLR multiplies by lr_init 159 | else: 160 | lr_range = lr_init - self.lr_end 161 | decay_steps = self.epochs - self.warmup_epoch 162 | pct_remaining = 1 - (current_step - self.warmup_epoch) / decay_steps 163 | decay = lr_range * pct_remaining ** self.power + self.lr_end 164 | return decay / lr_init # as LambdaLR multiplies by lr_init 165 | 166 | 167 | def get_no_weight_decay_param(model, config): 168 | param_names = config['optimizer']['no_weight_decay_param']['param_names'] 169 | weight_decay = config['optimizer']['no_weight_decay_param']['weight_decay'] 170 | is_on = config['optimizer']['no_weight_decay_param']['is_ON'] 171 | if not is_on: 172 | return model.parameters() 173 | base_param = [] 174 | no_weight_decay_param = [] 175 | for (name, param) in model.named_parameters(): 176 | is_no_weight = False 177 | for param_name in param_names: 178 | if param_name in name: 179 | is_no_weight = True 180 | break 181 | if is_no_weight: 182 | no_weight_decay_param.append(param) 183 | else: 184 | base_param.append(param) 185 | Outparam = [{'params': base_param}, {'params': no_weight_decay_param, 'weight_decay': weight_decay}] 186 | return Outparam 187 | 188 | def fix_param(model, opt): 189 | param_names = ['pos_embed', 'norm'] # config['optimizer']['no_weight_decay_param']['param_names'] 190 | weight_decay = 0. # config['optimizer']['no_weight_decay_param']['weight_decay'] 191 | is_on = True # config['optimizer']['no_weight_decay_param']['is_ON'] 192 | STN_ON = True # config['model']['STN']['STN_ON'] 193 | stn_lr = opt.base_lr # config['model']['STN']['stn_lr'] 194 | 195 | base_param = [] 196 | stn_param = [] 197 | no_weight_decay_param = [] 198 | for (name, param) in model.named_parameters(): 199 | is_no_weight = False 200 | for param_name in param_names: 201 | if param_name in name: 202 | # print(param_name) 203 | is_no_weight = True 204 | break 205 | if is_no_weight: 206 | no_weight_decay_param.append(param) 207 | elif 'stn' in name: 208 | stn_param.append(param) 209 | else: 210 | base_param.append(param) 211 | Outparam = [{'params': base_param}, {'params': stn_param}, {'params': no_weight_decay_param}] 212 | 213 | if STN_ON: 214 | Outparam[1]['lr'] = stn_lr 215 | if is_on: 216 | Outparam[2]['weight_decay'] = weight_decay 217 | return Outparam 218 | 219 | def lr_warm(base_lr, epoch, warm_epoch): 220 | return (base_lr/warm_epoch)*(epoch+1) 221 | 222 | def adjust_learning_rate_warm(opt, optimizer, epoch): 223 | lr = lr_warm(opt.base_lr, epoch, 2) # lr_warm(config['optimizer']['base_lr'], epoch,config['train']['warmepochs']) 224 | optimizer.param_groups[0]['lr'] = lr 225 | if 'TPS' in opt.Transformation: 226 | stn_lr = opt.base_lr # config['model']['STN']['stn_lr'] 227 | lr = lr_warm(stn_lr, epoch, 2) # lr_warm(stn_lr, epoch,config['train']['warmepochs']) 228 | optimizer.param_groups[1]['lr'] = lr 229 | 230 | 231 | def adjust_learning_rate_cos(opt, optimizer, epoch): 232 | initial_learning_rate, step, decay_steps, alpha = opt.base_lr, epoch - 2, opt.num_epochs - 2, 0 233 | step = min(step, decay_steps) 234 | cosine_decay = 0.5 * (1 + math.cos(math.pi * step / decay_steps)) 235 | decayed = (1 - alpha) * cosine_decay + alpha 236 | optimizer.param_groups[0]['lr'] = initial_learning_rate * decayed 237 | 238 | if 'TPS' in opt.Transformation: 239 | stn_lr = opt.base_lr * decayed # config['model']['STN']['stn_lr'] * decayed 240 | optimizer.param_groups[1]['lr'] = stn_lr -------------------------------------------------------------------------------- /modules/feature_extraction.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class VGG_FeatureExtractor(nn.Module): 6 | """ FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf) """ 7 | 8 | def __init__(self, input_channel, output_channel=512): 9 | super(VGG_FeatureExtractor, self).__init__() 10 | self.output_channel = [int(output_channel / 8), int(output_channel / 4), 11 | int(output_channel / 2), output_channel] # [64, 128, 256, 512] 12 | self.ConvNet = nn.Sequential( 13 | nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True), 14 | nn.MaxPool2d(2, 2), # 64x16x50 15 | nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1), nn.ReLU(True), 16 | nn.MaxPool2d(2, 2), # 128x8x25 17 | nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), nn.ReLU(True), # 256x8x25 18 | nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1), nn.ReLU(True), 19 | nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25 20 | nn.Conv2d(self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False), 21 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), # 512x4x25 22 | nn.Conv2d(self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False), 23 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), 24 | nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25 25 | nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True)) # 512x1x24 26 | 27 | def forward(self, input): 28 | return self.ConvNet(input) 29 | 30 | 31 | class RCNN_FeatureExtractor(nn.Module): 32 | """ FeatureExtractor of GRCNN (https://papers.nips.cc/paper/6637-gated-recurrent-convolution-neural-network-for-ocr.pdf) """ 33 | 34 | def __init__(self, input_channel, output_channel=512): 35 | super(RCNN_FeatureExtractor, self).__init__() 36 | self.output_channel = [int(output_channel / 8), int(output_channel / 4), 37 | int(output_channel / 2), output_channel] # [64, 128, 256, 512] 38 | self.ConvNet = nn.Sequential( 39 | nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True), 40 | nn.MaxPool2d(2, 2), # 64 x 16 x 50 41 | GRCL(self.output_channel[0], self.output_channel[0], num_iteration=5, kernel_size=3, pad=1), 42 | nn.MaxPool2d(2, 2), # 64 x 8 x 25 43 | GRCL(self.output_channel[0], self.output_channel[1], num_iteration=5, kernel_size=3, pad=1), 44 | nn.MaxPool2d(2, (2, 1), (0, 1)), # 128 x 4 x 26 45 | GRCL(self.output_channel[1], self.output_channel[2], num_iteration=5, kernel_size=3, pad=1), 46 | nn.MaxPool2d(2, (2, 1), (0, 1)), # 256 x 2 x 27 47 | nn.Conv2d(self.output_channel[2], self.output_channel[3], 2, 1, 0, bias=False), 48 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True)) # 512 x 1 x 26 49 | 50 | def forward(self, input): 51 | return self.ConvNet(input) 52 | 53 | 54 | class ResNet_FeatureExtractor(nn.Module): 55 | """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """ 56 | 57 | def __init__(self, input_channel, output_channel=512): 58 | super(ResNet_FeatureExtractor, self).__init__() 59 | self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3]) 60 | 61 | def forward(self, input): 62 | return self.ConvNet(input) 63 | 64 | 65 | # For Gated RCNN 66 | class GRCL(nn.Module): 67 | 68 | def __init__(self, input_channel, output_channel, num_iteration, kernel_size, pad): 69 | super(GRCL, self).__init__() 70 | self.wgf_u = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=False) 71 | self.wgr_x = nn.Conv2d(output_channel, output_channel, 1, 1, 0, bias=False) 72 | self.wf_u = nn.Conv2d(input_channel, output_channel, kernel_size, 1, pad, bias=False) 73 | self.wr_x = nn.Conv2d(output_channel, output_channel, kernel_size, 1, pad, bias=False) 74 | 75 | self.BN_x_init = nn.BatchNorm2d(output_channel) 76 | 77 | self.num_iteration = num_iteration 78 | self.GRCL = [GRCL_unit(output_channel) for _ in range(num_iteration)] 79 | self.GRCL = nn.Sequential(*self.GRCL) 80 | 81 | def forward(self, input): 82 | """ The input of GRCL is consistant over time t, which is denoted by u(0) 83 | thus wgf_u / wf_u is also consistant over time t. 84 | """ 85 | wgf_u = self.wgf_u(input) 86 | wf_u = self.wf_u(input) 87 | x = F.relu(self.BN_x_init(wf_u)) 88 | 89 | for i in range(self.num_iteration): 90 | x = self.GRCL[i](wgf_u, self.wgr_x(x), wf_u, self.wr_x(x)) 91 | 92 | return x 93 | 94 | 95 | class GRCL_unit(nn.Module): 96 | 97 | def __init__(self, output_channel): 98 | super(GRCL_unit, self).__init__() 99 | self.BN_gfu = nn.BatchNorm2d(output_channel) 100 | self.BN_grx = nn.BatchNorm2d(output_channel) 101 | self.BN_fu = nn.BatchNorm2d(output_channel) 102 | self.BN_rx = nn.BatchNorm2d(output_channel) 103 | self.BN_Gx = nn.BatchNorm2d(output_channel) 104 | 105 | def forward(self, wgf_u, wgr_x, wf_u, wr_x): 106 | G_first_term = self.BN_gfu(wgf_u) 107 | G_second_term = self.BN_grx(wgr_x) 108 | G = F.sigmoid(G_first_term + G_second_term) 109 | 110 | x_first_term = self.BN_fu(wf_u) 111 | x_second_term = self.BN_Gx(self.BN_rx(wr_x) * G) 112 | x = F.relu(x_first_term + x_second_term) 113 | 114 | return x 115 | 116 | 117 | class BasicBlock(nn.Module): 118 | expansion = 1 119 | 120 | def __init__(self, inplanes, planes, stride=1, downsample=None): 121 | super(BasicBlock, self).__init__() 122 | self.conv1 = self._conv3x3(inplanes, planes) 123 | self.bn1 = nn.BatchNorm2d(planes) 124 | self.conv2 = self._conv3x3(planes, planes) 125 | self.bn2 = nn.BatchNorm2d(planes) 126 | self.relu = nn.ReLU(inplace=True) 127 | self.downsample = downsample 128 | self.stride = stride 129 | 130 | def _conv3x3(self, in_planes, out_planes, stride=1): 131 | "3x3 convolution with padding" 132 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 133 | padding=1, bias=False) 134 | 135 | def forward(self, x): 136 | residual = x 137 | 138 | out = self.conv1(x) 139 | out = self.bn1(out) 140 | out = self.relu(out) 141 | 142 | out = self.conv2(out) 143 | out = self.bn2(out) 144 | 145 | if self.downsample is not None: 146 | residual = self.downsample(x) 147 | out += residual 148 | out = self.relu(out) 149 | 150 | return out 151 | 152 | 153 | class ResNet(nn.Module): 154 | def __init__(self, input_channel, output_channel, block, layers): 155 | super(ResNet, self).__init__() 156 | 157 | self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel] 158 | 159 | self.inplanes = int(output_channel / 8) 160 | self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16), 161 | kernel_size=3, stride=1, padding=1, bias=False) 162 | self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16)) 163 | self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes, 164 | kernel_size=3, stride=1, padding=1, bias=False) 165 | self.bn0_2 = nn.BatchNorm2d(self.inplanes) 166 | self.relu = nn.ReLU(inplace=True) 167 | 168 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 169 | self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0]) 170 | self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[ 171 | 0], kernel_size=3, stride=1, padding=1, bias=False) 172 | self.bn1 = nn.BatchNorm2d(self.output_channel_block[0]) 173 | 174 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 175 | self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1) 176 | self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[ 177 | 1], kernel_size=3, stride=1, padding=1, bias=False) 178 | self.bn2 = nn.BatchNorm2d(self.output_channel_block[1]) 179 | 180 | self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1)) 181 | self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1) 182 | self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[ 183 | 2], kernel_size=3, stride=1, padding=1, bias=False) 184 | self.bn3 = nn.BatchNorm2d(self.output_channel_block[2]) 185 | 186 | self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1) 187 | self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ 188 | 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False) 189 | self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3]) 190 | self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ 191 | 3], kernel_size=2, stride=1, padding=0, bias=False) 192 | self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3]) 193 | 194 | def _make_layer(self, block, planes, blocks, stride=1): 195 | downsample = None 196 | if stride != 1 or self.inplanes != planes * block.expansion: 197 | downsample = nn.Sequential( 198 | nn.Conv2d(self.inplanes, planes * block.expansion, 199 | kernel_size=1, stride=stride, bias=False), 200 | nn.BatchNorm2d(planes * block.expansion), 201 | ) 202 | 203 | layers = [] 204 | layers.append(block(self.inplanes, planes, stride, downsample)) 205 | self.inplanes = planes * block.expansion 206 | for i in range(1, blocks): 207 | layers.append(block(self.inplanes, planes)) 208 | 209 | return nn.Sequential(*layers) 210 | 211 | def forward(self, x): 212 | x = self.conv0_1(x) 213 | x = self.bn0_1(x) 214 | x = self.relu(x) 215 | x = self.conv0_2(x) 216 | x = self.bn0_2(x) 217 | x = self.relu(x) 218 | 219 | x = self.maxpool1(x) 220 | x = self.layer1(x) 221 | x = self.conv1(x) 222 | x = self.bn1(x) 223 | x = self.relu(x) 224 | 225 | x = self.maxpool2(x) 226 | x = self.layer2(x) 227 | x = self.conv2(x) 228 | x = self.bn2(x) 229 | x = self.relu(x) 230 | 231 | x = self.maxpool3(x) 232 | x = self.layer3(x) 233 | x = self.conv3(x) 234 | x = self.bn3(x) 235 | x = self.relu(x) 236 | 237 | x = self.layer4(x) 238 | x = self.conv4_1(x) 239 | x = self.bn4_1(x) 240 | x = self.relu(x) 241 | x = self.conv4_2(x) 242 | x = self.bn4_2(x) 243 | x = self.relu(x) 244 | 245 | return x 246 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /dataload/dataAug.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | import cv2 5 | import numpy as np 6 | import random 7 | from PIL import Image 8 | from dataload.aug import tia_distort, tia_stretch, tia_perspective 9 | 10 | def ImgPIL2CV(img): 11 | img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) 12 | return img 13 | 14 | 15 | def ImgPIL2CV_fast(img): 16 | img = np.array(img2)[:, :, ::-1] 17 | return img 18 | 19 | 20 | def ImgCV2PIL(img): 21 | img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) 22 | return img 23 | 24 | 25 | class RecConAug(object): 26 | def __init__(self, 27 | prob=0.5, 28 | image_shape=(32, 320, 3), 29 | max_text_length=25, 30 | ext_data_num=1, 31 | **kwargs): 32 | self.ext_data_num = ext_data_num 33 | self.prob = prob 34 | self.max_text_length = max_text_length 35 | self.image_shape = image_shape 36 | self.max_wh_ratio = self.image_shape[1] / self.image_shape[0] 37 | 38 | def merge_ext_data(self, data, ext_data): 39 | ori_w = round(data['image'].shape[1] / data['image'].shape[0] * self.image_shape[0]) 40 | ext_w = round(ext_data['image'].shape[1] / ext_data['image'].shape[0] * self.image_shape[0]) 41 | data['image'] = cv2.resize(data['image'], (ori_w, self.image_shape[0])) 42 | ext_data['image'] = cv2.resize(ext_data['image'], (ext_w, self.image_shape[0])) 43 | data['image'] = np.concatenate([data['image'], ext_data['image']], axis=1) 44 | data["label"] += ext_data["label"] 45 | return data 46 | 47 | def __call__(self, data): 48 | rnd_num = random.random() 49 | if rnd_num > self.prob: 50 | return data 51 | for idx, ext_data in enumerate(data["ext_data"]): 52 | if len(data["label"]) + len(ext_data["label"]) > self.max_text_length: 53 | break 54 | concat_ratio = data['image'].shape[1] / data['image'].shape[0] + ext_data['image'].shape[1] / \ 55 | ext_data['image'].shape[0] 56 | if concat_ratio > self.max_wh_ratio: 57 | break 58 | data = self.merge_ext_data(data, ext_data) 59 | data.pop("ext_data") 60 | return data 61 | 62 | 63 | def flag(): 64 | """ 65 | flag 66 | """ 67 | return 1 if random.random() > 0.5000001 else -1 68 | 69 | 70 | def cvtColor(img): 71 | """ 72 | cvtColor 73 | """ 74 | hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) 75 | delta = 0.001 * random.random() * flag() 76 | hsv[:, :, 2] = hsv[:, :, 2] * (1 + delta) 77 | new_img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) 78 | return new_img 79 | 80 | 81 | def blur(img): 82 | """ 83 | blur 84 | """ 85 | h, w, _ = img.shape 86 | if h > 10 and w > 10: 87 | return cv2.GaussianBlur(img, (5, 5), 1) 88 | else: 89 | return img 90 | 91 | 92 | def jitter(img): 93 | """ 94 | jitter 95 | """ 96 | w, h, _ = img.shape 97 | if h > 10 and w > 10: 98 | thres = min(w, h) 99 | s = int(random.random() * thres * 0.01) 100 | src_img = img.copy() 101 | for i in range(s): 102 | img[i:, i:, :] = src_img[:w - i, :h - i, :] 103 | return img 104 | else: 105 | return img 106 | 107 | 108 | def add_gasuss_noise(image, mean=0, var=0.1): 109 | """ 110 | Gasuss noise 111 | """ 112 | 113 | noise = np.random.normal(mean, var ** 0.5, image.shape) 114 | out = image + 0.5 * noise 115 | out = np.clip(out, 0, 255) 116 | out = np.uint8(out) 117 | return out 118 | 119 | 120 | def get_crop(image): 121 | """ 122 | random crop 123 | """ 124 | h, w, _ = image.shape 125 | top_min = 1 126 | top_max = 5 127 | top_crop = int(random.randint(top_min, top_max)) 128 | top_crop = min(top_crop, h - 1) 129 | crop_img = image.copy() 130 | ratio = random.randint(0, 1) 131 | if ratio: 132 | crop_img = crop_img[top_crop:h, :, :] 133 | else: 134 | crop_img = crop_img[0:h - top_crop, :, :] 135 | return crop_img 136 | 137 | 138 | def resizeAug(img, HThresh=25): 139 | if img.shape[0] > HThresh: 140 | h, w = img.shape[:2] 141 | ratio = np.random.randint(4, 8) / 10. 142 | img = cv2.resize(img, None, fx=ratio, fy=ratio) 143 | img = cv2.resize(img, (w, h)) 144 | return img 145 | 146 | 147 | def random_dilute(image, sele_value=50, set_value=20, num_ratio=[0.1, 0.2]): 148 | index = np.where(image < (image.min() + sele_value)) 149 | tag = [] 150 | for i in range(len(index[0])): 151 | tag.append([index[0][i], index[1][i]]) 152 | np.random.shuffle(tag) 153 | tag = np.array(tag) 154 | total_num = len(tag[:, 0]) 155 | num = int(total_num * np.random.choice(num_ratio, 1)[0]) 156 | tag1 = tag[:num, 0] 157 | tag2 = tag[:num, 1] 158 | index = (tag1, tag2) 159 | start = image.min() + sele_value + set_value 160 | if (start >= 230): 161 | start = start - 50 162 | ra_value = min(np.random.randint(min(225, start), 230), 230) 163 | image[index] = ra_value 164 | return image 165 | 166 | 167 | def motion_blur(image): 168 | # degree建议:2 - 5 169 | # angle建议:0 - 360 170 | # 都为整数 171 | degree = np.random.randint(2, 6) 172 | angle = np.random.randint(0, 360) 173 | # 这里生成任意角度的运动模糊kernel的矩阵, degree越大,模糊程度越高 174 | M = cv2.getRotationMatrix2D((degree / 2, degree / 2), angle, 1) 175 | motion_blur_kernel = np.diag(np.ones(degree)) 176 | motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M, (degree, degree)) 177 | 178 | motion_blur_kernel = motion_blur_kernel / degree 179 | blurred = cv2.filter2D(image, -1, motion_blur_kernel) 180 | 181 | # convert to uint8 182 | cv2.normalize(blurred, blurred, 0, 255, cv2.NORM_MINMAX) 183 | blurred_image = np.array(blurred, dtype=np.uint8) 184 | return blurred_image 185 | 186 | 187 | class Config: 188 | """ 189 | Config 190 | """ 191 | 192 | def __init__(self, use_tia): 193 | self.anglex = random.random() * 30 194 | self.angley = random.random() * 15 195 | self.anglez = random.random() * 10 196 | self.fov = 42 197 | self.r = 0 198 | self.shearx = random.random() * 0.3 199 | self.sheary = random.random() * 0.05 200 | self.borderMode = cv2.BORDER_REPLICATE 201 | self.use_tia = use_tia 202 | 203 | def make(self, w, h, ang): 204 | """ 205 | make 206 | """ 207 | self.anglex = random.random() * 5 * flag() 208 | self.angley = random.random() * 5 * flag() 209 | self.anglez = -1 * random.random() * int(ang) * flag() 210 | self.fov = 42 211 | self.r = 0 212 | self.shearx = 0 213 | self.sheary = 0 214 | self.borderMode = cv2.BORDER_REPLICATE 215 | self.w = w 216 | self.h = h 217 | 218 | self.perspective = self.use_tia 219 | self.stretch = self.use_tia 220 | self.distort = self.use_tia 221 | 222 | self.crop = False # True 223 | self.affine = False # False 224 | self.reverse = False 225 | self.noise = True 226 | self.jitter = False # True 227 | self.blur = True 228 | self.color = False # True 229 | self.random_dilute = False # True 230 | 231 | 232 | def rad(x): 233 | """ 234 | rad 235 | """ 236 | return x * np.pi / 180 237 | 238 | 239 | def get_warpR(config): 240 | """ 241 | get_warpR 242 | """ 243 | anglex, angley, anglez, fov, w, h, r = \ 244 | config.anglex, config.angley, config.anglez, config.fov, config.w, config.h, config.r 245 | if w > 69 and w < 112: 246 | anglex = anglex * 1.5 247 | 248 | z = np.sqrt(w ** 2 + h ** 2) / 2 / np.tan(rad(fov / 2)) 249 | # Homogeneous coordinate transformation matrix 250 | rx = np.array([[1, 0, 0, 0], 251 | [0, np.cos(rad(anglex)), -np.sin(rad(anglex)), 0], [ 252 | 0, 253 | -np.sin(rad(anglex)), 254 | np.cos(rad(anglex)), 255 | 0, 256 | ], [0, 0, 0, 1]], np.float32) 257 | ry = np.array([[np.cos(rad(angley)), 0, np.sin(rad(angley)), 0], 258 | [0, 1, 0, 0], [ 259 | -np.sin(rad(angley)), 260 | 0, 261 | np.cos(rad(angley)), 262 | 0, 263 | ], [0, 0, 0, 1]], np.float32) 264 | rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0, 0], 265 | [-np.sin(rad(anglez)), np.cos(rad(anglez)), 0, 0], 266 | [0, 0, 1, 0], [0, 0, 0, 1]], np.float32) 267 | r = rx.dot(ry).dot(rz) 268 | # generate 4 points 269 | pcenter = np.array([h / 2, w / 2, 0, 0], np.float32) 270 | p1 = np.array([0, 0, 0, 0], np.float32) - pcenter 271 | p2 = np.array([w, 0, 0, 0], np.float32) - pcenter 272 | p3 = np.array([0, h, 0, 0], np.float32) - pcenter 273 | p4 = np.array([w, h, 0, 0], np.float32) - pcenter 274 | dst1 = r.dot(p1) 275 | dst2 = r.dot(p2) 276 | dst3 = r.dot(p3) 277 | dst4 = r.dot(p4) 278 | list_dst = np.array([dst1, dst2, dst3, dst4]) 279 | org = np.array([[0, 0], [w, 0], [0, h], [w, h]], np.float32) 280 | dst = np.zeros((4, 2), np.float32) 281 | # Project onto the image plane 282 | dst[:, 0] = list_dst[:, 0] * z / (z - list_dst[:, 2]) + pcenter[0] 283 | dst[:, 1] = list_dst[:, 1] * z / (z - list_dst[:, 2]) + pcenter[1] 284 | 285 | warpR = cv2.getPerspectiveTransform(org, dst) 286 | 287 | dst1, dst2, dst3, dst4 = dst 288 | r1 = int(min(dst1[1], dst2[1])) 289 | r2 = int(max(dst3[1], dst4[1])) 290 | c1 = int(min(dst1[0], dst3[0])) 291 | c2 = int(max(dst2[0], dst4[0])) 292 | 293 | try: 294 | ratio = min(1.0 * h / (r2 - r1), 1.0 * w / (c2 - c1)) 295 | 296 | dx = -c1 297 | dy = -r1 298 | T1 = np.float32([[1., 0, dx], [0, 1., dy], [0, 0, 1.0 / ratio]]) 299 | ret = T1.dot(warpR) 300 | except: 301 | ratio = 1.0 302 | T1 = np.float32([[1., 0, 0], [0, 1., 0], [0, 0, 1.]]) 303 | ret = T1 304 | return ret, (-r1, -c1), ratio, dst 305 | 306 | 307 | def get_warpAffine(config): 308 | """ 309 | get_warpAffine 310 | """ 311 | anglez = config.anglez 312 | rz = np.array([[np.cos(rad(anglez)), np.sin(rad(anglez)), 0], 313 | [-np.sin(rad(anglez)), np.cos(rad(anglez)), 0]], np.float32) 314 | return rz 315 | 316 | 317 | def warp(img, use_tia=True, prob=0.4, angz=2): 318 | """ 319 | warp 320 | """ 321 | h, w, _ = img.shape 322 | config = Config(use_tia=use_tia) 323 | config.make(w, h, angz) 324 | new_img = img 325 | 326 | if config.distort: 327 | img_height, img_width = img.shape[0:2] 328 | if random.random() <= (prob - 0.1) and img_height >= 32 and img_width >= 32: 329 | new_img = tia_distort(new_img, 10) 330 | 331 | if config.stretch: 332 | img_height, img_width = img.shape[0:2] 333 | if random.random() <= prob and img_height >= 32 and img_width >= 32: 334 | new_img = tia_stretch(new_img, 10) 335 | 336 | if config.perspective: 337 | if random.random() <= prob: 338 | new_img = tia_perspective(new_img) 339 | 340 | if config.crop: 341 | img_height, img_width = img.shape[0:2] 342 | if random.random() <= prob and img_height >= 32 and img_width >= 32: 343 | new_img = get_crop(new_img) 344 | 345 | if config.blur: 346 | bprob = prob / 3. 347 | if random.random() <= bprob: 348 | new_img = blur(new_img) 349 | elif random.random() > bprob and random.random() <= 2 * bprob: 350 | new_img = resizeAug(new_img) 351 | elif random.random() > 2 * bprob and random.random() <= 3 * bprob: 352 | new_img = motion_blur(new_img) 353 | 354 | if config.random_dilute: 355 | if random.random() <= prob: 356 | new_img = random_dilute(new_img) 357 | 358 | if config.color: 359 | if random.random() <= prob: 360 | new_img = cvtColor(new_img) 361 | if config.jitter: 362 | new_img = jitter(new_img) 363 | if config.noise: 364 | if random.random() <= prob: 365 | new_img = add_gasuss_noise(new_img) 366 | if config.reverse: 367 | if random.random() <= prob: 368 | new_img = 255 - new_img 369 | 370 | return new_img 371 | 372 | -------------------------------------------------------------------------------- /test_benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import string 4 | import argparse 5 | import re 6 | 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import torch.utils.data 10 | import torch.nn.functional as F 11 | import numpy as np 12 | from nltk.metrics.distance import edit_distance 13 | 14 | from utils import CTCLabelConverter, AttnLabelConverter, Averager 15 | from dataset import hierarchical_dataset, AlignCollate 16 | from model import Model 17 | from collections import OrderedDict 18 | 19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 20 | 21 | 22 | def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=False): 23 | """ evaluation with 6 benchmark evaluation datasets """ 24 | 25 | # # To easily compute the total accuracy of our paper. 26 | eval_data_list = ['IIIT5k_3000', 'SVT', 'IC13_857', 'IC15_1811', 'SVTP', 'CUTE80'] 27 | 28 | if calculate_infer_time: 29 | evaluation_batch_size = 1 # batch_size should be 1 to calculate the GPU inference time per image. 30 | else: 31 | evaluation_batch_size = opt.batch_size 32 | 33 | list_accuracy = [] 34 | total_forward_time = 0 35 | total_evaluation_data_number = 0 36 | total_correct_number = 0 37 | log = open(f'./result/{opt.exp_name}/log_all_evaluation.txt', 'a') 38 | dashed_line = '-' * 80 39 | print(dashed_line) 40 | log.write(dashed_line + '\n') 41 | for eval_data in eval_data_list: 42 | eval_data_path = os.path.join(opt.eval_data, eval_data) 43 | AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 44 | eval_data, eval_data_log = hierarchical_dataset(root=eval_data_path, opt=opt) 45 | evaluation_loader = torch.utils.data.DataLoader( 46 | eval_data, batch_size=evaluation_batch_size, 47 | shuffle=False, 48 | num_workers=int(opt.workers), 49 | collate_fn=AlignCollate_evaluation, pin_memory=True) 50 | 51 | _, accuracy_by_best_model, norm_ED_by_best_model, _, _, _, infer_time, length_of_data = validation( 52 | model, criterion, evaluation_loader, converter, opt) 53 | list_accuracy.append(f'{accuracy_by_best_model:0.3f}') 54 | total_forward_time += infer_time 55 | total_evaluation_data_number += len(eval_data) 56 | total_correct_number += accuracy_by_best_model * length_of_data 57 | log.write(eval_data_log) 58 | print(f'Acc {accuracy_by_best_model:0.3f}\t normalized_ED {norm_ED_by_best_model:0.3f}') 59 | log.write(f'Acc {accuracy_by_best_model:0.3f}\t normalized_ED {norm_ED_by_best_model:0.3f}\n') 60 | print(dashed_line) 61 | log.write(dashed_line + '\n') 62 | print(total_forward_time) 63 | print(total_evaluation_data_number) 64 | averaged_forward_time = total_forward_time / total_evaluation_data_number * 1000 65 | total_accuracy = total_correct_number / total_evaluation_data_number 66 | params_num = sum([np.prod(p.size()) for p in model.parameters()]) 67 | 68 | evaluation_log = 'accuracy: ' 69 | for name, accuracy in zip(eval_data_list, list_accuracy): 70 | evaluation_log += f'{name}: {accuracy}\t' 71 | evaluation_log += f'total_accuracy: {total_accuracy:0.3f}\t' 72 | evaluation_log += f'averaged_infer_time: {averaged_forward_time:0.3f}\t# parameters: {params_num / 1e6:0.3f}' 73 | print(evaluation_log) 74 | log.write(evaluation_log + '\n') 75 | log.close() 76 | 77 | return None 78 | 79 | 80 | def validation(model, criterion, evaluation_loader, converter, opt): 81 | """ validation or evaluation """ 82 | n_correct = 0 83 | norm_ED = 0 84 | length_of_data = 0 85 | infer_time = 0 86 | valid_loss_avg = Averager() 87 | 88 | for i, (image_tensors, labels) in enumerate(evaluation_loader): 89 | batch_size = image_tensors.size(0) 90 | length_of_data = length_of_data + batch_size 91 | image = image_tensors.to(device) 92 | # For max length prediction 93 | length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) 94 | text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) 95 | 96 | text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length) 97 | 98 | start_time = time.time() 99 | if 'CTC' in opt.Prediction: 100 | preds = model(image, text_for_pred) 101 | 102 | # Calculate evaluation loss for CTC deocder. 103 | preds_size = torch.IntTensor([preds.size(1)] * batch_size) 104 | # permute 'preds' to use CTCloss format 105 | if opt.baiduCTC: 106 | cost = criterion(preds.permute(1, 0, 2), text_for_loss, preds_size, length_for_loss) / batch_size 107 | else: 108 | cost = criterion(preds.log_softmax(2).permute(1, 0, 2), text_for_loss, preds_size, length_for_loss) 109 | 110 | # Select max probabilty (greedy decoding) then decode index to character 111 | if opt.baiduCTC: 112 | _, preds_index = preds.max(2) 113 | preds_index = preds_index.view(-1) 114 | else: 115 | _, preds_index = preds.max(2) 116 | preds_str = converter.decode(preds_index.data, preds_size.data) 117 | 118 | else: 119 | preds = model(image, text_for_pred, is_train=False) 120 | 121 | preds = preds[:, :text_for_loss.shape[1] - 1, :] 122 | target = text_for_loss[:, 1:] # without [GO] Symbol 123 | cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), target.contiguous().view(-1)) 124 | 125 | # select max probabilty (greedy decoding) then decode index to character 126 | _, preds_index = preds.max(2) 127 | preds_str = converter.decode(preds_index, length_for_pred) 128 | labels = converter.decode(text_for_loss[:, 1:], length_for_loss) 129 | 130 | forward_time = time.time() - start_time 131 | infer_time += forward_time 132 | valid_loss_avg.add(cost) 133 | 134 | # calculate accuracy & confidence score 135 | preds_prob = F.softmax(preds, dim=2) 136 | preds_max_prob, _ = preds_prob.max(dim=2) 137 | confidence_score_list = [] 138 | for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob): 139 | if 'Attn' in opt.Prediction: 140 | gt = gt[:gt.find('[s]')] 141 | pred_EOS = pred.find('[s]') 142 | pred = pred[:pred_EOS] # prune after "end of sentence" token ([s]) 143 | pred_max_prob = pred_max_prob[:pred_EOS] 144 | 145 | # To evaluate 'case sensitive model' with alphanumeric and case insensitve setting. 146 | if opt.sensitive and opt.data_filtering_off: 147 | pred = pred.lower() 148 | gt = gt.lower() 149 | alphanumeric_case_insensitve = '0123456789abcdefghijklmnopqrstuvwxyz' 150 | out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]' 151 | pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred) 152 | gt = re.sub(out_of_alphanumeric_case_insensitve, '', gt) 153 | 154 | if pred == gt: 155 | n_correct += 1 156 | 157 | # ICDAR2019 Normalized Edit Distance 158 | if len(gt) == 0 or len(pred) == 0: 159 | norm_ED += 0 160 | elif len(gt) > len(pred): 161 | norm_ED += 1 - edit_distance(pred, gt) / len(gt) 162 | else: 163 | norm_ED += 1 - edit_distance(pred, gt) / len(pred) 164 | 165 | # calculate confidence score (= multiply of pred_max_prob) 166 | try: 167 | confidence_score = pred_max_prob.cumprod(dim=0)[-1] 168 | except: 169 | confidence_score = 0 # for empty pred case, when prune after "end of sentence" token ([s]) 170 | confidence_score_list.append(confidence_score) 171 | # print(pred, gt, pred==gt, confidence_score) 172 | 173 | accuracy = n_correct / float(length_of_data) * 100 174 | norm_ED = norm_ED / float(length_of_data) # ICDAR2019 Normalized Edit Distance 175 | 176 | return valid_loss_avg.val(), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data 177 | 178 | 179 | def test(opt): 180 | """ model configuration """ 181 | if 'CTC' in opt.Prediction: 182 | converter = CTCLabelConverter(opt.character) 183 | else: 184 | converter = AttnLabelConverter(opt.character) 185 | opt.num_class = len(converter.character) 186 | 187 | if opt.rgb: 188 | opt.input_channel = 3 189 | model = Model(opt) 190 | print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, 191 | opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, 192 | opt.SequenceModeling, opt.Prediction) 193 | model = torch.nn.DataParallel(model).to(device) 194 | 195 | # load model 196 | print('loading pretrained model from %s' % opt.saved_model) 197 | state_dict = torch.load(opt.saved_model, map_location=device) 198 | 199 | model.load_state_dict(state_dict, strict=False) 200 | 201 | opt.exp_name = '_'.join(opt.saved_model.split('/')[1:]) 202 | # print(model) 203 | 204 | """ keep evaluation model and result logs """ 205 | os.makedirs(f'./result/{opt.exp_name}', exist_ok=True) 206 | # os.system(f'cp {opt.saved_model} ./result/{opt.exp_name}/') 207 | 208 | """ setup loss """ 209 | if 'CTC' in opt.Prediction: 210 | criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) 211 | else: 212 | criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 213 | 214 | """ evaluation """ 215 | model.eval() 216 | with torch.no_grad(): 217 | if opt.benchmark_all_eval: # evaluation with 6 benchmark evaluation datasets 218 | benchmark_all_eval(model, criterion, converter, opt) 219 | else: 220 | log = open(f'./result/{opt.exp_name}/log_evaluation.txt', 'a') 221 | AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 222 | eval_data, eval_data_log = hierarchical_dataset(root=opt.eval_data, opt=opt) 223 | evaluation_loader = torch.utils.data.DataLoader( 224 | eval_data, batch_size=opt.batch_size, 225 | shuffle=False, 226 | num_workers=int(opt.workers), 227 | collate_fn=AlignCollate_evaluation, pin_memory=True) 228 | _, accuracy_by_best_model, _, _, _, _, _, _ = validation( 229 | model, criterion, evaluation_loader, converter, opt) 230 | log.write(eval_data_log) 231 | print(f'{accuracy_by_best_model:0.3f}') 232 | log.write(f'{accuracy_by_best_model:0.3f}\n') 233 | log.close() 234 | 235 | 236 | if __name__ == '__main__': 237 | parser = argparse.ArgumentParser() 238 | parser.add_argument('--eval_data', required=False, help='path to evaluation dataset') 239 | parser.add_argument('--benchmark_all_eval', action='store_true', help='evaluate 6 benchmark evaluation datasets') 240 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) 241 | parser.add_argument('--batch_size', type=int, default=192, help='input batch size') 242 | parser.add_argument('--saved_model', required=False, help="path to saved_model to evaluation") 243 | """ Data processing """ 244 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') 245 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') 246 | parser.add_argument('--imgW', type=int, default=96, help='the width of the input image') 247 | parser.add_argument('--rgb', action='store_true', help='use rgb input') 248 | parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label') 249 | parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode') 250 | parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize') 251 | parser.add_argument('--data_filtering_off', action='store_true', help='for data_filtering_off mode') 252 | parser.add_argument('--baiduCTC', action='store_true', help='for data_filtering_off mode') 253 | """ Model Architecture """ 254 | parser.add_argument('--Transformation', type=str, required=False, help='Transformation stage. None|TPS') 255 | parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet') 256 | parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM') 257 | parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn') 258 | parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') 259 | parser.add_argument('--input_channel', type=int, default=3, help='the number of input channel of Feature extractor') 260 | parser.add_argument('--output_channel', type=int, default=384, 261 | help='the number of output channel of Feature extractor') 262 | parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') 263 | 264 | opt = parser.parse_args() 265 | opt.rgb = True 266 | opt.eval_data = "text-recognition/data_lmdb_release/evaluation/" 267 | opt.saved_model = "saved_models/VIPTRv1-L_en/best_accuracy.pth" 268 | # opt.character += "ABCDEFGHIJKLMNOPQRSTUVWXYZ" 269 | """ vocab / character number configuration """ 270 | if opt.sensitive: 271 | opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). 272 | 273 | cudnn.benchmark = True 274 | cudnn.deterministic = True 275 | opt.num_gpu = torch.cuda.device_count() 276 | 277 | test(opt) 278 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | import six 5 | import math 6 | import lmdb 7 | import torch 8 | 9 | from natsort import natsorted 10 | from PIL import Image 11 | import numpy as np 12 | from torch.utils.data import Dataset, ConcatDataset, Subset 13 | from torch._utils import _accumulate 14 | from dataload.dataAug import * 15 | import torchvision.transforms as transforms 16 | 17 | class Batch_Balanced_Dataset(object): 18 | 19 | def __init__(self, opt): 20 | """ 21 | Modulate the data ratio in the batch. 22 | For example, when select_data is "MJ-ST" and batch_ratio is "0.5-0.5", 23 | the 50% of the batch is filled with MJ and the other 50% of the batch is filled with ST. 24 | """ 25 | log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') 26 | dashed_line = '-' * 80 27 | print(dashed_line) 28 | log.write(dashed_line + '\n') 29 | print(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}') 30 | log.write(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}\n') 31 | assert len(opt.select_data) == len(opt.batch_ratio) 32 | 33 | _AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, use_type='train', use_aug=True, aug_prob=0.4) 34 | self.data_loader_list = [] 35 | self.dataloader_iter_list = [] 36 | batch_size_list = [] 37 | Total_batch_size = 0 38 | for selected_d, batch_ratio_d in zip(opt.select_data, opt.batch_ratio): 39 | _batch_size = max(round(opt.batch_size * float(batch_ratio_d)), 1) 40 | print(dashed_line) 41 | log.write(dashed_line + '\n') 42 | _dataset, _dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt, select_data=[selected_d]) 43 | total_number_dataset = len(_dataset) 44 | log.write(_dataset_log) 45 | 46 | """ 47 | The total number of data can be modified with opt.total_data_usage_ratio. 48 | ex) opt.total_data_usage_ratio = 1 indicates 100% usage, and 0.2 indicates 20% usage. 49 | See 4.2 section in our paper. 50 | """ 51 | number_dataset = int(total_number_dataset * float(opt.total_data_usage_ratio)) 52 | dataset_split = [number_dataset, total_number_dataset - number_dataset] 53 | indices = range(total_number_dataset) 54 | _dataset, _ = [Subset(_dataset, indices[offset - length:offset]) 55 | for offset, length in zip(_accumulate(dataset_split), dataset_split)] 56 | selected_d_log = f'num total samples of {selected_d}: {total_number_dataset} x {opt.total_data_usage_ratio} (total_data_usage_ratio) = {len(_dataset)}\n' 57 | selected_d_log += f'num samples of {selected_d} per batch: {opt.batch_size} x {float(batch_ratio_d)} (batch_ratio) = {_batch_size}' 58 | print(selected_d_log) 59 | log.write(selected_d_log + '\n') 60 | batch_size_list.append(str(_batch_size)) 61 | Total_batch_size += _batch_size 62 | 63 | _data_loader = torch.utils.data.DataLoader( 64 | _dataset, batch_size=_batch_size, 65 | shuffle=True, 66 | num_workers=int(opt.workers), 67 | collate_fn=_AlignCollate, pin_memory=True) 68 | self.data_loader_list.append(_data_loader) 69 | self.dataloader_iter_list.append(iter(_data_loader)) 70 | 71 | Total_batch_size_log = f'{dashed_line}\n' 72 | batch_size_sum = '+'.join(batch_size_list) 73 | Total_batch_size_log += f'Total_batch_size: {batch_size_sum} = {Total_batch_size}\n' 74 | Total_batch_size_log += f'{dashed_line}' 75 | opt.batch_size = Total_batch_size 76 | 77 | print(Total_batch_size_log) 78 | log.write(Total_batch_size_log + '\n') 79 | log.close() 80 | 81 | def get_batch(self): 82 | balanced_batch_images = [] 83 | balanced_batch_texts = [] 84 | 85 | for i, data_loader_iter in enumerate(self.dataloader_iter_list): 86 | try: 87 | image, text = next(data_loader_iter) 88 | balanced_batch_images.append(image) 89 | balanced_batch_texts += text 90 | except StopIteration: 91 | self.dataloader_iter_list[i] = iter(self.data_loader_list[i]) 92 | image, text = next(self.dataloader_iter_list[i]) 93 | balanced_batch_images.append(image) 94 | balanced_batch_texts += text 95 | except ValueError: 96 | pass 97 | # print(type(balanced_batch_images)) 98 | balanced_batch_images = torch.cat(balanced_batch_images, 0) 99 | 100 | return balanced_batch_images, balanced_batch_texts 101 | 102 | 103 | def hierarchical_dataset(root, opt, select_data='/'): 104 | """ select_data='/' contains all sub-directory of root directory """ 105 | dataset_list = [] 106 | dataset_log = f'dataset_root: {root}\t dataset: {select_data[0]}' 107 | print(dataset_log) 108 | dataset_log += '\n' 109 | for dirpath, dirnames, filenames in os.walk(root+'/'): 110 | if not dirnames: 111 | select_flag = False 112 | for selected_d in select_data: 113 | if selected_d in dirpath: 114 | select_flag = True 115 | break 116 | 117 | if select_flag: 118 | dataset = LmdbDataset(dirpath, opt) 119 | sub_dataset_log = f'sub-directory:\t/{os.path.relpath(dirpath, root)}\t num samples: {len(dataset)}' 120 | print(sub_dataset_log) 121 | dataset_log += f'{sub_dataset_log}\n' 122 | dataset_list.append(dataset) 123 | 124 | concatenated_dataset = ConcatDataset(dataset_list) 125 | 126 | return concatenated_dataset, dataset_log 127 | 128 | 129 | class LmdbDataset(Dataset): 130 | 131 | def __init__(self, root, opt): 132 | 133 | self.root = root 134 | self.opt = opt 135 | self.env = lmdb.open(root, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) 136 | if not self.env: 137 | print('cannot create lmdb from %s' % (root)) 138 | sys.exit(0) 139 | 140 | with self.env.begin(write=False) as txn: 141 | nSamples = int(txn.get('num-samples'.encode())) 142 | self.nSamples = nSamples 143 | 144 | if self.opt.data_filtering_off: 145 | # for fast check or benchmark evaluation with no filtering 146 | self.filtered_index_list = [index + 1 for index in range(self.nSamples)] 147 | else: 148 | """ Filtering part 149 | If you want to evaluate IC15-2077 & CUTE datasets which have special character labels, 150 | use --data_filtering_off and only evaluate on alphabets and digits. 151 | see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L190-L192 152 | 153 | And if you want to evaluate them with the model trained with --sensitive option, 154 | use --sensitive and --data_filtering_off, 155 | see https://github.com/clovaai/deep-text-recognition-benchmark/blob/dff844874dbe9e0ec8c5a52a7bd08c7f20afe704/test.py#L137-L144 156 | """ 157 | self.filtered_index_list = [] 158 | for index in range(self.nSamples): 159 | index += 1 # lmdb starts with 1 160 | label_key = 'label-%09d'.encode() % index 161 | label = txn.get(label_key).decode('utf-8') 162 | 163 | if len(label) > self.opt.batch_max_length: 164 | # print(f'The length of the label is longer than max_length: length 165 | # {len(label)}, {label} in dataset {self.root}') 166 | continue 167 | 168 | # By default, images containing characters which are not in opt.character are filtered. 169 | # You can add [UNK] token to `opt.character` in utils.py instead of this filtering. 170 | out_of_char = f'[^{self.opt.character}]' 171 | if re.search(out_of_char, label.lower()): 172 | continue 173 | 174 | self.filtered_index_list.append(index) 175 | 176 | self.nSamples = len(self.filtered_index_list) 177 | 178 | def __len__(self): 179 | return self.nSamples 180 | 181 | def __getitem__(self, index): 182 | assert index <= len(self), 'index range error' 183 | index = self.filtered_index_list[index] 184 | 185 | with self.env.begin(write=False) as txn: 186 | label_key = 'label-%09d'.encode() % index 187 | label = txn.get(label_key).decode('utf-8') 188 | img_key = 'image-%09d'.encode() % index 189 | imgbuf = txn.get(img_key) 190 | 191 | buf = six.BytesIO() 192 | buf.write(imgbuf) 193 | buf.seek(0) 194 | try: 195 | if self.opt.rgb: 196 | img = Image.open(buf).convert('RGB') # for color image 197 | else: 198 | img = Image.open(buf).convert('L') 199 | 200 | except IOError: 201 | print(f'Corrupted image for {index}') 202 | # make dummy image and dummy label for corrupted image. 203 | if self.opt.rgb: 204 | img = Image.new('RGB', (self.opt.imgW, self.opt.imgH)) 205 | else: 206 | img = Image.new('L', (self.opt.imgW, self.opt.imgH)) 207 | label = '[dummy_label]' 208 | 209 | if not self.opt.sensitive: 210 | label = label.lower() 211 | 212 | # We only train and evaluate on alphanumerics (or pre-defined character set in train.py) 213 | out_of_char = f'[^{self.opt.character}]' 214 | label = re.sub(out_of_char, '', label) 215 | 216 | return (img, label) 217 | 218 | 219 | class RawDataset(Dataset): 220 | 221 | def __init__(self, root, opt): 222 | self.opt = opt 223 | self.image_path_list = [] 224 | for dirpath, dirnames, filenames in os.walk(root): 225 | for name in filenames: 226 | _, ext = os.path.splitext(name) 227 | ext = ext.lower() 228 | if ext == '.jpg' or ext == '.jpeg' or ext == '.png': 229 | self.image_path_list.append(os.path.join(dirpath, name)) 230 | 231 | self.image_path_list = natsorted(self.image_path_list) 232 | self.nSamples = len(self.image_path_list) 233 | 234 | def __len__(self): 235 | return self.nSamples 236 | 237 | def __getitem__(self, index): 238 | 239 | try: 240 | if self.opt.rgb: 241 | img = Image.open(self.image_path_list[index]).convert('RGB') # for color image 242 | else: 243 | img = Image.open(self.image_path_list[index]).convert('L') 244 | 245 | except IOError: 246 | print(f'Corrupted image for {index}') 247 | # make dummy image and dummy label for corrupted image. 248 | if self.opt.rgb: 249 | img = Image.new('RGB', (self.opt.imgW, self.opt.imgH)) 250 | else: 251 | img = Image.new('L', (self.opt.imgW, self.opt.imgH)) 252 | 253 | return (img, self.image_path_list[index]) 254 | 255 | 256 | class ResizeNormalize(object): 257 | 258 | def __init__(self, size, interpolation=Image.BICUBIC): 259 | self.size = size 260 | self.interpolation = interpolation 261 | self.toTensor = transforms.ToTensor() 262 | 263 | def __call__(self, img): 264 | img = img.resize(self.size, self.interpolation) 265 | img = self.toTensor(img) 266 | img.sub_(0.5).div_(0.5) 267 | return img 268 | 269 | 270 | class NormalizePAD(object): 271 | 272 | def __init__(self, max_size, PAD_type='right'): 273 | self.toTensor = transforms.ToTensor() 274 | self.max_size = max_size 275 | self.max_width_half = math.floor(max_size[2] / 2) 276 | self.PAD_type = PAD_type 277 | 278 | def __call__(self, img): 279 | img = self.toTensor(img) 280 | img.sub_(0.5).div_(0.5) 281 | c, h, w = img.size() 282 | Pad_img = torch.FloatTensor(*self.max_size).fill_(0) 283 | Pad_img[:, :, :w] = img # right pad 284 | if self.max_size[2] != w: # add border Pad 285 | Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w) 286 | 287 | return Pad_img 288 | 289 | 290 | class AlignCollate(object): 291 | 292 | def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False, use_type='valid', use_aug=False, aug_prob=0.4): 293 | self.imgH = imgH 294 | self.imgW = imgW 295 | self.keep_ratio_with_pad = keep_ratio_with_pad 296 | self.use_aug = use_aug 297 | self.aug_prob = aug_prob 298 | self.use_type = use_type 299 | print(use_type) 300 | print(use_aug) 301 | def __call__(self, batch): 302 | batch = filter(lambda x: x is not None, batch) 303 | images, labels = zip(*batch) 304 | new_images = [] 305 | for (image, label) in zip(images, labels): 306 | if self.use_type == 'train': 307 | image = np.array(image) 308 | try: 309 | image = warp(image, self.use_aug, self.aug_prob) 310 | except: 311 | pass 312 | image = Image.fromarray(image) 313 | new_images.append(image) 314 | 315 | if self.keep_ratio_with_pad: # same concept with 'Rosetta' paper 316 | resized_max_w = self.imgW 317 | input_channel = 3 if new_images[0].mode == 'RGB' else 1 318 | transform = NormalizePAD((input_channel, self.imgH, resized_max_w)) 319 | 320 | resized_images = [] 321 | for image in new_images: 322 | w, h = image.size 323 | ratio = w / float(h) 324 | if math.ceil(self.imgH * ratio) > self.imgW: 325 | resized_w = self.imgW 326 | else: 327 | resized_w = math.ceil(self.imgH * ratio) 328 | 329 | resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC) 330 | resized_images.append(transform(resized_image)) 331 | # resized_image.save('./image_test/%d_test.jpg' % w) 332 | 333 | image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0) 334 | 335 | else: 336 | transform = ResizeNormalize((self.imgW, self.imgH)) 337 | image_tensors = [transform(image) for image in new_images] 338 | image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0) 339 | 340 | return image_tensors, labels 341 | 342 | 343 | def tensor2im(image_tensor, imtype=np.uint8): 344 | image_numpy = image_tensor.cpu().float().numpy() 345 | if image_numpy.shape[0] == 1: 346 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 347 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 348 | return image_numpy.astype(imtype) 349 | 350 | 351 | def save_image(image_numpy, image_path): 352 | image_pil = Image.fromarray(image_numpy) 353 | image_pil.save(image_path) 354 | -------------------------------------------------------------------------------- /train_benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import random 5 | import string 6 | import argparse 7 | 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | import torch.nn.init as init 11 | import torch.optim as optim 12 | import torch.utils.data 13 | import numpy as np 14 | 15 | from utils import CTCLabelConverter, CTCLabelConverterForBaiduWarpctc, AttnLabelConverter, Averager 16 | from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset 17 | from model import Model 18 | from test import validation 19 | from modules.rec_sar_loss import SARLoss 20 | from modules.dctc_loss import DCTC 21 | from optimizer import CosineAnnealingLR 22 | 23 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" # 可修改 24 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 25 | 26 | 27 | def train(opt): 28 | """ dataset preparation """ 29 | if not opt.data_filtering_off: 30 | print('Filtering the images containing characters which are not in opt.character') 31 | print('Filtering the images whose label is longer than opt.batch_max_length') 32 | # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 33 | 34 | opt.select_data = opt.select_data.split('-') 35 | opt.batch_ratio = opt.batch_ratio.split('-') 36 | train_dataset = Batch_Balanced_Dataset(opt) 37 | 38 | log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') 39 | AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 40 | valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) 41 | valid_loader = torch.utils.data.DataLoader( 42 | valid_dataset, batch_size=opt.batch_size, 43 | shuffle=True, # 'True' to check training progress with validation function. 44 | num_workers=int(opt.workers), 45 | collate_fn=AlignCollate_valid, pin_memory=True) 46 | log.write(valid_dataset_log) 47 | print('-' * 80) 48 | log.write('-' * 80 + '\n') 49 | log.close() 50 | 51 | """ model configuration """ 52 | if 'CTC' in opt.Prediction: 53 | if opt.baiduCTC: 54 | converter = CTCLabelConverterForBaiduWarpctc(opt.character) 55 | else: 56 | converter = CTCLabelConverter(opt.character) 57 | else: 58 | converter = AttnLabelConverter(opt.character) 59 | 60 | opt.num_class = len(converter.character) 61 | 62 | if opt.rgb: 63 | opt.input_channel = 3 64 | model = Model(opt) 65 | print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, 66 | opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, 67 | opt.SequenceModeling, opt.Prediction) 68 | 69 | # weight initialization 70 | # for name, param in model.named_parameters(): 71 | # if 'localization_fc2' in name: 72 | # print(f'Skip {name} as it is already initialized') 73 | # continue 74 | # try: 75 | # if 'bias' in name: 76 | # init.constant_(param, 0.0) 77 | # elif 'weight' in name: 78 | # init.kaiming_normal_(param) 79 | # except Exception as e: # for batchnorm. 80 | # if 'weight' in name: 81 | # param.data.fill_(1) 82 | # continue 83 | 84 | # data parallel for multi-GPU 85 | model = torch.nn.DataParallel(model).to(device) 86 | model.train() 87 | if opt.saved_model != '': 88 | print(f'loading pretrained model from {opt.saved_model}') 89 | if opt.FT: 90 | model.load_state_dict(torch.load(opt.saved_model), strict=False) 91 | else: 92 | model.load_state_dict(torch.load(opt.saved_model)) 93 | # print("Model:") # 打印模型结构 94 | # print(model) 95 | 96 | """ setup loss """ 97 | if 'CTC' == opt.Prediction: 98 | if opt.baiduCTC: 99 | # need to install warpctc. see our guideline. 100 | from warpctc_pytorch import CTCLoss 101 | criterion = CTCLoss() 102 | else: 103 | criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) 104 | elif 'DCTC' in opt.Prediction: 105 | criterion = DCTC(use_il=False, alpha=0.01) 106 | else: 107 | criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 108 | 109 | # loss averager 110 | loss_avg = Averager() 111 | 112 | # filter that only require gradient decent 113 | filtered_parameters = [] 114 | params_num = [] 115 | for p in filter(lambda p: p.requires_grad, model.parameters()): 116 | filtered_parameters.append(p) 117 | params_num.append(np.prod(p.size())) 118 | print('Trainable params num : ', sum(params_num)) 119 | # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] 120 | 121 | # setup optimizer 122 | if opt.adam: 123 | # optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) # , weight_decay=3.0e-05 124 | optimizer = optim.AdamW(filtered_parameters, lr=opt.lr, 125 | betas=(opt.beta1, 0.999), 126 | eps=8.e-8, 127 | weight_decay=0.05, 128 | amsgrad=False) 129 | else: 130 | optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) 131 | print("Optimizer:") 132 | print(optimizer) 133 | 134 | if opt.lr_scheduler: 135 | lr_scheduler = CosineAnnealingLR( 136 | epochs=opt.num_epochs, 137 | step_each_epoch=opt.one_epoch_steps, 138 | warmup_epoch=2, 139 | last_epoch=-1)(optimizer=optimizer) 140 | else: 141 | lr_scheduler = None 142 | 143 | """ final options """ 144 | # print(opt) 145 | with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: 146 | opt_log = '------------ Options -------------\n' 147 | args = vars(opt) 148 | for k, v in args.items(): 149 | if k != "character": 150 | opt_log += f'{str(k)}: {str(v)}\n' 151 | opt_log += '---------------------------------------\n' 152 | print(opt_log) 153 | opt_file.write(opt_log) 154 | 155 | """ start training """ 156 | start_iter = 0 157 | if opt.saved_model != '': 158 | try: 159 | start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) 160 | print(f'continue to train, start_iter: {start_iter}') 161 | except: 162 | pass 163 | 164 | start_time = time.time() 165 | best_accuracy = -1 166 | best_norm_ED = -1 167 | iteration = start_iter 168 | 169 | while(True): 170 | # train part 171 | image_tensors, labels = train_dataset.get_batch() 172 | image = image_tensors.to(device) 173 | text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) 174 | batch_size = image.size(0) 175 | 176 | if 'CTC' == opt.Prediction: 177 | preds = model(image, text) 178 | preds_size = torch.IntTensor([preds.size(1)] * batch_size) 179 | if opt.baiduCTC: 180 | preds = preds.permute(1, 0, 2) # to use CTCLoss format 181 | cost = criterion(preds, text, preds_size, length) / batch_size 182 | else: 183 | preds = preds.log_softmax(2).permute(1, 0, 2) 184 | cost = criterion(preds, text, preds_size, length) 185 | elif 'DCTC' in opt.Prediction: 186 | preds = model(image, text) 187 | preds = preds.permute(1, 0, 2) 188 | targets_dict = { 189 | 'targets': text, 190 | 'target_lengths': length 191 | } 192 | cost = criterion(logits=preds, targets_dict=targets_dict) 193 | else: 194 | preds = model(image, text[:, :-1]) # align with Attention.forward 195 | target = text[:, 1:] # without [GO] Symbol 196 | cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) 197 | 198 | model.zero_grad() 199 | cost.backward() 200 | torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) 201 | optimizer.step() 202 | 203 | loss_avg.add(cost) 204 | 205 | if opt.lr_scheduler: 206 | lr_scheduler.step() 207 | 208 | # validation part 209 | if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' 210 | elapsed_time = time.time() - start_time 211 | # for log 212 | with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: 213 | model.eval() 214 | with torch.no_grad(): 215 | valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( 216 | model, criterion, valid_loader, converter, opt) 217 | model.train() 218 | 219 | # training loss and validation loss 220 | loss_log = f'[{iteration + 1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss: 0.5f}, Elapsed_time: {elapsed_time:0.5f}' 221 | loss_avg.reset() 222 | 223 | current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' 224 | 225 | # keep best accuracy model (on valid dataset) 226 | if current_accuracy > best_accuracy: 227 | best_accuracy = current_accuracy 228 | torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy.pth') 229 | if current_norm_ED > best_norm_ED: 230 | best_norm_ED = current_norm_ED 231 | torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth') 232 | best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' 233 | 234 | loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' 235 | print(loss_model_log) 236 | log.write(loss_model_log + '\n') 237 | 238 | # show some predicted results 239 | dashed_line = '-' * 80 240 | head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' 241 | predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' 242 | for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): 243 | if 'Attn' in opt.Prediction: 244 | gt = gt[:gt.find('[s]')] 245 | pred = pred[:pred.find('[s]')] 246 | 247 | predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' 248 | predicted_result_log += f'{dashed_line}' 249 | print(predicted_result_log) 250 | log.write(predicted_result_log + '\n') 251 | 252 | # save model per 1e+5 iter. 253 | if (iteration + 1) % 5e+4 == 0: 254 | torch.save( 255 | model.state_dict(), f'./saved_models/{opt.exp_name}/iter_{iteration + 1}.pth') 256 | 257 | if (iteration + 1) == opt.num_iter: 258 | print('end the training') 259 | sys.exit() 260 | iteration += 1 261 | 262 | 263 | if __name__ == '__main__': 264 | train_dir = "/home/chengxf/cqc/text-recognition/data_lmdb_release/training/" # 可修改 265 | valid_dir = "/home/chengxf/cqc/text-recognition/data_lmdb_release/validation/" # 可修改 266 | 267 | parser = argparse.ArgumentParser() 268 | parser.add_argument('--exp_name', type=str, default='exp', help='Where to store logs and models') # 可修改 269 | parser.add_argument('--train_data', type=str, default=train_dir, help='path to training dataset') 270 | parser.add_argument('--valid_data', type=str, default=valid_dir, help='path to validation dataset') 271 | parser.add_argument('--manualSeed', type=int, default=1111, help='for random seed setting') 272 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=8) 273 | parser.add_argument('--batch_size', type=int, default=512, help='input batch size') # 192 # 可修改 274 | parser.add_argument('--num_iter', type=int, default=600000, help='number of iterations to train for') # 可修改 275 | parser.add_argument('--valInterval', type=int, default=500, help='Interval between each validation') # 可修改 276 | parser.add_argument('--saved_model', default='', # ./saved_models/eng_benchmark19/best_accuracy.pth', 277 | help="path to model to continue training") # ./saved_models/exp_230719_yzaug/best_accuracy.pth 278 | parser.add_argument('--FT', action='store_true', help='whether to do fine-tuning') 279 | parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is Adadelta)') 280 | parser.add_argument('--lr', type=float, default=1, help='learning rate, default=1.0 for Adadelta') 281 | parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for adam. default=0.9') 282 | parser.add_argument('--rho', type=float, default=0.95, help='decay rate rho for Adadelta. default=0.95') 283 | parser.add_argument('--eps', type=float, default=1e-8, help='eps for Adadelta. default=1e-8') 284 | parser.add_argument('--lr_scheduler', action='store_true', help='whether to set lr_scheduler') 285 | parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping value. default=5') 286 | parser.add_argument('--baiduCTC', action='store_true', help='for data_filtering_off mode') 287 | """ Data processing """ 288 | parser.add_argument('--select_data', type=str, default='MJ-ST', # 'valid_random_lmdb', # 可修改 289 | help='select training data (default is MJ-ST, which means MJ and ST used as training data)') 290 | parser.add_argument('--batch_ratio', type=str, default='0.5-0.5', 291 | help='assign ratio for each selected data in the batch') 292 | parser.add_argument('--total_data_usage_ratio', type=str, default='1.0', 293 | help='total data usage ratio, this ratio is multiplied to total number of data.') 294 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') 295 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') 296 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') 297 | parser.add_argument('--rgb', action='store_true', help='use rgb input') 298 | parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label') 299 | parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode') 300 | parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize') 301 | parser.add_argument('--data_filtering_off', action='store_true', help='for data_filtering_off mode') 302 | """ Model Architecture """ 303 | parser.add_argument('--Transformation', type=str, default='TPS', help='Transformation stage. None|TPS') 304 | parser.add_argument('--FeatureExtraction', type=str, default='ResNet18', 305 | help='FeatureExtraction stage. VGG|RCNN|ResNet') 306 | parser.add_argument('--SequenceModeling', type=str, default='BiLSTM', help='SequenceModeling stage. None|BiLSTM') 307 | parser.add_argument('--Prediction', type=str, default='CTC', help='Prediction stage. CTC|Attn') 308 | parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') 309 | parser.add_argument('--input_channel', type=int, default=3, 310 | help='the number of input channel of Feature extractor') 311 | parser.add_argument('--output_channel', type=int, default=256, 312 | help='the number of output channel of Feature extractor') 313 | parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') 314 | 315 | opt = parser.parse_args() 316 | opt.rgb = True 317 | opt.lr_scheduler = True 318 | 319 | if not opt.exp_name: 320 | opt.exp_name = f'{opt.Transformation}-{opt.FeatureExtraction}-{opt.SequenceModeling}-{opt.Prediction}' 321 | opt.exp_name += f'-Seed{opt.manualSeed}' 322 | # print(opt.exp_name) 323 | 324 | os.makedirs(f'./saved_models/{opt.exp_name}', exist_ok=True) 325 | opt.train_dataset_num = 8919241 + 5522807 326 | opt.one_epoch_steps = int(opt.train_dataset_num / opt.batch_size) 327 | opt.num_epochs = int(opt.num_iter / opt.one_epoch_steps) + 1 328 | """ vocab / character number configuration """ 329 | # opt.character += "ABCDEFGHIJKLMNOPQRSTUVWXYZ" 330 | # opt.character = '' 331 | # opt.character = string.printable[:-6] 332 | # with open("/home/chengxf/cqc/text-recognition/word_dict_cn_full.txt", "r", encoding='utf-8') as fin: # 可修改 333 | # all_characters = fin.readlines() 334 | # for ch in all_characters: 335 | # opt.character += ch.strip('\n') 336 | 337 | """ Seed and GPU setting """ 338 | print("Random Seed: ", opt.manualSeed) 339 | random.seed(opt.manualSeed) 340 | np.random.seed(opt.manualSeed) 341 | torch.manual_seed(opt.manualSeed) 342 | torch.cuda.manual_seed(opt.manualSeed) 343 | 344 | cudnn.benchmark = True 345 | cudnn.deterministic = True 346 | opt.num_gpu = torch.cuda.device_count() 347 | print('device count', opt.num_gpu) 348 | if opt.num_gpu > 1: 349 | print('------ Use multi-GPU setting ------') 350 | print('if you stuck too long time with multi-GPU setting, try to set --workers 0') 351 | # check multi-GPU issue https://github.com/clovaai/deep-text-recognition-benchmark/issues/1 352 | opt.workers = opt.workers * opt.num_gpu 353 | opt.batch_size = opt.batch_size * opt.num_gpu 354 | 355 | """ previous version 356 | print('To equlize batch stats to 1-GPU setting, the batch_size is multiplied with num_gpu and multiplied batch_size is ', opt.batch_size) 357 | opt.batch_size = opt.batch_size * opt.num_gpu 358 | print('To equalize the number of epochs to 1-GPU setting, num_iter is divided with num_gpu by default.') 359 | If you dont care about it, just commnet out these line.) 360 | opt.num_iter = int(opt.num_iter / opt.num_gpu) 361 | """ 362 | print('begin') 363 | train(opt) 364 | -------------------------------------------------------------------------------- /modules/SVTR.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.nn.init as init 6 | import numpy as np 7 | import time 8 | import os 9 | 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | 12 | def truncated_normal_(tensor, mean=0, std=0.02): 13 | with torch.no_grad(): 14 | size = tensor.shape 15 | tmp = tensor.new_empty(size + (4,)).normal_() 16 | valid = (tmp < 2) & (tmp > -2) 17 | ind = valid.max(-1, keepdim=True)[1] 18 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) 19 | tensor.data.mul_(std).add_(mean) 20 | return tensor 21 | 22 | 23 | def drop_path(x, local_rank, drop_prob=0., training=False): 24 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 25 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 26 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... 27 | """ 28 | if drop_prob == 0. or not training: 29 | return x 30 | keep_prob = torch.tensor(1 - drop_prob).to(local_rank) 31 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) 32 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype).to(local_rank) 33 | random_tensor = torch.floor(random_tensor) # binarize 34 | output = x.divide(keep_prob) * random_tensor 35 | return output 36 | 37 | 38 | class ConvBNLayer(nn.Module): 39 | def __init__(self, 40 | in_channels, 41 | out_channels, 42 | kernel_size=3, 43 | stride=1, 44 | padding=0, 45 | bias_attr=False, 46 | groups=1, 47 | act=nn.GELU): 48 | super().__init__() 49 | self.conv = nn.Conv2d( 50 | in_channels=in_channels, 51 | out_channels=out_channels, 52 | kernel_size=kernel_size, 53 | stride=stride, 54 | padding=padding, 55 | groups=groups, 56 | bias=bias_attr) 57 | self.norm = nn.BatchNorm2d(out_channels) 58 | self.act = act() 59 | 60 | def forward(self, inputs): 61 | out = self.conv(inputs) 62 | out = self.norm(out) 63 | out = self.act(out) 64 | return out 65 | 66 | 67 | class DropPath(nn.Module): 68 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 69 | """ 70 | 71 | def __init__(self, local_rank, drop_prob=None): 72 | super(DropPath, self).__init__() 73 | self.drop_prob = drop_prob 74 | self.local_rank = local_rank 75 | 76 | def forward(self, x): 77 | return drop_path(x, self.local_rank, self.drop_prob, self.training) 78 | 79 | 80 | class Identity(nn.Module): 81 | def __init__(self): 82 | super(Identity, self).__init__() 83 | 84 | def forward(self, input): 85 | return input 86 | 87 | 88 | class Mlp(nn.Module): 89 | def __init__(self, 90 | in_features, 91 | hidden_features=None, 92 | out_features=None, 93 | act_layer=nn.GELU, 94 | drop=0.): 95 | super().__init__() 96 | out_features = out_features or in_features 97 | hidden_features = hidden_features or in_features 98 | self.fc1 = nn.Linear(in_features, hidden_features) 99 | self.act = act_layer() 100 | self.fc2 = nn.Linear(hidden_features, out_features) 101 | self.drop = nn.Dropout(drop) 102 | 103 | def forward(self, x): 104 | x = self.fc1(x) 105 | x = self.act(x) 106 | x = self.drop(x) 107 | x = self.fc2(x) 108 | x = self.drop(x) 109 | return x 110 | 111 | 112 | class ConvMixer(nn.Module): 113 | def __init__( 114 | self, 115 | dim, 116 | num_heads=8, 117 | HW=[8, 25], 118 | local_k=[3, 3], ): 119 | super().__init__() 120 | self.HW = HW 121 | self.dim = dim 122 | self.local_mixer = nn.Conv2d( 123 | dim, 124 | dim, 125 | local_k, 126 | 1, [local_k[0] // 2, local_k[1] // 2], 127 | groups=num_heads, 128 | ) 129 | 130 | def forward(self, x): 131 | h = self.HW[0] 132 | w = self.HW[1] 133 | x = x.permute([0, 2, 1]).reshape([-1, self.dim, h, w]).contiguous() 134 | x = self.local_mixer(x) 135 | x = x.flatten(2).permute([0, 2, 1]).contiguous() 136 | return x 137 | 138 | 139 | class Attention(nn.Module): 140 | def __init__(self, 141 | dim, 142 | local_rank=-1, 143 | num_heads=8, 144 | mixer='Global', 145 | HW=None, 146 | local_k=[7, 11], 147 | qkv_bias=False, 148 | qk_scale=None, 149 | attn_drop=0., 150 | proj_drop=0.): 151 | super().__init__() 152 | self.num_heads = num_heads 153 | head_dim = dim // num_heads 154 | self.scale = qk_scale or head_dim ** -0.5 155 | 156 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 157 | self.attn_drop = nn.Dropout(attn_drop) 158 | self.proj = nn.Linear(dim, dim) 159 | self.proj_drop = nn.Dropout(proj_drop) 160 | self.HW = HW 161 | if HW is not None: 162 | H = HW[0] 163 | W = HW[1] 164 | self.N = H * W 165 | self.C = dim 166 | if mixer == 'Local' and HW is not None: 167 | hk = local_k[0] 168 | wk = local_k[1] 169 | mask = torch.ones([H * W, H + hk - 1, W + wk - 1]).float().to(local_rank) 170 | for h in range(0, H): 171 | for w in range(0, W): 172 | mask[h * W + w, h:h + hk, w:w + wk] = 0. 173 | mask_paddle = mask[:, hk // 2:H + hk // 2, wk // 2:W + wk // 2].flatten(1) 174 | mask_inf = torch.full([H * W, H * W], -np.inf).float().to(local_rank) ### 造成nan 175 | mask = torch.where(mask_paddle < 1, mask_paddle, mask_inf) 176 | self.mask = mask.unsqueeze(0).unsqueeze(1).contiguous() 177 | self.mixer = mixer 178 | 179 | def forward(self, x): 180 | if self.HW is not None: 181 | N = self.N 182 | C = self.C 183 | else: 184 | _, N, C = x.shape 185 | qkv = self.qkv(x).reshape((-1, N, 3, self.num_heads, C // self.num_heads)).permute((2, 0, 3, 1, 4)).contiguous() 186 | q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] 187 | 188 | attn = (q.matmul(k.permute((0, 1, 3, 2)))).contiguous() 189 | if self.mixer == 'Local': 190 | attn += self.mask 191 | attn = nn.functional.softmax(attn, -1) 192 | attn = self.attn_drop(attn) 193 | 194 | x = (attn.matmul(v)).permute((0, 2, 1, 3)).reshape((-1, N, C)).contiguous() 195 | x = self.proj(x) 196 | x = self.proj_drop(x) 197 | return x 198 | 199 | 200 | class Block(nn.Module): 201 | def __init__(self, 202 | dim, 203 | num_heads, 204 | local_rank=-1, 205 | mixer='Global', 206 | local_mixer=[7, 11], 207 | HW=None, 208 | mlp_ratio=4., 209 | qkv_bias=False, 210 | qk_scale=None, 211 | drop=0., 212 | attn_drop=0., 213 | drop_path=0., 214 | act_layer=nn.GELU, 215 | norm_layer='nn.LayerNorm', 216 | epsilon=1e-6, 217 | prenorm=False): 218 | super().__init__() 219 | if isinstance(norm_layer, str): 220 | self.norm1 = eval(norm_layer)(dim, eps=epsilon) 221 | else: 222 | self.norm1 = norm_layer(dim) 223 | if mixer == 'Global' or mixer == 'Local': 224 | self.mixer = Attention( 225 | dim, 226 | local_rank=local_rank, 227 | num_heads=num_heads, 228 | mixer=mixer, 229 | HW=HW, 230 | local_k=local_mixer, 231 | qkv_bias=qkv_bias, 232 | qk_scale=qk_scale, 233 | attn_drop=attn_drop, 234 | proj_drop=drop) 235 | elif mixer == 'Conv': 236 | self.mixer = ConvMixer( 237 | dim, num_heads=num_heads, HW=HW, local_k=local_mixer) 238 | else: 239 | raise TypeError("The mixer must be one of [Global, Local, Conv]") 240 | 241 | self.drop_path = DropPath(local_rank, drop_path) if drop_path > 0. else Identity() 242 | if isinstance(norm_layer, str): 243 | self.norm2 = eval(norm_layer)(dim, eps=epsilon) 244 | else: 245 | self.norm2 = norm_layer(dim) 246 | mlp_hidden_dim = int(dim * mlp_ratio) 247 | self.mlp_ratio = mlp_ratio 248 | self.mlp = Mlp(in_features=dim, 249 | hidden_features=mlp_hidden_dim, 250 | act_layer=act_layer, 251 | drop=drop) 252 | self.prenorm = prenorm 253 | 254 | def forward(self, x): 255 | if self.prenorm: 256 | x = self.norm1(x + self.drop_path(self.mixer(x))) 257 | x = self.norm2(x + self.drop_path(self.mlp(x))) 258 | else: 259 | x = x + self.drop_path(self.mixer(self.norm1(x))) 260 | x = x + self.drop_path(self.mlp(self.norm2(x))) 261 | return x 262 | 263 | 264 | class PatchEmbed(nn.Module): 265 | """ Image to Patch Embedding 266 | """ 267 | 268 | def __init__(self, 269 | img_size=[32, 100], 270 | in_channels=3, 271 | embed_dim=768, 272 | sub_num=2, 273 | patch_size=[4, 4], 274 | mode='pope'): 275 | super().__init__() 276 | num_patches = (img_size[1] // (2 ** sub_num)) * \ 277 | (img_size[0] // (2 ** sub_num)) 278 | self.img_size = img_size 279 | self.num_patches = num_patches 280 | self.embed_dim = embed_dim 281 | self.norm = None 282 | if mode == 'pope': 283 | if sub_num == 2: 284 | self.proj = nn.Sequential( 285 | ConvBNLayer( 286 | in_channels=in_channels, 287 | out_channels=embed_dim // 2, 288 | kernel_size=3, 289 | stride=2, 290 | padding=1, 291 | act=nn.GELU, 292 | bias_attr=False), 293 | ConvBNLayer( 294 | in_channels=embed_dim // 2, 295 | out_channels=embed_dim, 296 | kernel_size=3, 297 | stride=2, 298 | padding=1, 299 | act=nn.GELU, 300 | bias_attr=False)) 301 | if sub_num == 3: 302 | self.proj = nn.Sequential( 303 | ConvBNLayer( 304 | in_channels=in_channels, 305 | out_channels=embed_dim // 4, 306 | kernel_size=3, 307 | stride=2, 308 | padding=1, 309 | act=nn.GELU, 310 | bias_attr=False), 311 | ConvBNLayer( 312 | in_channels=embed_dim // 4, 313 | out_channels=embed_dim // 2, 314 | kernel_size=3, 315 | stride=2, 316 | padding=1, 317 | act=nn.GELU, 318 | bias_attr=False), 319 | ConvBNLayer( 320 | in_channels=embed_dim // 2, 321 | out_channels=embed_dim, 322 | kernel_size=3, 323 | stride=2, 324 | padding=1, 325 | act=nn.GELU, 326 | bias_attr=False)) 327 | elif mode == 'linear': 328 | self.proj = nn.Conv2d( 329 | 1, embed_dim, kernel_size=patch_size, stride=patch_size) 330 | self.num_patches = img_size[0] // patch_size[0] * img_size[ 331 | 1] // patch_size[1] 332 | 333 | def forward(self, x): 334 | B, C, H, W = x.shape 335 | assert H == self.img_size[0] and W == self.img_size[1], \ 336 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 337 | x = self.proj(x).flatten(2).permute((0, 2, 1)).contiguous() 338 | return x 339 | 340 | 341 | class SubSample(nn.Module): 342 | def __init__(self, 343 | in_channels, 344 | out_channels, 345 | types='Pool', 346 | stride=[2, 1], 347 | sub_norm='nn.LayerNorm', 348 | act=None): 349 | super().__init__() 350 | self.types = types 351 | if types == 'Pool': 352 | self.avgpool = nn.AvgPool2d( 353 | kernel_size=[3, 5], stride=stride, padding=[1, 2]) 354 | self.maxpool = nn.MaxPool2d( 355 | kernel_size=[3, 5], stride=stride, padding=[1, 2]) 356 | self.proj = nn.Linear(in_channels, out_channels) 357 | else: 358 | self.conv = nn.Conv2d( 359 | in_channels, 360 | out_channels, 361 | kernel_size=3, 362 | stride=stride, 363 | padding=1) 364 | 365 | self.norm = eval(sub_norm)(out_channels) 366 | if act is not None: 367 | self.act = act() 368 | else: 369 | self.act = None 370 | 371 | def forward(self, x): 372 | 373 | if self.types == 'Pool': 374 | x1 = self.avgpool(x) 375 | x2 = self.maxpool(x) 376 | x = (x1 + x2) * 0.5 377 | out = self.proj(x.flatten(2).permute((0, 2, 1))).contiguous() 378 | else: 379 | x = self.conv(x) 380 | out = x.flatten(2).permute((0, 2, 1)).contiguous() 381 | out = self.norm(out) 382 | if self.act is not None: 383 | out = self.act(out) 384 | 385 | return out 386 | 387 | 388 | class SVTRNet(nn.Module): 389 | def __init__( 390 | self, 391 | img_size=[32, 100], 392 | in_channels=3, 393 | embed_dim=[64, 128, 256], 394 | depth=[3, 6, 3], 395 | num_heads=[2, 4, 8], 396 | mixer=['Local'] * 6 + ['Global'] * 6, # Local atten, Global atten, Conv 397 | local_mixer=[[7, 11], [7, 11], [7, 11]], 398 | patch_merging='Conv', # Conv, Pool, None 399 | mlp_ratio=4, 400 | qkv_bias=True, 401 | qk_scale=None, 402 | drop_rate=0., 403 | last_drop=0.1, 404 | attn_drop_rate=0., 405 | drop_path_rate=0.1, 406 | norm_layer='nn.LayerNorm', 407 | sub_norm='nn.LayerNorm', 408 | epsilon=1e-6, 409 | out_channels=192, 410 | out_char_num=25, 411 | block_unit='Block', 412 | act='nn.GELU', 413 | last_stage=True, 414 | sub_num=2, 415 | prenorm=False, 416 | use_lenhead=False, 417 | local_rank=-1, 418 | **kwargs): 419 | super().__init__() 420 | self.img_size = img_size 421 | self.embed_dim = embed_dim 422 | self.out_channels = out_channels 423 | self.prenorm = prenorm 424 | patch_merging = None if patch_merging != 'Conv' and patch_merging != 'Pool' else patch_merging 425 | self.patch_embed = PatchEmbed( 426 | img_size=img_size, 427 | in_channels=in_channels, 428 | embed_dim=embed_dim[0], 429 | sub_num=sub_num) 430 | num_patches = self.patch_embed.num_patches 431 | self.HW = [img_size[0] // (2 ** sub_num), img_size[1] // (2 ** sub_num)] 432 | self.pos_embed = nn.Parameter(torch.zeros([1, num_patches, embed_dim[0]])).to(local_rank) 433 | self.pos_drop = nn.Dropout(p=drop_rate) 434 | Block_unit = eval(block_unit) 435 | 436 | dpr = np.linspace(0, drop_path_rate, sum(depth)) 437 | self.blocks1 = nn.ModuleList([ 438 | Block_unit( 439 | dim=embed_dim[0], 440 | num_heads=num_heads[0], 441 | local_rank=local_rank, 442 | mixer=mixer[0:depth[0]][i], 443 | HW=self.HW, 444 | local_mixer=local_mixer[0], 445 | mlp_ratio=mlp_ratio, 446 | qkv_bias=qkv_bias, 447 | qk_scale=qk_scale, 448 | drop=drop_rate, 449 | act_layer=eval(act), 450 | attn_drop=attn_drop_rate, 451 | drop_path=dpr[0:depth[0]][i], 452 | norm_layer=norm_layer, 453 | epsilon=epsilon, 454 | prenorm=prenorm) for i in range(depth[0]) 455 | ]) 456 | if patch_merging is not None: 457 | self.sub_sample1 = SubSample( 458 | embed_dim[0], 459 | embed_dim[1], 460 | sub_norm=sub_norm, 461 | stride=[2, 1], 462 | types=patch_merging) 463 | HW = [self.HW[0] // 2, self.HW[1]] 464 | else: 465 | HW = self.HW 466 | self.patch_merging = patch_merging 467 | self.blocks2 = nn.ModuleList([ 468 | Block_unit( 469 | dim=embed_dim[1], 470 | num_heads=num_heads[1], 471 | local_rank=local_rank, 472 | mixer=mixer[depth[0]:depth[0] + depth[1]][i], 473 | HW=HW, 474 | local_mixer=local_mixer[1], 475 | mlp_ratio=mlp_ratio, 476 | qkv_bias=qkv_bias, 477 | qk_scale=qk_scale, 478 | drop=drop_rate, 479 | act_layer=eval(act), 480 | attn_drop=attn_drop_rate, 481 | drop_path=dpr[depth[0]:depth[0] + depth[1]][i], 482 | norm_layer=norm_layer, 483 | epsilon=epsilon, 484 | prenorm=prenorm) for i in range(depth[1]) 485 | ]) 486 | if patch_merging is not None: 487 | self.sub_sample2 = SubSample( 488 | embed_dim[1], 489 | embed_dim[2], 490 | sub_norm=sub_norm, 491 | stride=[2, 1], 492 | types=patch_merging) 493 | HW = [self.HW[0] // 4, self.HW[1]] 494 | else: 495 | HW = self.HW 496 | self.blocks3 = nn.ModuleList([ 497 | Block_unit( 498 | dim=embed_dim[2], 499 | num_heads=num_heads[2], 500 | local_rank=local_rank, 501 | mixer=mixer[depth[0] + depth[1]:][i], 502 | HW=HW, 503 | local_mixer=local_mixer[2], 504 | mlp_ratio=mlp_ratio, 505 | qkv_bias=qkv_bias, 506 | qk_scale=qk_scale, 507 | drop=drop_rate, 508 | act_layer=eval(act), 509 | attn_drop=attn_drop_rate, 510 | drop_path=dpr[depth[0] + depth[1]:][i], 511 | norm_layer=norm_layer, 512 | epsilon=epsilon, 513 | prenorm=prenorm) for i in range(depth[2]) 514 | ]) 515 | self.last_stage = last_stage 516 | if last_stage: 517 | self.avg_pool = nn.AdaptiveAvgPool2d([1, out_char_num]) 518 | self.last_conv = nn.Conv2d( 519 | in_channels=embed_dim[2], 520 | out_channels=self.out_channels, 521 | kernel_size=1, 522 | stride=1, 523 | padding=0, 524 | bias=False) 525 | self.hardswish = nn.Hardswish() 526 | self.dropout = nn.Dropout(p=last_drop) 527 | if not prenorm: 528 | self.norm = eval(norm_layer)(embed_dim[-1], eps=epsilon) 529 | self.use_lenhead = use_lenhead 530 | if use_lenhead: 531 | self.len_conv = nn.Linear(embed_dim[2], self.out_channels) 532 | self.hardswish_len = nn.Hardswish() 533 | self.dropout_len = nn.Dropout(p=last_drop) 534 | 535 | # init.kaiming_uniform_(self.pos_embed) 536 | self.pos_embed.data = truncated_normal_(self.pos_embed.data) 537 | self.apply(self._init_weights) 538 | 539 | def _init_weights(self, m): 540 | if isinstance(m, nn.Linear): 541 | # init.kaiming_uniform_(m.weight) 542 | m.weight.data = truncated_normal_(m.weight.data) 543 | if isinstance(m, nn.Linear) and m.bias is not None: 544 | init.zeros_(m.bias) 545 | elif isinstance(m, nn.LayerNorm): 546 | init.zeros_(m.bias) 547 | init.ones_(m.weight) 548 | elif isinstance(m, nn.Conv2d): 549 | init.kaiming_uniform_(m.weight) 550 | 551 | def forward_features(self, x): 552 | x = self.patch_embed(x) 553 | x = x + self.pos_embed 554 | x = self.pos_drop(x) 555 | for blk in self.blocks1: 556 | x = blk(x) 557 | if self.patch_merging is not None: 558 | x = self.sub_sample1( 559 | x.permute([0, 2, 1]).reshape( 560 | [-1, self.embed_dim[0], self.HW[0], self.HW[1]])).contiguous() 561 | for blk in self.blocks2: 562 | x = blk(x) 563 | if self.patch_merging is not None: 564 | x = self.sub_sample2( 565 | x.permute([0, 2, 1]).reshape( 566 | [-1, self.embed_dim[1], self.HW[0] // 2, self.HW[1]])).contiguous() 567 | for blk in self.blocks3: 568 | x = blk(x) 569 | if not self.prenorm: 570 | x = self.norm(x) 571 | return x 572 | 573 | def forward(self, x): 574 | x = self.forward_features(x) 575 | if self.use_lenhead: 576 | len_x = self.len_conv(x.mean(1)) 577 | len_x = self.dropout_len(self.hardswish_len(len_x)) 578 | if self.last_stage: 579 | if self.patch_merging is not None: 580 | h = self.HW[0] // 4 581 | else: 582 | h = self.HW[0] 583 | x = self.avg_pool( 584 | x.permute([0, 2, 1]).reshape([-1, self.embed_dim[2], h, self.HW[1]])).contiguous() 585 | x = self.last_conv(x) 586 | x = self.hardswish(x) 587 | x = self.dropout(x) # bchw 588 | x = x.permute(0, 3, 1, 2).contiguous() # bwch 589 | x = x.squeeze(3) 590 | if self.use_lenhead: 591 | return x, len_x 592 | return x 593 | 594 | if __name__ == '__main__': 595 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" # 可修改 596 | model = SVTRNet(img_size=[32, 100], 597 | in_channels=3, 598 | embed_dim=[64, 128, 256], 599 | depth=[3, 6, 3], 600 | num_heads=[2, 4, 8], 601 | mixer=['Local'] * 6 + ['Global'] * 602 | 6, # Local atten, Global atten, Conv 603 | local_mixer=[[7, 11], [7, 11], [7, 11]], 604 | patch_merging='Conv', # Conv, Pool, None 605 | mlp_ratio=4, 606 | qkv_bias=True, 607 | qk_scale=None, 608 | drop_rate=0., 609 | last_drop=0.1, 610 | attn_drop_rate=0., 611 | drop_path_rate=0.1, 612 | norm_layer='nn.LayerNorm', 613 | sub_norm='nn.LayerNorm', 614 | epsilon=1e-6, 615 | out_channels=192, 616 | out_char_num=25, 617 | block_unit='Block', 618 | act='nn.GELU', 619 | last_stage=True, 620 | sub_num=2, 621 | prenorm=False, 622 | use_lenhead=False, 623 | local_rank=device).to(device) 624 | # model = SVTRNet(img_size=[32, 100], 625 | # in_channels=3, 626 | # embed_dim=[192, 256, 512], 627 | # depth=[3, 9, 9], 628 | # num_heads=[6, 8, 16], 629 | # mixer=['Local'] * 10 + ['Global'] * 630 | # 11, # Local atten, Global atten, Conv 631 | # local_mixer=[[7, 11], [7, 11], [7, 11]], 632 | # patch_merging='Conv', # Conv, Pool, None 633 | # mlp_ratio=4, 634 | # qkv_bias=True, 635 | # qk_scale=None, 636 | # drop_rate=0., 637 | # last_drop=0.1, 638 | # attn_drop_rate=0., 639 | # drop_path_rate=0.1, 640 | # norm_layer='nn.LayerNorm', 641 | # sub_norm='nn.LayerNorm', 642 | # epsilon=1e-6, 643 | # out_channels=384, 644 | # out_char_num=25, 645 | # block_unit='Block', 646 | # act='nn.GELU', 647 | # last_stage=True, 648 | # sub_num=2, 649 | # prenorm=False, 650 | # use_lenhead=False, 651 | # local_rank=device).to(device) 652 | model.eval() 653 | with torch.no_grad(): 654 | start = time.time() 655 | for i in range(5): 656 | # print(model(a).shape) 657 | a = torch.randn(1, 3, 32, 100).to(device) 658 | y = model(a) 659 | infer_time = time.time() - start 660 | print(infer_time / 5) 661 | # print(model(a).shape) 662 | print("Parameter numbers: {}".format(sum(p.numel() for p in model.parameters()))) 663 | from thop import profile 664 | 665 | input = torch.randn(1, 3, 32, 100) 666 | flops, params = profile(model, inputs=(input,)) 667 | print('GFLOPs:', flops/1e9) 668 | # a = torch.randn(1, 3, 32, 100).to(device) 669 | # print(model(a).shape) 670 | # 671 | # print("Parameter numbers: {}".format(sum(p.numel() for p in model.parameters()))) 672 | --------------------------------------------------------------------------------