├── src ├── __init__.py └── baidudataset.py ├── modules ├── __init__.py ├── optimizer │ ├── __init__.py │ └── ranger.py ├── sequence_modeling.py ├── prediction.py ├── resnet_aster.py ├── transformation.py ├── bert.py ├── resnet_fpn.py ├── feature_extraction.py ├── SRN_Resnet.py └── SRN_modules.py ├── demo_image ├── SRN.png ├── demo_1.png ├── demo_10.jpg ├── demo_2.jpg ├── demo_3.png ├── demo_4.png ├── demo_5.png ├── demo_6.png ├── demo_7.png ├── demo_8.jpg └── demo_9.jpg ├── dataset ├── VAL.py └── alphabet.py ├── README.md ├── .gitignore ├── model.py ├── demo.py ├── utils.py ├── test.py ├── dataset.py └── train.py /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /modules/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /demo_image/SRN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenjun2hao/SRN.pytorch/HEAD/demo_image/SRN.png -------------------------------------------------------------------------------- /demo_image/demo_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenjun2hao/SRN.pytorch/HEAD/demo_image/demo_1.png -------------------------------------------------------------------------------- /demo_image/demo_10.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenjun2hao/SRN.pytorch/HEAD/demo_image/demo_10.jpg -------------------------------------------------------------------------------- /demo_image/demo_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenjun2hao/SRN.pytorch/HEAD/demo_image/demo_2.jpg -------------------------------------------------------------------------------- /demo_image/demo_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenjun2hao/SRN.pytorch/HEAD/demo_image/demo_3.png -------------------------------------------------------------------------------- /demo_image/demo_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenjun2hao/SRN.pytorch/HEAD/demo_image/demo_4.png -------------------------------------------------------------------------------- /demo_image/demo_5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenjun2hao/SRN.pytorch/HEAD/demo_image/demo_5.png -------------------------------------------------------------------------------- /demo_image/demo_6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenjun2hao/SRN.pytorch/HEAD/demo_image/demo_6.png -------------------------------------------------------------------------------- /demo_image/demo_7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenjun2hao/SRN.pytorch/HEAD/demo_image/demo_7.png -------------------------------------------------------------------------------- /demo_image/demo_8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenjun2hao/SRN.pytorch/HEAD/demo_image/demo_8.jpg -------------------------------------------------------------------------------- /demo_image/demo_9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenjun2hao/SRN.pytorch/HEAD/demo_image/demo_9.jpg -------------------------------------------------------------------------------- /dataset/VAL.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | 4 | path = '/home/yangna/deepblue/OCR/data/Baidu/train_images/' 5 | tpath = './dataset/BAIDU/images/' 6 | with open('./dataset/BAIDU/small_train.txt') as f: 7 | datas = f.readlines() 8 | 9 | for data in datas: 10 | name = data.rstrip().split('\t')[0] 11 | src = os.path.join(path, name) 12 | target = os.path.join(tpath, name) 13 | os.system('cp {} {}'.format(src, target)) -------------------------------------------------------------------------------- /dataset/alphabet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | # datas = glob.glob(r'/home/yangna/deepblue/OCR/EAST2/ICDAR_2015/ch4_training_localization_transcription_gt/*.txt') 5 | path = './dataset/BAIDU/train.list' 6 | alphabet = [] 7 | 8 | with open(path) as f: 9 | perdata = f.readlines() 10 | 11 | for pd in perdata: 12 | pd = pd.rstrip().split('\t') 13 | alphabet += pd[-1] 14 | 15 | temp = ''.join(set(alphabet)) 16 | 17 | with open('./dataset/BAIDU/baidu_alphabet.txt', 'w') as outf: 18 | outf.write(temp) -------------------------------------------------------------------------------- /modules/sequence_modeling.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BidirectionalLSTM(nn.Module): 5 | 6 | def __init__(self, input_size, hidden_size, output_size): 7 | super(BidirectionalLSTM, self).__init__() 8 | self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True) 9 | self.linear = nn.Linear(hidden_size * 2, output_size) 10 | 11 | def forward(self, input): 12 | """ 13 | input : visual feature [batch_size x T x input_size] 14 | output : contextual feature [batch_size x T x output_size] 15 | """ 16 | self.rnn.flatten_parameters() 17 | recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) 18 | output = self.linear(recurrent) # batch_size x T x output_size 19 | return output 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Towards Accurate Scene Text Recognition with Semantic Reasoning Networks 2 | 3 | Unofficial PyTorch implementation of the [paper](https://arxiv.org/abs/2003.12294), which integrates not only global semantic reasoning module but also parallel visual attention module and visual-semantic fusion decoder.the semanti reasoning network(SRN) can be trained end-to-end. 4 | 5 | At present, the accuracy of the paper cannot be achieved. And i borrowed code from [deep-text-recognition-benchmark](https://github.com/clovaai/deep-text-recognition-benchmark) 6 | 7 | **model** 8 | 9 | 10 | **result** 11 | | IIIT5k_3000 | SVT | IC03_860 | IC03_867 | IC13_857 | IC13_1015 | IC15_1811 | IC15_2077 | SVTP | CUTE80 | 12 | | ----------- | ------| ---------| ---------| ---------| --------- | ----------| --------- | ---- | ------ | 13 | | 84.600 | 83.617| 92.907 | 92.849 | 90.315 | 88.177 | 71.010 | 68.064 | 71.008 | 68.641 | 14 | 15 | **total_accuracy: 80.597** 16 | 17 | --- 18 | 19 | **Feature** 20 | - predict the character at once time 21 | - DistributedDataParallel training 22 | 23 | 24 | 25 | 26 | --- 27 | ## Requirements 28 | Pytorch >= 1.1.0 29 | 30 | 31 | ## Test 32 | 1. download the evaluation data from [deep-text-recognition-benchmark](https://github.com/clovaai/deep-text-recognition-benchmark) 33 | 34 | 2. download the pretrained model from [Baidu](https://pan.baidu.com/s/1E5xeajIl_fvtrGWyrE9CeA), Password: d2qn 35 | 36 | 3. test on the evaluation data 37 | ```bash 38 | python test.py --eval_data path-to-data --saved_model path-to-model 39 | ``` 40 | 41 | --- 42 | 43 | ## Train 44 | 1. download the training data from [deep-text-recognition-benchmark](https://github.com/clovaai/deep-text-recognition-benchmark) 45 | 46 | 2. training from scratch 47 | ```bash 48 | python train.py --train_data path-to-train-data --valid-data path-to-valid-data 49 | ``` 50 | 51 | ## Reference 52 | 1. [bert_ocr.pytorch](https://github.com/chenjun2hao/Bert_OCR.pytorch) 53 | 2. [deep-text-recognition-benchmark](https://github.com/clovaai/deep-text-recognition-benchmark) 54 | 3. [2D Attentional Irregular Scene Text Recognizer](https://arxiv.org/pdf/1906.05708.pdf) 55 | 4. [Towards Accurate Scene Text Recognition with Semantic Reasoning Networks](https://arxiv.org/abs/2003.12294) 56 | 57 | ## difference with the origin paper 58 | - use resnet for 1D feature not resnetFpn 2D feature 59 | - use add not gated unit for visual-semanti fusion decoder 60 | 61 | ## other 62 | It is difficult to achieve the accuracy of the paper, hope more people to try and share -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/saved_models/* 2 | **/data_lmdb_release/* 3 | **/image_release/* 4 | **/result/* 5 | **/model/* 6 | *.mdb 7 | *.pth 8 | *.tar 9 | *.sh 10 | *.txt 11 | *.ipynb 12 | *.zip 13 | *.eps 14 | *.pdf 15 | 16 | ### Linux ### 17 | *~ 18 | 19 | # temporary files which can be created if a process still has a handle open of a deleted file 20 | .fuse_hidden* 21 | 22 | # KDE directory preferences 23 | .directory 24 | 25 | # Linux trash folder which might appear on any partition or disk 26 | .Trash-* 27 | 28 | # .nfs files are created when an open file is removed but is still being accessed 29 | .nfs* 30 | 31 | ### OSX ### 32 | # General 33 | .DS_Store 34 | .AppleDouble 35 | .LSOverride 36 | 37 | # Icon must end with two \r 38 | Icon 39 | 40 | # Thumbnails 41 | ._* 42 | 43 | # Files that might appear in the root of a volume 44 | .DocumentRevisions-V100 45 | .fseventsd 46 | .Spotlight-V100 47 | .TemporaryItems 48 | .Trashes 49 | .VolumeIcon.icns 50 | .com.apple.timemachine.donotpresent 51 | 52 | # Directories potentially created on remote AFP share 53 | .AppleDB 54 | .AppleDesktop 55 | Network Trash Folder 56 | Temporary Items 57 | .apdisk 58 | 59 | ### Python ### 60 | # Byte-compiled / optimized / DLL files 61 | __pycache__/ 62 | *.py[cod] 63 | *$py.class 64 | 65 | # C extensions 66 | *.so 67 | 68 | # Distribution / packaging 69 | .Python 70 | build/ 71 | develop-eggs/ 72 | dist/ 73 | downloads/ 74 | eggs/ 75 | .eggs/ 76 | lib/ 77 | lib64/ 78 | parts/ 79 | sdist/ 80 | var/ 81 | wheels/ 82 | *.egg-info/ 83 | .installed.cfg 84 | *.egg 85 | MANIFEST 86 | 87 | # PyInstaller 88 | # Usually these files are written by a python script from a template 89 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 90 | *.manifest 91 | *.spec 92 | 93 | # Installer logs 94 | pip-log.txt 95 | pip-delete-this-directory.txt 96 | 97 | # Unit test / coverage reports 98 | htmlcov/ 99 | .tox/ 100 | .coverage 101 | .coverage.* 102 | .cache 103 | nosetests.xml 104 | coverage.xml 105 | *.cover 106 | .hypothesis/ 107 | .pytest_cache/ 108 | 109 | # Translations 110 | *.mo 111 | *.pot 112 | 113 | # Django stuff: 114 | *.log 115 | local_settings.py 116 | db.sqlite3 117 | 118 | # Flask stuff: 119 | instance/ 120 | .webassets-cache 121 | 122 | # Scrapy stuff: 123 | .scrapy 124 | 125 | # Sphinx documentation 126 | docs/_build/ 127 | 128 | # PyBuilder 129 | target/ 130 | 131 | # Jupyter Notebook 132 | .ipynb_checkpoints 133 | 134 | # IPython 135 | profile_default/ 136 | ipython_config.py 137 | 138 | # pyenv 139 | .python-version 140 | 141 | # celery beat schedule file 142 | celerybeat-schedule 143 | 144 | # SageMath parsed files 145 | *.sage.py 146 | 147 | # Environments 148 | .env 149 | .venv 150 | env/ 151 | venv/ 152 | ENV/ 153 | env.bak/ 154 | venv.bak/ 155 | 156 | # Spyder project settings 157 | .spyderproject 158 | .spyproject 159 | 160 | # Rope project settings 161 | .ropeproject 162 | 163 | # mkdocs documentation 164 | /site 165 | 166 | # mypy 167 | .mypy_cache/ 168 | .dmypy.json 169 | dmypy.json 170 | 171 | ### Python Patch ### 172 | .venv/ 173 | 174 | ### Python.VirtualEnv Stack ### 175 | # Virtualenv 176 | # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ 177 | [Bb]in 178 | [Ii]nclude 179 | [Ll]ib 180 | [Ll]ib64 181 | [Ll]ocal 182 | [Ss]cripts 183 | pyvenv.cfg 184 | pip-selfcheck.json 185 | 186 | ### Windows ### 187 | # Windows thumbnail cache files 188 | Thumbs.db 189 | ehthumbs.db 190 | ehthumbs_vista.db 191 | 192 | # Dump file 193 | *.stackdump 194 | 195 | # Folder config file 196 | [Dd]esktop.ini 197 | 198 | # Recycle Bin used on file shares 199 | $RECYCLE.BIN/ 200 | 201 | # Windows Installer files 202 | *.cab 203 | *.msi 204 | *.msix 205 | *.msm 206 | *.msp 207 | 208 | # Windows shortcuts 209 | *.lnk 210 | 211 | .idea/ 212 | .vscode/ -------------------------------------------------------------------------------- /modules/prediction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Attention(nn.Module): 7 | 8 | def __init__(self, input_size, hidden_size, num_classes): 9 | super(Attention, self).__init__() 10 | self.attention_cell = AttentionCell(input_size, hidden_size, num_classes) 11 | self.hidden_size = hidden_size 12 | self.num_classes = num_classes 13 | self.generator = nn.Linear(hidden_size, num_classes) 14 | 15 | def _char_to_onehot(self, input_char, onehot_dim=38): 16 | input_char = input_char.unsqueeze(1) 17 | batch_size = input_char.size(0) 18 | one_hot = torch.cuda.FloatTensor(batch_size, onehot_dim).zero_() 19 | one_hot = one_hot.scatter_(1, input_char, 1) 20 | return one_hot 21 | 22 | def forward(self, batch_H, text, is_train=True, batch_max_length=25): 23 | """ 24 | input: 25 | batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x num_classes] 26 | text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO]. 27 | output: probability distribution at each step [batch_size x num_steps x num_classes] 28 | """ 29 | batch_size = batch_H.size(0) 30 | num_steps = batch_max_length + 1 # +1 for [s] at end of sentence. 31 | 32 | output_hiddens = torch.cuda.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0) 33 | hidden = (torch.cuda.FloatTensor(batch_size, self.hidden_size).fill_(0), 34 | torch.cuda.FloatTensor(batch_size, self.hidden_size).fill_(0)) 35 | 36 | if is_train: 37 | for i in range(num_steps): 38 | # one-hot vectors for a i-th char. in a batch 39 | char_onehots = self._char_to_onehot(text[:, i], onehot_dim=self.num_classes) 40 | # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_onehots : one-hot(y_{t-1}) 41 | hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots) 42 | output_hiddens[:, i, :] = hidden[0] # LSTM hidden index (0: hidden, 1: Cell) 43 | probs = self.generator(output_hiddens) 44 | 45 | else: 46 | targets = torch.cuda.LongTensor(batch_size).fill_(0) # [GO] token 47 | probs = torch.cuda.FloatTensor(batch_size, num_steps, self.num_classes).fill_(0) 48 | 49 | for i in range(num_steps): 50 | char_onehots = self._char_to_onehot(targets, onehot_dim=self.num_classes) 51 | hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots) 52 | probs_step = self.generator(hidden[0]) 53 | probs[:, i, :] = probs_step 54 | _, next_input = probs_step.max(1) 55 | targets = next_input 56 | 57 | return probs # batch_size x num_steps x num_classes 58 | 59 | 60 | class AttentionCell(nn.Module): 61 | 62 | def __init__(self, input_size, hidden_size, num_embeddings): 63 | super(AttentionCell, self).__init__() 64 | self.i2h = nn.Linear(input_size, hidden_size, bias=False) 65 | self.h2h = nn.Linear(hidden_size, hidden_size) # either i2i or h2h should have bias 66 | self.score = nn.Linear(hidden_size, 1, bias=False) 67 | self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size) 68 | self.hidden_size = hidden_size 69 | 70 | def forward(self, prev_hidden, batch_H, char_onehots): 71 | # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size] 72 | batch_H_proj = self.i2h(batch_H) 73 | prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1) 74 | e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1 75 | 76 | alpha = F.softmax(e, dim=1) 77 | context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel 78 | concat_context = torch.cat([context, char_onehots], 1) # batch_size x (num_channel + num_embedding) 79 | cur_hidden = self.rnn(concat_context, prev_hidden) 80 | return cur_hidden, alpha 81 | -------------------------------------------------------------------------------- /modules/resnet_aster.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | import sys 6 | import math 7 | 8 | # from config import get_args 9 | # global_args = get_args(sys.argv[1:]) 10 | 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | """3x3 convolution with padding""" 14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 15 | padding=1, bias=False) 16 | 17 | 18 | def conv1x1(in_planes, out_planes, stride=1): 19 | """1x1 convolution""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 21 | 22 | 23 | def get_sinusoid_encoding(n_position, feat_dim, wave_length=10000): 24 | # [n_position] 25 | positions = torch.arange(0, n_position)#.cuda() 26 | # [feat_dim] 27 | dim_range = torch.arange(0, feat_dim)#.cuda() 28 | dim_range = torch.pow(wave_length, 2 * (dim_range // 2) / feat_dim) 29 | # [n_position, feat_dim] 30 | angles = positions.unsqueeze(1) / dim_range.unsqueeze(0) 31 | angles = angles.float() 32 | angles[:, 0::2] = torch.sin(angles[:, 0::2]) 33 | angles[:, 1::2] = torch.cos(angles[:, 1::2]) 34 | return angles 35 | 36 | 37 | class AsterBlock(nn.Module): 38 | 39 | def __init__(self, inplanes, planes, stride=1, downsample=None): 40 | super(AsterBlock, self).__init__() 41 | self.conv1 = conv1x1(inplanes, planes, stride) 42 | self.bn1 = nn.BatchNorm2d(planes) 43 | self.relu = nn.ReLU(inplace=True) 44 | self.conv2 = conv3x3(planes, planes) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.downsample = downsample 47 | self.stride = stride 48 | 49 | def forward(self, x): 50 | residual = x 51 | out = self.conv1(x) 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | out = self.conv2(out) 55 | out = self.bn2(out) 56 | 57 | if self.downsample is not None: 58 | residual = self.downsample(x) 59 | out += residual 60 | out = self.relu(out) 61 | return out 62 | 63 | 64 | class ResNet_ASTER(nn.Module): 65 | """For aster or crnn 66 | borrowed from: https://github.com/ayumiymk/aster.pytorch 67 | """ 68 | def __init__(self, in_channels=1, out_channel=512, n_group=1): 69 | super(ResNet_ASTER, self).__init__() 70 | self.n_group = n_group 71 | 72 | in_channels = in_channels 73 | self.layer0 = nn.Sequential( 74 | nn.Conv2d(in_channels, 32, kernel_size=(3, 3), stride=1, padding=1, bias=False), 75 | nn.BatchNorm2d(32), 76 | nn.ReLU(inplace=True)) 77 | 78 | self.inplanes = 32 79 | self.layer1 = self._make_layer(32, 3, [2, 2]) # [16, 50] 80 | self.layer2 = self._make_layer(64, 4, [2, 2]) # [8, 25] 81 | self.layer3 = self._make_layer(128, 6, [2, 2]) # [4, 25] 82 | self.layer4 = self._make_layer(256, 6, [1, ]) # [2, 25] 83 | self.layer5 = self._make_layer(out_channel, 3, [1, 1]) # [1, 25] 84 | 85 | for m in self.modules(): 86 | if isinstance(m, nn.Conv2d): 87 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 88 | elif isinstance(m, nn.BatchNorm2d): 89 | nn.init.constant_(m.weight, 1) 90 | nn.init.constant_(m.bias, 0) 91 | 92 | def _make_layer(self, planes, blocks, stride): 93 | downsample = None 94 | if stride != [1, 1] or self.inplanes != planes: 95 | downsample = nn.Sequential( 96 | conv1x1(self.inplanes, planes, stride), 97 | nn.BatchNorm2d(planes)) 98 | 99 | layers = [] 100 | layers.append(AsterBlock(self.inplanes, planes, stride, downsample)) 101 | self.inplanes = planes 102 | for _ in range(1, blocks): 103 | layers.append(AsterBlock(self.inplanes, planes)) 104 | return nn.Sequential(*layers) 105 | 106 | def forward(self, x): 107 | 108 | x0 = self.layer0(x) 109 | x1 = self.layer1(x0) 110 | x2 = self.layer2(x1) 111 | x3 = self.layer3(x2) 112 | x4 = self.layer4(x3) 113 | x5 = self.layer5(x4) 114 | 115 | return x5 116 | 117 | 118 | def numel(model): 119 | return sum(p.numel() for p in model.parameters()) 120 | 121 | if __name__ == "__main__": 122 | x = torch.randn(3, 1, 64, 256) 123 | net = ResNet_ASTER() 124 | encoder_feat = net(x) 125 | print(encoder_feat.size()) # 3*512*h/4*w/4 126 | 127 | num_params = numel(net) 128 | print(f'Number of parameters: {num_params}') -------------------------------------------------------------------------------- /modules/optimizer/ranger.py: -------------------------------------------------------------------------------- 1 | #Ranger deep learning optimizer - RAdam + Lookahead combined. 2 | #https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer 3 | 4 | import math 5 | import torch 6 | from torch.optim.optimizer import Optimizer, required 7 | import itertools as it 8 | #from torch.optim import Optimizer 9 | #credit - Lookahead implementation from LonePatient - https://github.com/lonePatient/lookahead_pytorch/blob/master/optimizer.py 10 | #credit2 - RAdam code by https://github.com/LiyuanLucasLiu/RAdam/blob/master/radam.py 11 | #changes 8/31/19 - fix references to *self*.N_sma_threshold; 12 | #changed eps to 1e-5 as better default than 1e-8. 13 | 14 | class Ranger(Optimizer): 15 | 16 | def __init__(self, params, lr=1e-3, alpha=0.5, k=6, N_sma_threshhold=5, betas=(.95,0.999), eps=1e-5, weight_decay=0): 17 | #parameter checks 18 | if not 0.0 <= alpha <= 1.0: 19 | raise ValueError(f'Invalid slow update rate: {alpha}') 20 | if not 1 <= k: 21 | raise ValueError(f'Invalid lookahead steps: {k}') 22 | if not lr > 0: 23 | raise ValueError(f'Invalid Learning Rate: {lr}') 24 | if not eps > 0: 25 | raise ValueError(f'Invalid eps: {eps}') 26 | 27 | #parameter comments: 28 | # beta1 (momentum) of .95 seems to work better than .90... 29 | #N_sma_threshold of 5 seems better in testing than 4. 30 | #In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you. 31 | 32 | #prep defaults and init torch.optim base 33 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 34 | super().__init__(params,defaults) 35 | 36 | #adjustable threshold 37 | self.N_sma_threshhold = N_sma_threshhold 38 | 39 | #now we can get to work... 40 | for group in self.param_groups: 41 | group["step_counter"] = 0 42 | #print("group step counter init") 43 | 44 | #look ahead params 45 | self.alpha = alpha 46 | self.k = k 47 | 48 | #radam buffer for state 49 | self.radam_buffer = [[None,None,None] for ind in range(10)] 50 | 51 | #lookahead weights 52 | self.slow_weights = [[p.clone().detach() for p in group['params']] 53 | for group in self.param_groups] 54 | 55 | #don't use grad for lookahead weights 56 | for w in it.chain(*self.slow_weights): 57 | w.requires_grad = False 58 | 59 | def __setstate__(self, state): 60 | print("set state called") 61 | super(Ranger, self).__setstate__(state) 62 | 63 | 64 | def step(self, closure=None): 65 | loss = None 66 | #note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure. 67 | #Uncomment if you need to use the actual closure... 68 | 69 | #if closure is not None: 70 | #loss = closure() 71 | 72 | #------------ radam 73 | for group in self.param_groups: 74 | 75 | for p in group['params']: 76 | if p.grad is None: 77 | continue 78 | grad = p.grad.data.float() 79 | if grad.is_sparse: 80 | raise RuntimeError('RAdam does not support sparse gradients') 81 | 82 | p_data_fp32 = p.data.float() 83 | 84 | state = self.state[p] 85 | 86 | if len(state) == 0: 87 | state['step'] = 0 88 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 89 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 90 | else: 91 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 92 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 93 | 94 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 95 | beta1, beta2 = group['betas'] 96 | 97 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 98 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 99 | 100 | state['step'] += 1 101 | buffered = self.radam_buffer[int(state['step'] % 10)] 102 | if state['step'] == buffered[0]: 103 | N_sma, step_size = buffered[1], buffered[2] 104 | else: 105 | buffered[0] = state['step'] 106 | beta2_t = beta2 ** state['step'] 107 | N_sma_max = 2 / (1 - beta2) - 1 108 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 109 | buffered[1] = N_sma 110 | if N_sma > self.N_sma_threshhold: 111 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 112 | else: 113 | step_size = 1.0 / (1 - beta1 ** state['step']) 114 | buffered[2] = step_size 115 | 116 | if group['weight_decay'] != 0: 117 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 118 | 119 | if N_sma > self.N_sma_threshhold: 120 | denom = exp_avg_sq.sqrt().add_(group['eps']) 121 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 122 | else: 123 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 124 | 125 | p.data.copy_(p_data_fp32) 126 | 127 | 128 | #---------------- end radam step 129 | 130 | #look ahead tracking and updating if latest batch = k 131 | for group,slow_weights in zip(self.param_groups,self.slow_weights): 132 | group['step_counter'] += 1 133 | if group['step_counter'] % self.k != 0: 134 | continue 135 | for p,q in zip(group['params'],slow_weights): 136 | if p.grad is None: 137 | continue 138 | q.data.add_(self.alpha,p.data - q.data) 139 | p.data.copy_(q.data) 140 | 141 | 142 | 143 | return loss -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import torch.nn as nn 18 | 19 | from modules.transformation import TPS_SpatialTransformerNetwork 20 | from modules.feature_extraction import VGG_FeatureExtractor, RCNN_FeatureExtractor, ResNet_FeatureExtractor 21 | from modules.sequence_modeling import BidirectionalLSTM 22 | from modules.prediction import Attention 23 | from modules.resnet_aster import ResNet_ASTER 24 | 25 | from modules.bert import Bert_Ocr 26 | from modules.bert import Config 27 | 28 | from modules.SRN_modules import Transforme_Encoder, SRN_Decoder, Torch_transformer_encoder 29 | from modules.resnet_fpn import ResNet_FPN 30 | 31 | 32 | class Model(nn.Module): 33 | 34 | def __init__(self, opt): 35 | super(Model, self).__init__() 36 | self.opt = opt 37 | self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction, 38 | 'Seq': opt.SequenceModeling, 'Pred': opt.Prediction} 39 | 40 | """ Transformation """ 41 | if opt.Transformation == 'TPS': 42 | self.Transformation = TPS_SpatialTransformerNetwork( 43 | F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel) 44 | else: 45 | print('No Transformation module specified') 46 | 47 | """ FeatureExtraction """ 48 | if opt.FeatureExtraction == 'VGG': 49 | self.FeatureExtraction = VGG_FeatureExtractor(opt.input_channel, opt.output_channel) 50 | elif opt.FeatureExtraction == 'RCNN': 51 | self.FeatureExtraction = RCNN_FeatureExtractor(opt.input_channel, opt.output_channel) 52 | elif opt.FeatureExtraction == 'ResNet': 53 | self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel) 54 | self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1 55 | elif opt.FeatureExtraction == 'AsterRes': 56 | self.FeatureExtraction = ResNet_ASTER(opt.input_channel, opt.output_channel) 57 | elif opt.FeatureExtraction == 'ResnetFpn': 58 | self.FeatureExtraction = ResNet_FPN() 59 | else: 60 | raise Exception('No FeatureExtraction module specified') 61 | self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512 62 | 63 | 64 | """ Sequence modeling""" 65 | if opt.SequenceModeling == 'BiLSTM': 66 | self.SequenceModeling = nn.Sequential( 67 | BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size), 68 | BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size)) 69 | self.SequenceModeling_output = opt.hidden_size 70 | elif opt.SequenceModeling == 'Bert': 71 | cfg = Config() 72 | cfg.dim = opt.output_channel; cfg.dim_c = opt.output_channel # 降维减少计算量 73 | cfg.p_dim = opt.position_dim # 一张图片cnn编码之后的特征序列长度 74 | cfg.max_vocab_size = opt.batch_max_length + 1 # 一张图片中最多的文字个数, +1 for EOS 75 | cfg.len_alphabet = opt.alphabet_size # 文字的类别个数 76 | self.SequenceModeling = Bert_Ocr(cfg) 77 | elif opt.SequenceModeling == 'SRN': 78 | self.SequenceModeling = Transforme_Encoder(n_layers=2, n_position=opt.position_dim) 79 | # self.SequenceModeling = Torch_transformer_encoder(n_layers=2, n_position=opt.position_dim) 80 | self.SequenceModeling_output = 512 81 | else: 82 | print('No SequenceModeling module specified') 83 | self.SequenceModeling_output = self.FeatureExtraction_output 84 | 85 | """ Prediction """ 86 | if opt.Prediction == 'CTC': 87 | self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class) 88 | elif opt.Prediction == 'Attn': 89 | self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class) 90 | elif opt.Prediction == 'Bert_pred': 91 | pass 92 | elif opt.Prediction == 'SRN': 93 | self.Prediction = SRN_Decoder(n_position=opt.position_dim, N_max_character=opt.batch_max_character + 1, n_class=opt.alphabet_size) 94 | else: 95 | raise Exception('Prediction is neither CTC or Attn') 96 | 97 | def forward(self, input, text, is_train=True): 98 | """ Transformation stage """ 99 | if not self.stages['Trans'] == "None": 100 | input = self.Transformation(input) 101 | 102 | 103 | """ Feature extraction stage """ 104 | visual_feature = self.FeatureExtraction(input) 105 | # if self.stages['Feat'] == 'AsterRes' or self.stages['Feat'] == 'ResnetFpn': 106 | if self.stages['Feat'] == 'AsterRes' or self.stages['Feat'] == 'ResnetFpn': 107 | b, c, h, w = visual_feature.shape 108 | visual_feature = visual_feature.permute(0, 1, 3, 2) 109 | visual_feature = visual_feature.contiguous().view(b, c, -1) 110 | visual_feature = visual_feature.permute(0, 2, 1) # batch, seq, feature 111 | else: 112 | visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h] 113 | visual_feature = visual_feature.squeeze(3) 114 | 115 | 116 | """ Sequence modeling stage """ 117 | if self.stages['Seq'] == 'BiLSTM': 118 | contextual_feature = self.SequenceModeling(visual_feature) 119 | elif self.stages['Seq'] == 'Bert': 120 | pad_mask = text 121 | contextual_feature = self.SequenceModeling(visual_feature, pad_mask) 122 | elif self.stages['Seq'] == 'SRN': 123 | contextual_feature = self.SequenceModeling(visual_feature, src_mask=None)[0] 124 | else: 125 | contextual_feature = visual_feature # for convenience. this is NOT contextually modeled by BiLSTM 126 | 127 | 128 | """ Prediction stage """ 129 | if self.stages['Pred'] == 'CTC': 130 | prediction = self.Prediction(contextual_feature.contiguous()) 131 | elif self.stages['Pred'] == 'Bert_pred': 132 | prediction = contextual_feature 133 | elif self.stages['Pred'] == 'SRN': 134 | prediction = self.Prediction(contextual_feature) 135 | else: 136 | prediction = self.Prediction(contextual_feature.contiguous(), text, is_train, batch_max_length=self.opt.batch_max_length) 137 | 138 | return prediction 139 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import string 2 | import argparse 3 | 4 | import torch 5 | import torch.backends.cudnn as cudnn 6 | import torch.utils.data 7 | 8 | from utils import CTCLabelConverter, AttnLabelConverter, TransformerConverter, SRNConverter 9 | from dataset import RawDataset, AlignCollate 10 | from model import Model 11 | 12 | 13 | def demo(opt): 14 | """ model configuration """ 15 | if 'CTC' in opt.Prediction: 16 | converter = CTCLabelConverter(opt.character) 17 | elif 'Bert' in opt.Prediction: 18 | converter = TransformerConverter(opt.character, opt.batch_max_length) 19 | elif 'SRN' in opt.Prediction: 20 | converter = SRNConverter(opt.character, PAD=36) 21 | else: 22 | converter = AttnLabelConverter(opt.character) 23 | opt.num_class = len(converter.character) 24 | opt.alphabet_size = len(opt.character) + 2 # +2 for [UNK]+[EOS] 25 | 26 | if opt.rgb: 27 | opt.input_channel = 3 28 | model = Model(opt) 29 | print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, 30 | opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, 31 | opt.SequenceModeling, opt.Prediction) 32 | 33 | model = torch.nn.DataParallel(model) 34 | if torch.cuda.is_available(): 35 | model = model.cuda() 36 | 37 | # load model 38 | print('loading pretrained model from %s' % opt.saved_model) 39 | model.load_state_dict(torch.load(opt.saved_model)) 40 | 41 | # prepare data. two demo images from https://github.com/bgshih/crnn#run-demo 42 | AlignCollate_demo = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 43 | demo_data = RawDataset(root=opt.image_folder, opt=opt) # use RawDataset 44 | demo_loader = torch.utils.data.DataLoader( 45 | demo_data, batch_size=opt.batch_size, 46 | shuffle=False, 47 | num_workers=int(opt.workers), 48 | collate_fn=AlignCollate_demo, pin_memory=True) 49 | 50 | # predict 51 | model.eval() 52 | for image_tensors, image_path_list in demo_loader: 53 | batch_size = image_tensors.size(0) 54 | with torch.no_grad(): 55 | image = image_tensors.cuda() 56 | # For max length prediction 57 | length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] * batch_size) 58 | text_for_pred = torch.cuda.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0) 59 | 60 | if 'CTC' in opt.Prediction: 61 | preds = model(image, text_for_pred).log_softmax(2) 62 | 63 | # Select max probabilty (greedy decoding) then decode index to character 64 | preds_size = torch.IntTensor([preds.size(1)] * batch_size) 65 | _, preds_index = preds.permute(1, 0, 2).max(2) 66 | preds_index = preds_index.transpose(1, 0).contiguous().view(-1) 67 | preds_str = converter.decode(preds_index.data, preds_size.data) 68 | 69 | elif 'Bert' in opt.Prediction: 70 | with torch.no_grad(): 71 | pad_mask = None 72 | preds = model(image, pad_mask) 73 | 74 | # select max probabilty (greedy decoding) then decode index to character 75 | _, preds_index = preds[1].max(2) 76 | length_for_pred = torch.cuda.IntTensor([preds_index.size(-1)] * batch_size) 77 | preds_str = converter.decode(preds_index, length_for_pred) 78 | 79 | elif 'SRN' in opt.Prediction: 80 | with torch.no_grad(): 81 | preds = model(image, None) 82 | 83 | # select max probabilty (greedy decoding) then decode index to character 84 | _, preds_index = preds[2].max(2) 85 | length_for_pred = torch.cuda.IntTensor([preds_index.size(-1)] * batch_size) 86 | preds_str = converter.decode(preds_index, length_for_pred) 87 | 88 | else: 89 | preds = model(image, text_for_pred, is_train=False) 90 | 91 | # select max probabilty (greedy decoding) then decode index to character 92 | _, preds_index = preds.max(2) 93 | preds_str = converter.decode(preds_index, length_for_pred) 94 | 95 | print('-' * 80) 96 | print('image_path\tpredicted_labels') 97 | print('-' * 80) 98 | for img_name, pred in zip(image_path_list, preds_str): 99 | if 'Attn' in opt.Prediction: 100 | pred = pred[:pred.find('[s]')] # prune after "end of sentence" token ([s]) 101 | 102 | print(f'{img_name}\t{pred}') 103 | 104 | 105 | if __name__ == '__main__': 106 | parser = argparse.ArgumentParser() 107 | parser.add_argument('--image_folder', default='demo_image/', help='path to image_folder which contains text images') 108 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) 109 | parser.add_argument('--batch_size', type=int, default=64, help='input batch size') 110 | parser.add_argument('--saved_model', default='./saved_models/None-ResNet-SRN-SRN-Seed666/iter_300000.pth', help="path to saved_model to evaluation") 111 | """ Data processing """ 112 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') 113 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') 114 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') 115 | parser.add_argument('--rgb', action='store_true', help='use rgb input') 116 | parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz$', help='character label') 117 | parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode') 118 | parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize') 119 | """ Model Architecture """ 120 | parser.add_argument('--Transformation', type=str, default='None', help='Transformation stage. None|TPS') 121 | parser.add_argument('--FeatureExtraction', type=str, default='ResNet', help='FeatureExtraction stage. VGG|RCNN|ResNet|AsterRes') 122 | parser.add_argument('--SequenceModeling', type=str, default='SRN', help='SequenceModeling stage. None|BiLSTM|Bert|SRN') 123 | parser.add_argument('--Prediction', type=str, default='SRN', help='Prediction stage. CTC|Attn|Bert_pred|SRN') 124 | parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') 125 | parser.add_argument('--input_channel', type=int, default=1, help='the number of input channel of Feature extractor') 126 | parser.add_argument('--output_channel', type=int, default=512, 127 | help='the number of output channel of Feature extractor') 128 | parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') 129 | parser.add_argument('--position_dim', type=int, default=26, help='the length sequence out from cnn encoder,resnet:65;resnetfpn:256') 130 | opt = parser.parse_args() 131 | 132 | """ vocab / character number configuration """ 133 | if opt.sensitive: 134 | opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). 135 | 136 | cudnn.benchmark = True 137 | cudnn.deterministic = True 138 | opt.num_gpu = torch.cuda.device_count() 139 | 140 | demo(opt) 141 | -------------------------------------------------------------------------------- /modules/transformation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class TPS_SpatialTransformerNetwork(nn.Module): 8 | """ Rectification Network of RARE, namely TPS based STN """ 9 | 10 | def __init__(self, F, I_size, I_r_size, I_channel_num=1): 11 | """ Based on RARE TPS 12 | input: 13 | batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width] 14 | I_size : (height, width) of the input image I 15 | I_r_size : (height, width) of the rectified image I_r 16 | I_channel_num : the number of channels of the input image I 17 | output: 18 | batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width] 19 | """ 20 | super(TPS_SpatialTransformerNetwork, self).__init__() 21 | self.F = F 22 | self.I_size = I_size 23 | self.I_r_size = I_r_size # = (I_r_height, I_r_width) 24 | self.I_channel_num = I_channel_num 25 | self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num) 26 | self.GridGenerator = GridGenerator(self.F, self.I_r_size) 27 | 28 | def forward(self, batch_I): 29 | batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 30 | build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2 31 | build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2]) 32 | batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border') 33 | 34 | return batch_I_r 35 | 36 | 37 | class LocalizationNetwork(nn.Module): 38 | """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """ 39 | 40 | def __init__(self, F, I_channel_num): 41 | super(LocalizationNetwork, self).__init__() 42 | self.F = F 43 | self.I_channel_num = I_channel_num 44 | self.conv = nn.Sequential( 45 | nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1, 46 | bias=False), nn.BatchNorm2d(64), nn.ReLU(True), 47 | nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2 48 | nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), 49 | nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4 50 | nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), 51 | nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8 52 | nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), 53 | nn.AdaptiveAvgPool2d(1) # batch_size x 512 54 | ) 55 | 56 | self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True)) 57 | self.localization_fc2 = nn.Linear(256, self.F * 2) 58 | 59 | # Init fc2 in LocalizationNetwork 60 | self.localization_fc2.weight.data.fill_(0) 61 | """ see RARE paper Fig. 6 (a) """ 62 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 63 | ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2)) 64 | ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2)) 65 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 66 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 67 | initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 68 | self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1) 69 | 70 | def forward(self, batch_I): 71 | """ 72 | input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width] 73 | output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2] 74 | """ 75 | batch_size = batch_I.size(0) 76 | features = self.conv(batch_I).view(batch_size, -1) 77 | batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2) 78 | return batch_C_prime 79 | 80 | 81 | class GridGenerator(nn.Module): 82 | """ Grid Generator of RARE, which produces P_prime by multipling T with P """ 83 | 84 | def __init__(self, F, I_r_size): 85 | """ Generate P_hat and inv_delta_C for later """ 86 | super(GridGenerator, self).__init__() 87 | self.eps = 1e-6 88 | self.I_r_height, self.I_r_width = I_r_size 89 | self.F = F 90 | self.C = self._build_C(self.F) # F x 2 91 | self.P = self._build_P(self.I_r_width, self.I_r_height) 92 | self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3 93 | self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3 94 | 95 | def _build_C(self, F): 96 | """ Return coordinates of fiducial points in I_r; C """ 97 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 98 | ctrl_pts_y_top = -1 * np.ones(int(F / 2)) 99 | ctrl_pts_y_bottom = np.ones(int(F / 2)) 100 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 101 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 102 | C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 103 | return C # F x 2 104 | 105 | def _build_inv_delta_C(self, F, C): 106 | """ Return inv_delta_C which is needed to calculate T """ 107 | hat_C = np.zeros((F, F), dtype=float) # F x F 108 | for i in range(0, F): 109 | for j in range(i, F): 110 | r = np.linalg.norm(C[i] - C[j]) 111 | hat_C[i, j] = r 112 | hat_C[j, i] = r 113 | np.fill_diagonal(hat_C, 1) 114 | hat_C = (hat_C ** 2) * np.log(hat_C) 115 | # print(C.shape, hat_C.shape) 116 | delta_C = np.concatenate( # F+3 x F+3 117 | [ 118 | np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3 119 | np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3 120 | np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3 121 | ], 122 | axis=0 123 | ) 124 | inv_delta_C = np.linalg.inv(delta_C) 125 | return inv_delta_C # F+3 x F+3 126 | 127 | def _build_P(self, I_r_width, I_r_height): 128 | I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width # self.I_r_width 129 | I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height # self.I_r_height 130 | P = np.stack( # self.I_r_width x self.I_r_height x 2 131 | np.meshgrid(I_r_grid_x, I_r_grid_y), 132 | axis=2 133 | ) 134 | return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2 135 | 136 | def _build_P_hat(self, F, C, P): 137 | n = P.shape[0] # n (= self.I_r_width x self.I_r_height) 138 | P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2 139 | C_tile = np.expand_dims(C, axis=0) # 1 x F x 2 140 | P_diff = P_tile - C_tile # n x F x 2 141 | rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F 142 | rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F 143 | P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1) 144 | return P_hat # n x F+3 145 | 146 | def build_P_prime(self, batch_C_prime): 147 | """ Generate Grid from batch_C_prime [batch_size x F x 2] """ 148 | batch_size = batch_C_prime.size(0) 149 | batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1) 150 | batch_P_hat = self.P_hat.repeat(batch_size, 1, 1) 151 | batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros( 152 | batch_size, 3, 2).float().cuda()), dim=1) # batch_size x F+3 x 2 153 | batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2 154 | batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2 155 | return batch_P_prime # batch_size x n x 2 156 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class CTCLabelConverter(object): 5 | """ Convert between text-label and text-index """ 6 | 7 | def __init__(self, character): 8 | # character (str): set of the possible characters. 9 | dict_character = list(character) 10 | 11 | self.dict = {} 12 | for i, char in enumerate(dict_character): 13 | # NOTE: 0 is reserved for 'blank' token required by CTCLoss 14 | self.dict[char] = i + 1 15 | 16 | self.character = ['[blank]'] + dict_character # dummy '[blank]' token for CTCLoss (index 0) 17 | 18 | def encode(self, text): 19 | """convert text-label into text-index. 20 | input: 21 | text: text labels of each image. [batch_size] 22 | 23 | output: 24 | text: concatenated text index for CTCLoss. 25 | [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] 26 | length: length of each text. [batch_size] 27 | """ 28 | length = [len(s) for s in text] 29 | text = ''.join(text) 30 | text = [self.dict[char] for char in text] 31 | 32 | return (torch.IntTensor(text), torch.IntTensor(length)) 33 | 34 | def decode(self, text_index, length): 35 | """ convert text-index into text-label. """ 36 | texts = [] 37 | index = 0 38 | for l in length: 39 | t = text_index[index:index + l] 40 | 41 | char_list = [] 42 | for i in range(l): 43 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. 44 | char_list.append(self.character[t[i]]) 45 | text = ''.join(char_list) 46 | 47 | texts.append(text) 48 | index += l 49 | return texts 50 | 51 | 52 | class AttnLabelConverter(object): 53 | """ Convert between text-label and text-index """ 54 | 55 | def __init__(self, character): 56 | # character (str): set of the possible characters. 57 | # [GO] for the start token of the attention decoder. [s] for end-of-sentence token. 58 | list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]'] 59 | list_character = list(character) 60 | self.character = list_token + list_character 61 | 62 | self.dict = {} 63 | for i, char in enumerate(self.character): 64 | # print(i, char) 65 | self.dict[char] = i 66 | 67 | def encode(self, text, batch_max_length=25): 68 | """ convert text-label into text-index. 69 | input: 70 | text: text labels of each image. [batch_size] 71 | batch_max_length: max length of text label in the batch. 25 by default 72 | 73 | output: 74 | text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token. 75 | text[:, 0] is [GO] token and text is padded with [GO] token after [s] token. 76 | length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size] 77 | """ 78 | length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence. 79 | # batch_max_length = max(length) # this is not allowed for multi-gpu setting 80 | batch_max_length += 1 81 | # additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token. 82 | batch_text = torch.cuda.LongTensor(len(text), batch_max_length + 1).fill_(0) 83 | for i, t in enumerate(text): 84 | text = list(t) 85 | text.append('[s]') 86 | text = [self.dict[char] for char in text] 87 | batch_text[i][1:1 + len(text)] = torch.cuda.LongTensor(text) # batch_text[:, 0] = [GO] token 88 | return (batch_text, torch.cuda.IntTensor(length)) 89 | 90 | def decode(self, text_index, length): 91 | """ convert text-index into text-label. """ 92 | texts = [] 93 | for index, l in enumerate(length): 94 | text = ''.join([self.character[i] for i in text_index[index, :]]) 95 | texts.append(text) 96 | return texts 97 | 98 | 99 | class TransformerConverter(object): 100 | """ Convert between text-label and text-index """ 101 | 102 | def __init__(self, character, batch_max_length=94): 103 | # character (str): set of the possible characters. 104 | # [GO] for the start token of the attention decoder. [s] for end-of-sentence token. 105 | list_token = ['❶', '❷'] # ['[s]','[UNK]','[PAD]','[GO]'] 106 | list_character = list(character) 107 | self.character = list_token + list_character 108 | self.batch_max_length = batch_max_length 109 | 110 | self.dict = {} 111 | for i, char in enumerate(self.character): 112 | # print(i, char) 113 | self.dict[char] = i 114 | 115 | def encode(self, text): 116 | """ convert text-label into text-index. 117 | input: 118 | text: text labels of each image. [batch_size] 119 | batch_max_length: max length of text label in the batch. 25 by default 120 | 121 | output: 122 | text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token. 123 | text[:, 0] is [GO] token and text is padded with [GO] token after [s] token. 124 | length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size] 125 | """ 126 | length = [ ] # +1 for [EOS] at end of sentence. 127 | 128 | # additional +1 for [EOS] at last step. batch_text is padded with [UNK] token 129 | batch_text = torch.cuda.LongTensor(len(text), self.batch_max_length + 1).fill_(0) 130 | for i, t in enumerate(text): 131 | per_text = list(t.replace('\u3000', '')) 132 | per_text.append('❷') 133 | length.append(len(per_text)) 134 | per_text = [self.dict[char] for char in per_text] 135 | 136 | batch_text[i][:len(per_text)] = torch.cuda.LongTensor(per_text) # batch_text[:, 0] = [GO] token 137 | return (batch_text, torch.cuda.IntTensor(length)) 138 | 139 | def decode(self, text_index, length): 140 | """ convert text-index into text-label. """ 141 | texts = [] 142 | for index, l in enumerate(length): 143 | text = ''.join([self.character[i] for i in text_index[index, :]]) 144 | idx = text.find('❷') 145 | texts.append(text[:idx]) 146 | return texts 147 | 148 | 149 | class SRNConverter(object): 150 | """ Convert between text-label and text-index """ 151 | 152 | def __init__(self, character, PAD=36): 153 | # character (str): set of the possible characters. 154 | # [GO] for the start token of the attention decoder. [s] for end-of-sentence token. 155 | # list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]'] 156 | list_character = list(character) 157 | self.character = list_character 158 | self.PAD = PAD 159 | 160 | self.dict = {} 161 | for i, char in enumerate(self.character): 162 | # print(i, char) 163 | self.dict[char] = i 164 | 165 | def encode(self, text, batch_max_length=25): 166 | """ convert text-label into text-index. 167 | input: 168 | text: text labels of each image. [batch_size] 169 | batch_max_length: max length of text label in the batch. 25 by default 170 | 171 | output: 172 | text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token. 173 | text[:, 0] is [GO] token and text is padded with [GO] token after [s] token. 174 | length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size] 175 | """ 176 | length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence. 177 | # additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token. 178 | batch_text = torch.cuda.LongTensor(len(text), batch_max_length + 1).fill_(self.PAD) 179 | # mask_text = torch.cuda.LongTensor(len(text), batch_max_length).fill_(0) 180 | for i, t in enumerate(text): 181 | t = list(t + self.character[-2]) 182 | text = [self.dict[char] for char in t] 183 | # t_mask = [1 for i in range(len(text) + 1)] 184 | batch_text[i][0:len(text)] = torch.cuda.LongTensor(text) # batch_text[:, len_text+1] = [EOS] token 185 | # mask_text[i][0:len(text)+1] = torch.cuda.LongTensor(t_mask) 186 | return (batch_text, torch.cuda.IntTensor(length)) 187 | 188 | def decode(self, text_index, length): 189 | """ convert text-index into text-label. """ 190 | texts = [] 191 | for index, l in enumerate(length): 192 | text = ''.join([self.character[i] for i in text_index[index, :]]) 193 | idx = text.find('$') 194 | texts.append(text[:idx]) 195 | return texts 196 | 197 | 198 | class Averager(object): 199 | """Compute average for torch.Tensor, used for loss average.""" 200 | 201 | def __init__(self): 202 | self.reset() 203 | 204 | def add(self, v): 205 | count = v.data.numel() 206 | v = v.data.sum() 207 | self.n_count += count 208 | self.sum += v 209 | 210 | def reset(self): 211 | self.n_count = 0 212 | self.sum = 0 213 | 214 | def val(self): 215 | res = 0 216 | if self.n_count != 0: 217 | res = self.sum / float(self.n_count) 218 | return res 219 | -------------------------------------------------------------------------------- /modules/bert.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 Dong-Hyun Lee, Kakao Brain. 2 | # (Strongly inspired by original Google BERT code and Hugging Face's code) 3 | 4 | """ Transformer Model Classes & Config Class """ 5 | 6 | import math 7 | import json 8 | from typing import NamedTuple 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | def split_last(x, shape): 17 | "split the last dimension to given shape" 18 | shape = list(shape) 19 | assert shape.count(-1) <= 1 20 | if -1 in shape: 21 | shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape)) 22 | return x.view(*x.size()[:-1], *shape) 23 | 24 | 25 | def merge_last(x, n_dims): 26 | "merge the last n_dims to a dimension" 27 | s = x.size() 28 | assert n_dims > 1 and n_dims < len(s) 29 | return x.view(*s[:-n_dims], -1) 30 | 31 | 32 | def gelu(x): 33 | "Implementation of the gelu activation function by Hugging Face" 34 | return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) 35 | 36 | 37 | class LayerNorm(nn.Module): 38 | "A layernorm module in the TF style (epsilon inside the square root)." 39 | def __init__(self, cfg, variance_epsilon=1e-12): 40 | super().__init__() 41 | self.gamma = nn.Parameter(torch.ones(cfg.dim)) 42 | self.beta = nn.Parameter(torch.zeros(cfg.dim)) 43 | self.variance_epsilon = variance_epsilon 44 | 45 | def forward(self, x): 46 | u = x.mean(-1, keepdim=True) 47 | s = (x - u).pow(2).mean(-1, keepdim=True) 48 | x = (x - u) / torch.sqrt(s + self.variance_epsilon) 49 | return self.gamma * x + self.beta 50 | 51 | 52 | class Embeddings(nn.Module): 53 | "The embedding module from word, position and token_type embeddings." 54 | def __init__(self, cfg): 55 | super().__init__() 56 | self.pos_embed = nn.Embedding(cfg.p_dim, cfg.dim) # position embedding 57 | self.norm = LayerNorm(cfg) 58 | self.drop = nn.Dropout(cfg.p_drop_hidden) 59 | 60 | def forward(self, x): 61 | seq_len = x.size(1) 62 | pos = torch.arange(seq_len, dtype=torch.long, device=x.device) 63 | pos = pos.unsqueeze(0).expand(x.size(0), -1) # (S,) -> (B, S) 64 | 65 | e = x + self.pos_embed(pos) 66 | return self.drop(self.norm(e)) 67 | 68 | 69 | class MultiHeadedSelfAttention(nn.Module): 70 | """ Multi-Headed Dot Product Attention """ 71 | def __init__(self, cfg): 72 | super().__init__() 73 | self.proj_q = nn.Linear(cfg.dim, cfg.dim) 74 | self.proj_k = nn.Linear(cfg.dim, cfg.dim) 75 | self.proj_v = nn.Linear(cfg.dim, cfg.dim) 76 | self.drop = nn.Dropout(cfg.p_drop_attn) 77 | self.scores = None # for visualization 78 | self.n_heads = cfg.n_heads 79 | 80 | def forward(self, x, mask): 81 | """ 82 | x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim)) 83 | mask : (B(batch_size) x S(seq_len)) 84 | * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W 85 | """ 86 | # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W) 87 | q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x) 88 | q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) 89 | for x in [q, k, v]) 90 | # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S) 91 | scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1)) 92 | if mask is not None: 93 | mask = mask[:, None, None, :].float() 94 | scores -= 10000.0 * (1.0 - mask) 95 | scores = self.drop(F.softmax(scores, dim=-1)) 96 | # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W) 97 | h = (scores @ v).transpose(1, 2).contiguous() 98 | # -merge-> (B, S, D) 99 | h = merge_last(h, 2) 100 | self.scores = scores 101 | return h 102 | 103 | 104 | class PositionWiseFeedForward(nn.Module): 105 | """ FeedForward Neural Networks for each position """ 106 | def __init__(self, cfg): 107 | super().__init__() 108 | self.fc1 = nn.Linear(cfg.dim, cfg.dim_ff) 109 | self.fc2 = nn.Linear(cfg.dim_ff, cfg.dim) 110 | #self.activ = lambda x: activ_fn(cfg.activ_fn, x) 111 | 112 | def forward(self, x): 113 | # (B, S, D) -> (B, S, D_ff) -> (B, S, D) 114 | return self.fc2(gelu(self.fc1(x))) 115 | 116 | 117 | class Block(nn.Module): 118 | """ Transformer Block """ 119 | def __init__(self, cfg): 120 | super().__init__() 121 | self.attn = MultiHeadedSelfAttention(cfg) 122 | self.proj = nn.Linear(cfg.dim, cfg.dim) 123 | self.norm1 = LayerNorm(cfg) 124 | self.pwff = PositionWiseFeedForward(cfg) 125 | self.norm2 = LayerNorm(cfg) 126 | self.drop = nn.Dropout(cfg.p_drop_hidden) 127 | 128 | def forward(self, x, mask): 129 | h = self.attn(x, mask) 130 | h = self.norm1(x + self.drop(self.proj(h))) 131 | h = self.norm2(h + self.drop(self.pwff(h))) 132 | return h 133 | 134 | 135 | class Transformer(nn.Module): 136 | """ Transformer with Self-Attentive Blocks""" 137 | def __init__(self, cfg, n_layers): 138 | super().__init__() 139 | self.embed = Embeddings(cfg) 140 | self.blocks = nn.ModuleList([Block(cfg) for _ in range(n_layers)]) 141 | 142 | def forward(self, x, mask): 143 | h = self.embed(x) 144 | for block in self.blocks: 145 | h = block(h, mask) 146 | return h 147 | 148 | 149 | class Parallel_Attention(nn.Module): 150 | ''' the Parallel Attention Module for 2D attention 151 | reference the origin paper: https://arxiv.org/abs/1906.05708 152 | ''' 153 | def __init__(self, cfg): 154 | super().__init__() 155 | self.atten_w1 = nn.Linear(cfg.dim_c, cfg.dim_c) 156 | self.atten_w2 = nn.Linear(cfg.dim_c, cfg.max_vocab_size) 157 | self.activ_fn = nn.Tanh() 158 | self.soft = nn.Softmax(dim=1) 159 | self.drop = nn.Dropout(0.1) 160 | 161 | def forward(self, origin_I, bert_out, mask=None): 162 | bert_out = self.activ_fn(self.drop(self.atten_w1(bert_out))) 163 | atten_w = self.soft(self.atten_w2(bert_out)) # b*200*94 164 | x = torch.bmm(origin_I.transpose(1,2), atten_w) # b*512*94 165 | return x 166 | 167 | 168 | class MultiHeadAttention(nn.Module): 169 | ''' Multi-Head Attention module ''' 170 | 171 | def __init__(self, n_head=8, d_k=64, d_model=128, max_vocab_size=94, dropout=0.1): 172 | ''' d_k: the attention dim 173 | d_model: the encoder output feature 174 | max_vocab_size: the output maxium length of sequence 175 | ''' 176 | super(MultiHeadAttention, self).__init__() 177 | 178 | self.n_head, self.d_k = n_head, d_k 179 | self.temperature = np.power(d_k, 0.5) 180 | self.max_vocab_size = max_vocab_size 181 | 182 | self.w_encoder = nn.Linear(d_model, n_head * d_k) 183 | self.w_atten = nn.Linear(d_model, n_head * max_vocab_size) 184 | self.w_out = nn.Linear(n_head * d_k, d_model) 185 | self.activ_fn = nn.Tanh() 186 | 187 | self.softmax = nn.Softmax(dim=1) # at the d_in dimension 188 | self.dropout = nn.Dropout(dropout) 189 | 190 | nn.init.normal_(self.w_encoder.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 191 | nn.init.normal_(self.w_atten.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 192 | nn.init.xavier_normal_(self.w_out.weight) 193 | 194 | 195 | def forward(self, encoder_feature, bert_out, mask=None): 196 | 197 | d_k, n_head, max_vocab_size = self.d_k, self.n_head, self.max_vocab_size 198 | 199 | sz_b, d_in, _ = encoder_feature.size() 200 | 201 | # 原始特征 202 | encoder_feature = encoder_feature.view(sz_b, d_in, n_head, d_k) 203 | encoder_feature = encoder_feature.permute(2, 0, 1, 3).contiguous().view(-1, d_in, d_k) # 32*200*64 204 | 205 | # 求解权值 206 | alpha = self.activ_fn(self.dropout(self.w_encoder(bert_out))) 207 | alpha = self.w_atten(alpha).view(sz_b, d_in, n_head, max_vocab_size) # 4*200*8*94 208 | alpha = alpha.permute(2, 0, 1, 3).contiguous().view(-1, d_in, max_vocab_size) # 32*200*94 209 | alpha = alpha / self.temperature 210 | alpha = self.dropout(self.softmax(alpha)) # 32*200*94 211 | 212 | # 输出部分 213 | output = torch.bmm(encoder_feature.transpose(1,2), alpha) # 32*64*94 214 | output = output.view(n_head, sz_b, d_k, max_vocab_size) 215 | output = output.permute(1, 3, 0, 2).contiguous().view(sz_b, max_vocab_size, -1) # 4*94*512 216 | output = self.dropout(self.w_out(output)) 217 | output = output.transpose(1,2) 218 | 219 | return output 220 | 221 | 222 | class Two_Stage_Decoder(nn.Module): 223 | def __init__(self, cfg): 224 | super().__init__() 225 | self.out_w = nn.Linear(cfg.dim_c, cfg.len_alphabet) 226 | self.relation_attention = Transformer(cfg, cfg.decoder_atten_layers) 227 | self.out_w1 = nn.Linear(cfg.dim_c, cfg.len_alphabet) 228 | 229 | def forward(self, x): 230 | x1 = self.out_w(x) 231 | x2 = self.relation_attention(x, mask=None) 232 | x2 = self.out_w1(x2) # 两个分支的输出部分采用不同的网络 233 | 234 | return x1, x2 235 | 236 | 237 | class Bert_Ocr(nn.Module): 238 | def __init__(self, cfg): 239 | super().__init__() 240 | self.cfg = cfg 241 | self.transformer = Transformer(cfg, cfg.attention_layers) 242 | self.attention = Parallel_Attention(cfg) 243 | # self.attention = MultiHeadAttention(d_model=cfg.dim, max_vocab_size=cfg.max_vocab_size) 244 | self.decoder = Two_Stage_Decoder(cfg) 245 | 246 | def forward(self, encoder_feature, mask=None): 247 | bert_out = self.transformer(encoder_feature, mask) # 做一个self_attention//4*200*512 248 | glimpses = self.attention(encoder_feature, bert_out, mask) # 原始序列和目标序列的转化//4*512*94 249 | res = self.decoder(glimpses.transpose(1,2)) 250 | return res 251 | 252 | 253 | class Config(object): 254 | '''参数设置''' 255 | """ Relation Attention Module """ 256 | p_drop_attn = 0.1 257 | p_drop_hidden = 0.1 258 | dim = 512 # the encode output feature 259 | attention_layers = 2 # the layers of transformer 260 | n_heads = 8 261 | dim_ff = 1024 * 2 # 位置前向传播的隐含层维度 262 | 263 | ''' Parallel Attention Module ''' 264 | dim_c = dim 265 | max_vocab_size = 26 # 一张图片含有字符的最大长度 266 | 267 | """ Two-stage Decoder """ 268 | len_alphabet = 39 # 字符类别数量 269 | decoder_atten_layers = 2 270 | 271 | 272 | def numel(model): 273 | return sum(p.numel() for p in model.parameters()) 274 | 275 | 276 | if __name__ == '__main__': 277 | 278 | cfg = Config() 279 | mask = None 280 | x = torch.randn(4, 200, cfg.dim) 281 | net = Bert_Ocr(cfg) 282 | res1, res2 = net(x, mask) 283 | print(res1.shape, res2.shape) 284 | print('参数总量为:', numel(net)) 285 | -------------------------------------------------------------------------------- /modules/resnet_fpn.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Pytorch Faster R-CNN and FPN 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Zheqi He and Xinlei Chen, Yixiao Ge 5 | # https://github.com/yxgeee/pytorch-FPN/blob/master/lib/nets/resnet_v1.py 6 | # -------------------------------------------------------- 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import math 15 | import torch.utils.model_zoo as model_zoo 16 | 17 | 18 | __all__ = [ 19 | 'ResNet_FPN', 20 | 'ResNet', 21 | 'resnet18', 22 | 'resnet34', 23 | 'resnet50', 24 | 'resnet101', 25 | 'resnet152'] 26 | 27 | 28 | model_urls = { 29 | 'resnet18': 'https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth', 30 | 'resnet34': 'https://s3.amazonaws.com/pytorch/models/resnet34-333f7ec4.pth', 31 | 'resnet50': 'https://s3.amazonaws.com/pytorch/models/resnet50-19c8e357.pth', 32 | 'resnet101': 'https://s3.amazonaws.com/pytorch/models/resnet101-5d3b4d8f.pth', 33 | 'resnet152': 'https://s3.amazonaws.com/pytorch/models/resnet152-b121ed2d.pth', 34 | } 35 | 36 | 37 | def conv3x3(in_planes, out_planes, stride=1): 38 | "3x3 convolution with padding" 39 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 40 | padding=1, bias=False) 41 | 42 | 43 | class BasicBlock(nn.Module): 44 | expansion = 1 45 | 46 | def __init__(self, inplanes, planes, stride=1, downsample=None): 47 | super(BasicBlock, self).__init__() 48 | self.conv1 = conv3x3(inplanes, planes, stride) 49 | self.bn1 = nn.BatchNorm2d(planes) 50 | self.relu = nn.ReLU(inplace=True) 51 | self.conv2 = conv3x3(planes, planes) 52 | self.bn2 = nn.BatchNorm2d(planes) 53 | self.downsample = downsample 54 | self.stride = stride 55 | 56 | def forward(self, x): 57 | residual = x 58 | 59 | out = self.conv1(x) 60 | out = self.bn1(out) 61 | out = self.relu(out) 62 | 63 | out = self.conv2(out) 64 | out = self.bn2(out) 65 | 66 | if self.downsample is not None: 67 | residual = self.downsample(x) 68 | 69 | out += residual 70 | out = self.relu(out) 71 | 72 | return out 73 | 74 | 75 | class Bottleneck(nn.Module): 76 | expansion = 4 77 | 78 | def __init__(self, inplanes, planes, stride=1, downsample=None): 79 | super(Bottleneck, self).__init__() 80 | self.conv1 = nn.Conv2d( 81 | inplanes, 82 | planes, 83 | kernel_size=1, 84 | stride=stride, 85 | bias=False) # change 86 | self.bn1 = nn.BatchNorm2d(planes) 87 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change 88 | padding=1, bias=False) 89 | self.bn2 = nn.BatchNorm2d(planes) 90 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 91 | self.bn3 = nn.BatchNorm2d(planes * 4) 92 | self.relu = nn.ReLU(inplace=True) 93 | self.downsample = downsample 94 | self.stride = stride 95 | 96 | def forward(self, x): 97 | residual = x 98 | 99 | out = self.conv1(x) 100 | out = self.bn1(out) 101 | out = self.relu(out) 102 | 103 | out = self.conv2(out) 104 | out = self.bn2(out) 105 | out = self.relu(out) 106 | 107 | out = self.conv3(out) 108 | out = self.bn3(out) 109 | 110 | if self.downsample is not None: 111 | residual = self.downsample(x) 112 | 113 | out += residual 114 | out = self.relu(out) 115 | 116 | return out 117 | 118 | 119 | class BuildBlock(nn.Module): 120 | def __init__(self, planes=512): 121 | super(BuildBlock, self).__init__() 122 | 123 | self.planes = planes 124 | # Top-down layers, use nn.ConvTranspose2d to replace 125 | # nn.Conv2d+F.upsample? 126 | self.toplayer1 = nn.Conv2d( 127 | 2048, 128 | planes, 129 | kernel_size=1, 130 | stride=1, 131 | padding=0) # Reduce channels 132 | self.toplayer2 = nn.Conv2d( 133 | 512, planes, kernel_size=3, stride=1, padding=1) 134 | self.toplayer3 = nn.Conv2d( 135 | 512, planes, kernel_size=3, stride=1, padding=1) 136 | 137 | # Lateral layers 138 | self.latlayer1 = nn.Conv2d( 139 | 1024, planes, kernel_size=1, stride=1, padding=0) 140 | self.latlayer2 = nn.Conv2d( 141 | 512, planes, kernel_size=1, stride=1, padding=0) 142 | 143 | def _upsample_add(self, x, y): 144 | _, _, H, W = y.size() 145 | return F.upsample( 146 | x, 147 | size=( 148 | H, 149 | W), 150 | mode='bilinear', 151 | align_corners=True) + y 152 | 153 | def forward(self, c3, c4, c5): 154 | # Top-down 155 | p5 = self.toplayer1(c5) 156 | p4 = self._upsample_add(p5, self.latlayer1(c4)) 157 | p4 = self.toplayer2(p4) 158 | p3 = self._upsample_add(p4, self.latlayer2(c3)) 159 | p3 = self.toplayer3(p3) 160 | 161 | return p3, p4, p5 162 | 163 | 164 | class ResNet(nn.Module): 165 | def __init__(self, block, layers, num_classes=1000): 166 | self.inplanes = 64 167 | super(ResNet, self).__init__() 168 | # the symbol is referred to fots. 169 | # Conv1 /2 170 | self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, 171 | bias=False) 172 | self.bn1 = nn.BatchNorm2d(64) 173 | self.relu = nn.ReLU(inplace=True) 174 | # Pool1 /4 175 | # maxpool different from pytorch-resnet, to match tf-faster-rcnn 176 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 177 | self.layer1 = self._make_layer( 178 | block, 64, layers[0], stride=1) # Res2 /4 179 | self.layer2 = self._make_layer( 180 | block, 128, layers[1], stride=2) # Res3 /8 181 | self.layer3 = self._make_layer( 182 | block, 256, layers[2], stride=2) # Res4 /16 183 | # use stride 1 for the last conv4 layer (same as tf-faster-rcnn) 184 | self.layer4 = self._make_layer( 185 | block, 512, layers[3], stride=2) # Res5 /32 186 | 187 | for m in self.modules(): 188 | if isinstance(m, nn.Conv2d): 189 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 190 | m.weight.data.normal_(0, math.sqrt(2. / n)) 191 | elif isinstance(m, nn.BatchNorm2d): 192 | m.weight.data.fill_(1) 193 | m.bias.data.zero_() 194 | 195 | def _make_layer(self, block, planes, blocks, stride=1): 196 | downsample = None 197 | if stride != 1 or self.inplanes != planes * block.expansion: 198 | downsample = nn.Sequential( 199 | nn.Conv2d(self.inplanes, planes * block.expansion, 200 | kernel_size=1, stride=stride, bias=False), 201 | nn.BatchNorm2d(planes * block.expansion), 202 | ) 203 | 204 | layers = [] 205 | layers.append(block(self.inplanes, planes, stride, downsample)) 206 | self.inplanes = planes * block.expansion 207 | for i in range(1, blocks): 208 | layers.append(block(self.inplanes, planes)) 209 | 210 | return nn.Sequential(*layers) 211 | 212 | 213 | def resnet18(pretrained=False): 214 | """Constructs a ResNet-18 model. 215 | Args: 216 | pretrained (bool): If True, returns a model pre-trained on ImageNet 217 | """ 218 | model = ResNet(BasicBlock, [2, 2, 2, 2]) 219 | if pretrained: 220 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 221 | return model 222 | 223 | 224 | def resnet34(pretrained=False): 225 | """Constructs a ResNet-34 model. 226 | Args: 227 | pretrained (bool): If True, returns a model pre-trained on ImageNet 228 | """ 229 | model = ResNet(BasicBlock, [3, 4, 6, 3]) 230 | if pretrained: 231 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 232 | return model 233 | 234 | 235 | def resnet50(pretrained=False): 236 | """Constructs a ResNet-50 model. 237 | Args: 238 | pretrained (bool): If True, returns a model pre-trained on ImageNet 239 | """ 240 | model = ResNet(Bottleneck, [3, 4, 6, 3]) 241 | if pretrained: 242 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 243 | return model 244 | 245 | 246 | def resnet101(pretrained=False): 247 | """Constructs a ResNet-101 model. 248 | Args: 249 | pretrained (bool): If True, returns a model pre-trained on ImageNet 250 | """ 251 | model = ResNet(Bottleneck, [3, 4, 23, 3]) 252 | if pretrained: 253 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 254 | return model 255 | 256 | 257 | def resnet152(pretrained=False): 258 | """Constructs a ResNet-152 model. 259 | Args: 260 | pretrained (bool): If True, returns a model pre-trained on ImageNet 261 | """ 262 | model = ResNet(Bottleneck, [3, 8, 36, 3]) 263 | if pretrained: 264 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 265 | return model 266 | 267 | 268 | class ResNet_FPN(nn.Module): 269 | def __init__(self, num_layers=50): 270 | super(ResNet_FPN, self).__init__() 271 | self._num_layers = num_layers 272 | self._layers = {} 273 | 274 | self._init_head_tail() 275 | self.out_planes = self.fpn.planes 276 | 277 | def forward(self, x): 278 | c2 = self.head1(x) 279 | c3 = self.head2(c2) 280 | c4 = self.head3(c3) 281 | c5 = self.head4(c4) 282 | p3, p4, p5 = self.fpn( c3, c4, c5) 283 | # net_conv = [p2, p3, p4, p5] 284 | 285 | # return p2, [x, self.resnet.conv1(x), c2] 286 | return p3 287 | 288 | def _init_head_tail(self): 289 | # choose different blocks for different number of layers 290 | if self._num_layers == 50: 291 | self.resnet = resnet50() 292 | 293 | elif self._num_layers == 101: 294 | self.resnet = resnet101() 295 | 296 | elif self._num_layers == 152: 297 | self.resnet = resnet152() 298 | 299 | else: 300 | # other numbers are not supported 301 | raise NotImplementedError 302 | 303 | # Build Building Block for FPN 304 | self.fpn = BuildBlock() 305 | self.head1 = nn.Sequential( 306 | self.resnet.conv1, 307 | self.resnet.bn1, 308 | self.resnet.relu, 309 | self.resnet.maxpool, 310 | self.resnet.layer1) # /4 311 | self.head2 = nn.Sequential(self.resnet.layer2) # /8 312 | self.head3 = nn.Sequential(self.resnet.layer3) # /16 313 | self.head4 = nn.Sequential(self.resnet.layer4) # /32 314 | 315 | 316 | if __name__=='__main__': 317 | model = ResNet_FPN() 318 | 319 | x = torch.randn((2,1,64,256)) 320 | y = model(x) 321 | print(y.shape) -------------------------------------------------------------------------------- /modules/feature_extraction.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | # 2020-05-11 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class VGG_FeatureExtractor(nn.Module): 8 | """ FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf) """ 9 | 10 | def __init__(self, input_channel, output_channel=512): 11 | super(VGG_FeatureExtractor, self).__init__() 12 | self.output_channel = [int(output_channel / 8), int(output_channel / 4), 13 | int(output_channel / 2), output_channel] # [64, 128, 256, 512] 14 | self.ConvNet = nn.Sequential( 15 | nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True), 16 | nn.MaxPool2d(2, 2), # 64x16x50 17 | nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1), nn.ReLU(True), 18 | nn.MaxPool2d(2, 2), # 128x8x25 19 | nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), nn.ReLU(True), # 256x8x25 20 | nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1), nn.ReLU(True), 21 | nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25 22 | nn.Conv2d(self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False), 23 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), # 512x4x25 24 | nn.Conv2d(self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False), 25 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), 26 | nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25 27 | nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True)) # 512x1x24 28 | 29 | def forward(self, input): 30 | return self.ConvNet(input) 31 | 32 | 33 | class RCNN_FeatureExtractor(nn.Module): 34 | """ FeatureExtractor of GRCNN (https://papers.nips.cc/paper/6637-gated-recurrent-convolution-neural-network-for-ocr.pdf) """ 35 | 36 | def __init__(self, input_channel, output_channel=512): 37 | super(RCNN_FeatureExtractor, self).__init__() 38 | self.output_channel = [int(output_channel / 8), int(output_channel / 4), 39 | int(output_channel / 2), output_channel] # [64, 128, 256, 512] 40 | self.ConvNet = nn.Sequential( 41 | nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True), 42 | nn.MaxPool2d(2, 2), # 64 x 16 x 50 43 | GRCL(self.output_channel[0], self.output_channel[0], num_iteration=5, kernel_size=3, pad=1), 44 | nn.MaxPool2d(2, 2), # 64 x 8 x 25 45 | GRCL(self.output_channel[0], self.output_channel[1], num_iteration=5, kernel_size=3, pad=1), 46 | nn.MaxPool2d(2, (2, 1), (0, 1)), # 128 x 4 x 26 47 | GRCL(self.output_channel[1], self.output_channel[2], num_iteration=5, kernel_size=3, pad=1), 48 | nn.MaxPool2d(2, (2, 1), (0, 1)), # 256 x 2 x 27 49 | nn.Conv2d(self.output_channel[2], self.output_channel[3], 2, 1, 0, bias=False), 50 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True)) # 512 x 1 x 26 51 | 52 | def forward(self, input): 53 | return self.ConvNet(input) 54 | 55 | 56 | class ResNet_FeatureExtractor(nn.Module): 57 | """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """ 58 | 59 | def __init__(self, input_channel, output_channel=512): 60 | super(ResNet_FeatureExtractor, self).__init__() 61 | self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3]) 62 | 63 | def forward(self, input): 64 | return self.ConvNet(input) 65 | 66 | 67 | # For Gated RCNN 68 | class GRCL(nn.Module): 69 | 70 | def __init__(self, input_channel, output_channel, num_iteration, kernel_size, pad): 71 | super(GRCL, self).__init__() 72 | self.wgf_u = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=False) 73 | self.wgr_x = nn.Conv2d(output_channel, output_channel, 1, 1, 0, bias=False) 74 | self.wf_u = nn.Conv2d(input_channel, output_channel, kernel_size, 1, pad, bias=False) 75 | self.wr_x = nn.Conv2d(output_channel, output_channel, kernel_size, 1, pad, bias=False) 76 | 77 | self.BN_x_init = nn.BatchNorm2d(output_channel) 78 | 79 | self.num_iteration = num_iteration 80 | self.GRCL = [GRCL_unit(output_channel) for _ in range(num_iteration)] 81 | self.GRCL = nn.Sequential(*self.GRCL) 82 | 83 | def forward(self, input): 84 | """ The input of GRCL is consistant over time t, which is denoted by u(0) 85 | thus wgf_u / wf_u is also consistant over time t. 86 | """ 87 | wgf_u = self.wgf_u(input) 88 | wf_u = self.wf_u(input) 89 | x = F.relu(self.BN_x_init(wf_u)) 90 | 91 | for i in range(self.num_iteration): 92 | x = self.GRCL[i](wgf_u, self.wgr_x(x), wf_u, self.wr_x(x)) 93 | 94 | return x 95 | 96 | 97 | class GRCL_unit(nn.Module): 98 | 99 | def __init__(self, output_channel): 100 | super(GRCL_unit, self).__init__() 101 | self.BN_gfu = nn.BatchNorm2d(output_channel) 102 | self.BN_grx = nn.BatchNorm2d(output_channel) 103 | self.BN_fu = nn.BatchNorm2d(output_channel) 104 | self.BN_rx = nn.BatchNorm2d(output_channel) 105 | self.BN_Gx = nn.BatchNorm2d(output_channel) 106 | 107 | def forward(self, wgf_u, wgr_x, wf_u, wr_x): 108 | G_first_term = self.BN_gfu(wgf_u) 109 | G_second_term = self.BN_grx(wgr_x) 110 | G = F.sigmoid(G_first_term + G_second_term) 111 | 112 | x_first_term = self.BN_fu(wf_u) 113 | x_second_term = self.BN_Gx(self.BN_rx(wr_x) * G) 114 | x = F.relu(x_first_term + x_second_term) 115 | 116 | return x 117 | 118 | 119 | class BasicBlock(nn.Module): 120 | expansion = 1 121 | 122 | def __init__(self, inplanes, planes, stride=1, downsample=None): 123 | super(BasicBlock, self).__init__() 124 | self.conv1 = self._conv3x3(inplanes, planes) 125 | self.bn1 = nn.BatchNorm2d(planes) 126 | self.conv2 = self._conv3x3(planes, planes) 127 | self.bn2 = nn.BatchNorm2d(planes) 128 | self.relu = nn.ReLU(inplace=True) 129 | self.downsample = downsample 130 | self.stride = stride 131 | 132 | def _conv3x3(self, in_planes, out_planes, stride=1): 133 | "3x3 convolution with padding" 134 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 135 | padding=1, bias=False) 136 | 137 | def forward(self, x): 138 | residual = x 139 | 140 | out = self.conv1(x) 141 | out = self.bn1(out) 142 | out = self.relu(out) 143 | 144 | out = self.conv2(out) 145 | out = self.bn2(out) 146 | 147 | if self.downsample is not None: 148 | residual = self.downsample(x) 149 | out += residual 150 | out = self.relu(out) 151 | 152 | return out 153 | 154 | 155 | class ResNet(nn.Module): 156 | 157 | def __init__(self, input_channel, output_channel, block, layers): 158 | super(ResNet, self).__init__() 159 | 160 | self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel] 161 | 162 | self.inplanes = int(output_channel / 8) 163 | self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16), 164 | kernel_size=3, stride=1, padding=1, bias=False) 165 | self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16)) 166 | self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes, 167 | kernel_size=3, stride=1, padding=1, bias=False) 168 | self.bn0_2 = nn.BatchNorm2d(self.inplanes) 169 | self.relu = nn.ReLU(inplace=True) 170 | 171 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 172 | self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0]) 173 | self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[ 174 | 0], kernel_size=3, stride=1, padding=1, bias=False) 175 | self.bn1 = nn.BatchNorm2d(self.output_channel_block[0]) 176 | 177 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 178 | self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1) 179 | self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[ 180 | 1], kernel_size=3, stride=1, padding=1, bias=False) 181 | self.bn2 = nn.BatchNorm2d(self.output_channel_block[1]) 182 | 183 | self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1)) 184 | self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1) 185 | self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[ 186 | 2], kernel_size=3, stride=1, padding=1, bias=False) 187 | self.bn3 = nn.BatchNorm2d(self.output_channel_block[2]) 188 | 189 | self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1) 190 | self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ 191 | 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False) 192 | self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3]) 193 | self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ 194 | 3], kernel_size=2, stride=1, padding=0, bias=False) 195 | self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3]) 196 | 197 | def _make_layer(self, block, planes, blocks, stride=1): 198 | downsample = None 199 | if stride != 1 or self.inplanes != planes * block.expansion: 200 | downsample = nn.Sequential( 201 | nn.Conv2d(self.inplanes, planes * block.expansion, 202 | kernel_size=1, stride=stride, bias=False), 203 | nn.BatchNorm2d(planes * block.expansion), 204 | ) 205 | 206 | layers = [] 207 | layers.append(block(self.inplanes, planes, stride, downsample)) 208 | self.inplanes = planes * block.expansion 209 | for i in range(1, blocks): 210 | layers.append(block(self.inplanes, planes)) 211 | 212 | return nn.Sequential(*layers) 213 | 214 | def forward(self, x): 215 | x = self.conv0_1(x) 216 | x = self.bn0_1(x) 217 | x = self.relu(x) 218 | x = self.conv0_2(x) 219 | x = self.bn0_2(x) 220 | x = self.relu(x) 221 | 222 | x = self.maxpool1(x) 223 | x = self.layer1(x) 224 | x = self.conv1(x) 225 | x = self.bn1(x) 226 | x = self.relu(x) 227 | 228 | x = self.maxpool2(x) 229 | x = self.layer2(x) 230 | x = self.conv2(x) 231 | x = self.bn2(x) 232 | x = self.relu(x) 233 | 234 | x = self.maxpool3(x) 235 | x = self.layer3(x) 236 | x = self.conv3(x) 237 | x = self.bn3(x) 238 | x = self.relu(x) 239 | 240 | x = self.layer4(x) 241 | x = self.conv4_1(x) 242 | x = self.bn4_1(x) 243 | x = self.relu(x) 244 | x = self.conv4_2(x) 245 | x = self.bn4_2(x) 246 | x = self.relu(x) 247 | 248 | return x 249 | -------------------------------------------------------------------------------- /modules/SRN_Resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class VGG_FeatureExtractor(nn.Module): 7 | """ FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf) """ 8 | 9 | def __init__(self, input_channel, output_channel=512): 10 | super(VGG_FeatureExtractor, self).__init__() 11 | self.output_channel = [int(output_channel / 8), int(output_channel / 4), 12 | int(output_channel / 2), output_channel] # [64, 128, 256, 512] 13 | self.ConvNet = nn.Sequential( 14 | nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True), 15 | nn.MaxPool2d(2, 2), # 64x16x50 16 | nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1), nn.ReLU(True), 17 | nn.MaxPool2d(2, 2), # 128x8x25 18 | nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), nn.ReLU(True), # 256x8x25 19 | nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1), nn.ReLU(True), 20 | nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25 21 | nn.Conv2d(self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False), 22 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), # 512x4x25 23 | nn.Conv2d(self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False), 24 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), 25 | nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25 26 | nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True)) # 512x1x24 27 | 28 | def forward(self, input): 29 | return self.ConvNet(input) 30 | 31 | 32 | class RCNN_FeatureExtractor(nn.Module): 33 | """ FeatureExtractor of GRCNN (https://papers.nips.cc/paper/6637-gated-recurrent-convolution-neural-network-for-ocr.pdf) """ 34 | 35 | def __init__(self, input_channel, output_channel=512): 36 | super(RCNN_FeatureExtractor, self).__init__() 37 | self.output_channel = [int(output_channel / 8), int(output_channel / 4), 38 | int(output_channel / 2), output_channel] # [64, 128, 256, 512] 39 | self.ConvNet = nn.Sequential( 40 | nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True), 41 | nn.MaxPool2d(2, 2), # 64 x 16 x 50 42 | GRCL(self.output_channel[0], self.output_channel[0], num_iteration=5, kernel_size=3, pad=1), 43 | nn.MaxPool2d(2, 2), # 64 x 8 x 25 44 | GRCL(self.output_channel[0], self.output_channel[1], num_iteration=5, kernel_size=3, pad=1), 45 | nn.MaxPool2d(2, (2, 1), (0, 1)), # 128 x 4 x 26 46 | GRCL(self.output_channel[1], self.output_channel[2], num_iteration=5, kernel_size=3, pad=1), 47 | nn.MaxPool2d(2, (2, 1), (0, 1)), # 256 x 2 x 27 48 | nn.Conv2d(self.output_channel[2], self.output_channel[3], 2, 1, 0, bias=False), 49 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True)) # 512 x 1 x 26 50 | 51 | def forward(self, input): 52 | return self.ConvNet(input) 53 | 54 | 55 | class ResNet_FeatureExtractor(nn.Module): 56 | """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """ 57 | 58 | def __init__(self, input_channel, output_channel=512): 59 | super(ResNet_FeatureExtractor, self).__init__() 60 | self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3]) 61 | 62 | def forward(self, input): 63 | return self.ConvNet(input) 64 | 65 | 66 | # For Gated RCNN 67 | class GRCL(nn.Module): 68 | 69 | def __init__(self, input_channel, output_channel, num_iteration, kernel_size, pad): 70 | super(GRCL, self).__init__() 71 | self.wgf_u = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=False) 72 | self.wgr_x = nn.Conv2d(output_channel, output_channel, 1, 1, 0, bias=False) 73 | self.wf_u = nn.Conv2d(input_channel, output_channel, kernel_size, 1, pad, bias=False) 74 | self.wr_x = nn.Conv2d(output_channel, output_channel, kernel_size, 1, pad, bias=False) 75 | 76 | self.BN_x_init = nn.BatchNorm2d(output_channel) 77 | 78 | self.num_iteration = num_iteration 79 | self.GRCL = [GRCL_unit(output_channel) for _ in range(num_iteration)] 80 | self.GRCL = nn.Sequential(*self.GRCL) 81 | 82 | def forward(self, input): 83 | """ The input of GRCL is consistant over time t, which is denoted by u(0) 84 | thus wgf_u / wf_u is also consistant over time t. 85 | """ 86 | wgf_u = self.wgf_u(input) 87 | wf_u = self.wf_u(input) 88 | x = F.relu(self.BN_x_init(wf_u)) 89 | 90 | for i in range(self.num_iteration): 91 | x = self.GRCL[i](wgf_u, self.wgr_x(x), wf_u, self.wr_x(x)) 92 | 93 | return x 94 | 95 | 96 | class GRCL_unit(nn.Module): 97 | 98 | def __init__(self, output_channel): 99 | super(GRCL_unit, self).__init__() 100 | self.BN_gfu = nn.BatchNorm2d(output_channel) 101 | self.BN_grx = nn.BatchNorm2d(output_channel) 102 | self.BN_fu = nn.BatchNorm2d(output_channel) 103 | self.BN_rx = nn.BatchNorm2d(output_channel) 104 | self.BN_Gx = nn.BatchNorm2d(output_channel) 105 | 106 | def forward(self, wgf_u, wgr_x, wf_u, wr_x): 107 | G_first_term = self.BN_gfu(wgf_u) 108 | G_second_term = self.BN_grx(wgr_x) 109 | G = F.sigmoid(G_first_term + G_second_term) 110 | 111 | x_first_term = self.BN_fu(wf_u) 112 | x_second_term = self.BN_Gx(self.BN_rx(wr_x) * G) 113 | x = F.relu(x_first_term + x_second_term) 114 | 115 | return x 116 | 117 | 118 | class BasicBlock(nn.Module): 119 | expansion = 1 120 | 121 | def __init__(self, inplanes, planes, stride=1, downsample=None): 122 | super(BasicBlock, self).__init__() 123 | self.conv1 = self._conv3x3(inplanes, planes) 124 | self.bn1 = nn.BatchNorm2d(planes) 125 | self.conv2 = self._conv3x3(planes, planes) 126 | self.bn2 = nn.BatchNorm2d(planes) 127 | self.relu = nn.ReLU(inplace=True) 128 | self.downsample = downsample 129 | self.stride = stride 130 | 131 | def _conv3x3(self, in_planes, out_planes, stride=1): 132 | "3x3 convolution with padding" 133 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 134 | padding=1, bias=False) 135 | 136 | def forward(self, x): 137 | residual = x 138 | 139 | out = self.conv1(x) 140 | out = self.bn1(out) 141 | out = self.relu(out) 142 | 143 | out = self.conv2(out) 144 | out = self.bn2(out) 145 | 146 | if self.downsample is not None: 147 | residual = self.downsample(x) 148 | out += residual 149 | out = self.relu(out) 150 | 151 | return out 152 | 153 | 154 | class ResNet(nn.Module): 155 | 156 | def __init__(self, input_channel, output_channel, block, layers): 157 | super(ResNet, self).__init__() 158 | 159 | self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel] 160 | 161 | self.inplanes = int(output_channel / 8) 162 | self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16), 163 | kernel_size=3, stride=1, padding=1, bias=False) 164 | self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16)) 165 | self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes, 166 | kernel_size=3, stride=1, padding=1, bias=False) 167 | self.bn0_2 = nn.BatchNorm2d(self.inplanes) 168 | self.relu = nn.ReLU(inplace=True) 169 | 170 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 171 | self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0]) 172 | self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[ 173 | 0], kernel_size=3, stride=1, padding=1, bias=False) 174 | self.bn1 = nn.BatchNorm2d(self.output_channel_block[0]) 175 | 176 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 177 | self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1) 178 | self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[ 179 | 1], kernel_size=3, stride=1, padding=1, bias=False) 180 | self.bn2 = nn.BatchNorm2d(self.output_channel_block[1]) 181 | 182 | self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1)) 183 | self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1) 184 | self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[ 185 | 2], kernel_size=3, stride=1, padding=1, bias=False) 186 | self.bn3 = nn.BatchNorm2d(self.output_channel_block[2]) 187 | 188 | self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1) 189 | self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ 190 | 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False) 191 | self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3]) 192 | self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ 193 | 3], kernel_size=2, stride=1, padding=0, bias=False) 194 | self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3]) 195 | 196 | def _make_layer(self, block, planes, blocks, stride=1): 197 | downsample = None 198 | if stride != 1 or self.inplanes != planes * block.expansion: 199 | downsample = nn.Sequential( 200 | nn.Conv2d(self.inplanes, planes * block.expansion, 201 | kernel_size=1, stride=stride, bias=False), 202 | nn.BatchNorm2d(planes * block.expansion), 203 | ) 204 | 205 | layers = [] 206 | layers.append(block(self.inplanes, planes, stride, downsample)) 207 | self.inplanes = planes * block.expansion 208 | for i in range(1, blocks): 209 | layers.append(block(self.inplanes, planes)) 210 | 211 | return nn.Sequential(*layers) 212 | 213 | def forward(self, x): 214 | x = self.conv0_1(x) 215 | x = self.bn0_1(x) 216 | x = self.relu(x) 217 | x = self.conv0_2(x) 218 | x = self.bn0_2(x) 219 | x = self.relu(x) 220 | 221 | x = self.maxpool1(x) 222 | x = self.layer1(x) 223 | x = self.conv1(x) 224 | x = self.bn1(x) 225 | x = self.relu(x) 226 | 227 | x = self.maxpool2(x) 228 | x = self.layer2(x) 229 | x = self.conv2(x) 230 | x = self.bn2(x) 231 | x = self.relu(x) 232 | 233 | x = self.maxpool3(x) 234 | x = self.layer3(x) 235 | x = self.conv3(x) 236 | x = self.bn3(x) 237 | x = self.relu(x) 238 | 239 | x = self.layer4(x) 240 | x = self.conv4_1(x) 241 | x = self.bn4_1(x) 242 | x = self.relu(x) 243 | x = self.conv4_2(x) 244 | x = self.bn4_2(x) 245 | x = self.relu(x) 246 | 247 | return x 248 | 249 | if __name__=='__main__': 250 | x = torch.rand(4, 1, 32, 100) 251 | # x = x.cuda() 252 | model = ResNet_FeatureExtractor(1, 5112) 253 | # model = model.cuda() 254 | y = model(x) 255 | print(y.shape) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import string 4 | import argparse 5 | 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import torch.utils.data 9 | import numpy as np 10 | from nltk.metrics.distance import edit_distance 11 | 12 | from utils import CTCLabelConverter, AttnLabelConverter, Averager, TransformerConverter, SRNConverter 13 | from dataset import hierarchical_dataset, AlignCollate 14 | from model import Model 15 | from modules.SRN_modules import cal_performance2 16 | 17 | 18 | def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=False): 19 | """ evaluation with 10 benchmark evaluation datasets """ 20 | # The evaluation datasets, dataset order is same with Table 1 in our paper. 21 | eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', 22 | 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80'] 23 | 24 | if calculate_infer_time: 25 | evaluation_batch_size = 1 # batch_size should be 1 to calculate the GPU inference time per image. 26 | else: 27 | evaluation_batch_size = opt.batch_size 28 | 29 | list_accuracy = [] 30 | total_forward_time = 0 31 | total_evaluation_data_number = 0 32 | total_correct_number = 0 33 | print('-' * 80) 34 | for eval_data in eval_data_list: 35 | eval_data_path = os.path.join(opt.eval_data, eval_data) 36 | AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 37 | eval_data = hierarchical_dataset(root=eval_data_path, opt=opt) 38 | evaluation_loader = torch.utils.data.DataLoader( 39 | eval_data, batch_size=evaluation_batch_size, 40 | shuffle=False, 41 | num_workers=int(opt.workers), 42 | collate_fn=AlignCollate_evaluation, pin_memory=True) 43 | 44 | _, accuracy_by_best_model, norm_ED_by_best_model, _, _, infer_time, length_of_data = validation( 45 | model, criterion, evaluation_loader, converter, opt) 46 | list_accuracy.append(f'{accuracy_by_best_model:0.3f}') 47 | total_forward_time += infer_time 48 | total_evaluation_data_number += len(eval_data) 49 | total_correct_number += accuracy_by_best_model * length_of_data 50 | print('Acc %0.3f\t normalized_ED %0.3f' % (accuracy_by_best_model, norm_ED_by_best_model)) 51 | print('-' * 80) 52 | 53 | averaged_forward_time = total_forward_time / total_evaluation_data_number * 1000 54 | total_accuracy = total_correct_number / total_evaluation_data_number 55 | params_num = sum([np.prod(p.size()) for p in model.parameters()]) 56 | 57 | evaluation_log = 'accuracy: ' 58 | for name, accuracy in zip(eval_data_list, list_accuracy): 59 | evaluation_log += f'{name}: {accuracy}\t' 60 | evaluation_log += f'total_accuracy: {total_accuracy:0.3f}\t' 61 | evaluation_log += f'averaged_infer_time: {averaged_forward_time:0.3f}\t# parameters: {params_num/1e6:0.3f}' 62 | print(evaluation_log) 63 | with open(f'./result/{opt.experiment_name}/log_all_evaluation.txt', 'a') as log: 64 | log.write(evaluation_log + '\n') 65 | 66 | return None 67 | 68 | 69 | def validation(model, criterion, evaluation_loader, converter, opt): 70 | """ validation or evaluation """ 71 | print('start validation') 72 | for p in model.parameters(): 73 | p.requires_grad = False 74 | 75 | n_correct = 0 76 | norm_ED = 0 77 | length_of_data = 0 78 | infer_time = 0 79 | valid_loss_avg = Averager() 80 | 81 | for i, (image_tensors, labels) in enumerate(evaluation_loader): 82 | batch_size = image_tensors.size(0) 83 | length_of_data = length_of_data + batch_size 84 | image = image_tensors.cuda() 85 | # For max length prediction 86 | length_for_pred = torch.cuda.IntTensor([opt.batch_max_length] * batch_size) 87 | text_for_pred = torch.cuda.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0) 88 | 89 | if 'SRN' in opt.Prediction: 90 | text_for_loss, length_for_loss = converter.encode(labels) 91 | else: 92 | text_for_loss, length_for_loss = converter.encode(labels) 93 | 94 | start_time = time.time() 95 | if 'CTC' in opt.Prediction: 96 | preds = model(image, text_for_pred).log_softmax(2) 97 | forward_time = time.time() - start_time 98 | 99 | # Calculate evaluation loss for CTC deocder. 100 | preds_size = torch.IntTensor([preds.size(1)] * batch_size) 101 | preds = preds.permute(1, 0, 2) # to use CTCloss format 102 | cost = criterion(preds, text_for_loss, preds_size, length_for_loss) 103 | 104 | # Select max probabilty (greedy decoding) then decode index to character 105 | _, preds_index = preds.max(2) 106 | preds_index = preds_index.transpose(1, 0).contiguous().view(-1) 107 | preds_str = converter.decode(preds_index.data, preds_size.data) 108 | 109 | elif 'Bert' in opt.Prediction: 110 | with torch.no_grad(): 111 | pad_mask = None 112 | preds = model(image, pad_mask) 113 | forward_time = time.time() - start_time 114 | 115 | cost = criterion(preds[0].view(-1, preds[0].shape[-1]), text_for_loss.contiguous().view(-1)) + \ 116 | criterion(preds[1].view(-1, preds[1].shape[-1]), text_for_loss.contiguous().view(-1)) 117 | 118 | # select max probabilty (greedy decoding) then decode index to character 119 | _, preds_index = preds[1].max(2) 120 | length_for_pred = torch.cuda.IntTensor([preds_index.size(-1)] * batch_size) 121 | preds_str = converter.decode(preds_index, length_for_pred) 122 | labels = converter.decode(text_for_loss, length_for_loss) 123 | 124 | elif 'SRN' in opt.Prediction: 125 | with torch.no_grad(): 126 | preds = model(image, None) 127 | forward_time = time.time() - start_time 128 | 129 | cost, train_correct = criterion(preds, text_for_loss, opt.SRN_PAD) 130 | 131 | # select max probabilty (greedy decoding) then decode index to character 132 | _, preds_index = preds[2].max(2) 133 | preds_str = converter.decode(preds_index, length_for_pred) 134 | labels = converter.decode(text_for_loss, length_for_loss) 135 | 136 | else: 137 | preds = model(image, text_for_pred, is_train=False) 138 | forward_time = time.time() - start_time 139 | 140 | preds = preds[:, :text_for_loss.shape[1] - 1, :] 141 | target = text_for_loss[:, 1:] # without [GO] Symbol 142 | cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), target.contiguous().view(-1)) 143 | 144 | # select max probabilty (greedy decoding) then decode index to character 145 | _, preds_index = preds.max(2) 146 | preds_str = converter.decode(preds_index, length_for_pred) 147 | labels = converter.decode(text_for_loss[:, 1:], length_for_loss) 148 | 149 | infer_time += forward_time 150 | valid_loss_avg.add(cost) 151 | 152 | # calculate accuracy. 153 | for pred, gt in zip(preds_str, labels): 154 | if 'Attn' in opt.Prediction: 155 | pred = pred[:pred.find('[s]')] # prune after "end of sentence" token ([s]) 156 | gt = gt[:gt.find('[s]')] 157 | 158 | if pred == gt: 159 | n_correct += 1 160 | else: 161 | temp = 1 162 | 163 | if len(gt) == 0: 164 | norm_ED += 1 165 | else: 166 | norm_ED += edit_distance(pred, gt) / len(gt) 167 | 168 | accuracy = n_correct / float(length_of_data) * 100 169 | 170 | return valid_loss_avg.val(), accuracy, norm_ED, preds_str, labels, infer_time, length_of_data 171 | 172 | 173 | def test(opt): 174 | """ model configuration """ 175 | if 'CTC' in opt.Prediction: 176 | converter = CTCLabelConverter(opt.character) 177 | elif 'Bert' in opt.Prediction: 178 | converter = TransformerConverter(opt.character, opt.batch_max_length) 179 | elif 'SRN' in opt.Prediction: 180 | converter = SRNConverter(opt.character, opt.SRN_PAD) 181 | else: 182 | converter = AttnLabelConverter(opt.character) 183 | opt.num_class = len(converter.character) 184 | 185 | if opt.rgb: 186 | opt.input_channel = 3 187 | model = Model(opt) 188 | print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, 189 | opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, 190 | opt.SequenceModeling, opt.Prediction) 191 | model = torch.nn.DataParallel(model).cuda() 192 | 193 | # load model 194 | print('loading pretrained model from %s' % opt.saved_model) 195 | model.load_state_dict(torch.load(opt.saved_model)) 196 | opt.experiment_name = '_'.join(opt.saved_model.split('/')[1:]) 197 | # print(model) 198 | 199 | """ keep evaluation model and result logs """ 200 | os.makedirs(f'./result/{opt.experiment_name}', exist_ok=True) 201 | os.system(f'cp {opt.saved_model} ./result/{opt.experiment_name}/') 202 | 203 | """ setup loss """ 204 | if 'CTC' in opt.Prediction: 205 | criterion = torch.nn.CTCLoss(zero_infinity=True).cuda() 206 | elif 'SRN' in opt.Prediction: 207 | # criterion = torch.nn.CrossEntropyLoss().cuda() 208 | criterion = cal_performance2 209 | else: 210 | criterion = torch.nn.CrossEntropyLoss(ignore_index=0).cuda() # ignore [GO] token = ignore index 0 211 | 212 | """ evaluation """ 213 | model.eval() 214 | if opt.benchmark_all_eval: # evaluation with 10 benchmark evaluation datasets 215 | benchmark_all_eval(model, criterion, converter, opt) 216 | else: 217 | AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 218 | eval_data = hierarchical_dataset(root=opt.eval_data, opt=opt) 219 | evaluation_loader = torch.utils.data.DataLoader( 220 | eval_data, batch_size=opt.batch_size, 221 | shuffle=False, 222 | num_workers=int(opt.workers), 223 | collate_fn=AlignCollate_evaluation, pin_memory=True) 224 | _, accuracy_by_best_model, _, _, _, _, _ = validation( 225 | model, criterion, evaluation_loader, converter, opt) 226 | 227 | print(accuracy_by_best_model) 228 | with open('./result/{0}/log_evaluation.txt'.format(opt.experiment_name), 'a') as log: 229 | log.write(str(accuracy_by_best_model) + '\n') 230 | 231 | 232 | if __name__ == '__main__': 233 | parser = argparse.ArgumentParser() 234 | parser.add_argument('--eval_data', default='/home/deepblue/deepbluetwo/chenjun/1_OCR/data/data_lmdb_release/evaluation', help='path to evaluation dataset') 235 | parser.add_argument('--benchmark_all_eval', default=True, help='evaluate 10 benchmark evaluation datasets') 236 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) 237 | parser.add_argument('--batch_size', type=int, default=64, help='input batch size') 238 | parser.add_argument('--saved_model', default='./saved_models/None-ResNet-SRN-SRN-Seed666/iter_65000.pth', help="path to saved_model to evaluation") 239 | """ Data processing """ 240 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') 241 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') 242 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') 243 | parser.add_argument('--rgb', action='store_true', help='use rgb input') 244 | parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz$#', help='character label') 245 | parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode') 246 | parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize') 247 | parser.add_argument('--data_filtering_off', action='store_true', help='for data_filtering_off mode') 248 | """ Model Architecture """ 249 | parser.add_argument('--Transformation', type=str, default='None', help='Transformation stage. None|TPS') 250 | parser.add_argument('--FeatureExtraction', type=str, default='ResNet', help='FeatureExtraction stage. VGG|RCNN|ResNet|AsterRes') 251 | parser.add_argument('--SequenceModeling', type=str, default='SRN', help='SequenceModeling stage. None|BiLSTM|Bert') 252 | parser.add_argument('--Prediction', type=str, default='SRN', help='Prediction stage. CTC|Attn|Bert_pred') 253 | parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') 254 | parser.add_argument('--input_channel', type=int, default=1, help='the number of input channel of Feature extractor') 255 | parser.add_argument('--output_channel', type=int, default=512, 256 | help='the number of output channel of Feature extractor') 257 | parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') 258 | parser.add_argument('--position_dim', type=int, default=26, help='the length sequence out from cnn encoder,resnet:65;resnetfpn:256') 259 | 260 | parser.add_argument('--SRN_PAD', type=int, default=36, help='the pad character for srn') 261 | parser.add_argument('--batch_max_character', type=int, default=25, help='the max sequence length') 262 | opt = parser.parse_args() 263 | 264 | """ vocab / character number configuration """ 265 | if opt.sensitive: 266 | opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). 267 | 268 | opt.alphabet_size = len(opt.character) # 269 | opt.SRN_PAD = len(opt.character)-1 270 | cudnn.benchmark = True 271 | cudnn.deterministic = True 272 | opt.num_gpu = torch.cuda.device_count() 273 | 274 | test(opt) 275 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | import six 5 | import math 6 | import lmdb 7 | import torch 8 | 9 | from natsort import natsorted 10 | from PIL import Image 11 | import numpy as np 12 | from torch.utils.data import Dataset, ConcatDataset, Subset 13 | from torch._utils import _accumulate 14 | import torchvision.transforms as transforms 15 | 16 | 17 | class Batch_Balanced_Dataset(object): 18 | 19 | def __init__(self, opt): 20 | """ 21 | Modulate the data ratio in the batch. 22 | For example, when select_data is "MJ-ST" and batch_ratio is "0.5-0.5", 23 | the 50% of the batch is filled with MJ and the other 50% of the batch is filled with ST. 24 | """ 25 | print('-' * 80) 26 | print(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}') 27 | assert len(opt.select_data) == len(opt.batch_ratio) 28 | 29 | _AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 30 | self.data_loader_list = [] 31 | self.dataloader_iter_list = [] 32 | self.nums_samples = 0. 33 | batch_size_list = [] 34 | Total_batch_size = 0 35 | for selected_d, batch_ratio_d in zip(opt.select_data, opt.batch_ratio): 36 | _batch_size = max(round(opt.batch_size * float(batch_ratio_d)), 1) 37 | print('-' * 80) 38 | _dataset = hierarchical_dataset(root=opt.train_data, opt=opt, select_data=[selected_d]) 39 | total_number_dataset = len(_dataset) 40 | 41 | """ 42 | The total number of data can be modified with opt.total_data_usage_ratio. 43 | ex) opt.total_data_usage_ratio = 1 indicates 100% usage, and 0.2 indicates 20% usage. 44 | See 4.2 section in our paper. 45 | """ 46 | 47 | # opt.total_data_usage_ratio = 1.0 if selected_d == 'ICDAR2019' else 0.5 48 | 49 | number_dataset = int(total_number_dataset * float(opt.total_data_usage_ratio)) 50 | dataset_split = [number_dataset, total_number_dataset - number_dataset] 51 | indices = range(total_number_dataset) 52 | _dataset, _ = [Subset(_dataset, indices[offset - length:offset]) 53 | for offset, length in zip(_accumulate(dataset_split), dataset_split)] 54 | print(f'num total samples of {selected_d}: {total_number_dataset} x {opt.total_data_usage_ratio} (total_data_usage_ratio) = {len(_dataset)}') 55 | print(f'num samples of {selected_d} per batch: {opt.batch_size} x {float(batch_ratio_d)} (batch_ratio) = {_batch_size}') 56 | batch_size_list.append(str(_batch_size)) 57 | Total_batch_size += _batch_size 58 | 59 | self.nums_samples += len(_dataset) 60 | _data_loader = torch.utils.data.DataLoader( 61 | _dataset, batch_size=_batch_size, 62 | shuffle=True, 63 | num_workers=int(opt.workers), 64 | collate_fn=_AlignCollate, pin_memory=False) 65 | self.data_loader_list.append(_data_loader) 66 | self.dataloader_iter_list.append(iter(_data_loader)) 67 | print('-' * 80) 68 | print('Total_batch_size: ', '+'.join(batch_size_list), '=', str(Total_batch_size)) 69 | opt.batch_size = Total_batch_size 70 | print('-' * 80) 71 | 72 | def get_batch(self): 73 | balanced_batch_images = [] 74 | balanced_batch_texts = [] 75 | for i, data_loader_iter in enumerate(self.dataloader_iter_list): 76 | try: 77 | image, text = data_loader_iter.next() 78 | balanced_batch_images.append(image) 79 | balanced_batch_texts += text 80 | except StopIteration: 81 | self.dataloader_iter_list[i] = iter(self.data_loader_list[i]) 82 | image, text = self.dataloader_iter_list[i].next() 83 | balanced_batch_images.append(image) 84 | balanced_batch_texts += text 85 | except ValueError: 86 | pass 87 | 88 | balanced_batch_images = torch.cat(balanced_batch_images, 0) 89 | 90 | return balanced_batch_images, balanced_batch_texts 91 | 92 | def __len__(self): 93 | return self.nums_samples 94 | 95 | 96 | def hierarchical_dataset(root, opt, select_data='/'): 97 | """ select_data='/' contains all sub-directory of root directory """ 98 | dataset_list = [] 99 | print(f'dataset_root: {root}\t dataset: {select_data[0]}') 100 | for dirpath, dirnames, filenames in os.walk(root): 101 | if not dirnames: 102 | select_flag = False 103 | for selected_d in select_data: 104 | if selected_d in dirpath: 105 | select_flag = True 106 | break 107 | 108 | if select_flag: 109 | dataset = LmdbDataset(dirpath, opt) 110 | print(f'sub-directory:\t/{os.path.relpath(dirpath, root)}\t num samples: {len(dataset)}') 111 | dataset_list.append(dataset) 112 | 113 | concatenated_dataset = ConcatDataset(dataset_list) 114 | 115 | return concatenated_dataset 116 | 117 | 118 | class LmdbDataset(Dataset): 119 | 120 | def __init__(self, root, opt): 121 | 122 | self.root = root 123 | self.opt = opt 124 | self.env = lmdb.open(root, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) 125 | if not self.env: 126 | print('cannot create lmdb from %s' % (root)) 127 | sys.exit(0) 128 | 129 | with self.env.begin(write=False) as txn: 130 | nSamples = int(txn.get('num-samples'.encode())) 131 | self.nSamples = nSamples 132 | 133 | if self.opt.data_filtering_off: 134 | # for fast check with no filtering 135 | self.filtered_index_list = [index + 1 for index in range(self.nSamples)] 136 | else: 137 | # Filtering 138 | self.filtered_index_list = [] 139 | for index in range(self.nSamples): 140 | index += 1 # lmdb starts with 1 141 | label_key = 'label-%09d'.encode() % index 142 | label = txn.get(label_key).decode('utf-8') 143 | 144 | if len(label) > self.opt.batch_max_length or len(label) == 0: 145 | # print(f'The length of the label is longer than max_length: length \ 146 | # {len(label)}, {label} in dataset {self.root}') 147 | continue 148 | 149 | # By default, images containing characters which are not in opt.character are filtered. 150 | # You can add [UNK] token to `opt.character` in utils.py instead of this filtering. 151 | out_of_char = f'[^{self.opt.character}]' 152 | if re.search(out_of_char, label.lower()): 153 | continue 154 | 155 | self.filtered_index_list.append(index) 156 | 157 | self.nSamples = len(self.filtered_index_list) 158 | 159 | def __len__(self): 160 | return self.nSamples 161 | 162 | def __getitem__(self, index): 163 | assert index <= len(self), 'index range error' 164 | index = self.filtered_index_list[index] 165 | 166 | with self.env.begin(write=False) as txn: 167 | label_key = 'label-%09d'.encode() % index 168 | label = txn.get(label_key).decode('utf-8') 169 | img_key = 'image-%09d'.encode() % index 170 | imgbuf = txn.get(img_key) 171 | 172 | buf = six.BytesIO() 173 | buf.write(imgbuf) 174 | buf.seek(0) 175 | try: 176 | if self.opt.rgb: 177 | img = Image.open(buf).convert('RGB') # for color image 178 | else: 179 | img = Image.open(buf).convert('L') 180 | 181 | except IOError: 182 | print(f'Corrupted image for {index}') 183 | # make dummy image and dummy label for corrupted image. 184 | if self.opt.rgb: 185 | img = Image.new('RGB', (self.opt.imgW, self.opt.imgH)) 186 | else: 187 | img = Image.new('L', (self.opt.imgW, self.opt.imgH)) 188 | label = '[dummy_label]' 189 | 190 | if not self.opt.sensitive: 191 | label = label.lower() 192 | 193 | # We only train and evaluate on alphanumerics (or pre-defined character set in train.py) 194 | out_of_char = f'[^{self.opt.character}]' 195 | label = re.sub(out_of_char, '', label) 196 | 197 | return (img, label) 198 | 199 | 200 | class RawDataset(Dataset): 201 | 202 | def __init__(self, root, opt): 203 | self.opt = opt 204 | self.image_path_list = [] 205 | for dirpath, dirnames, filenames in os.walk(root): 206 | for name in filenames: 207 | _, ext = os.path.splitext(name) 208 | ext = ext.lower() 209 | if ext == '.jpg' or ext == '.jpeg' or ext == '.png': 210 | self.image_path_list.append(os.path.join(dirpath, name)) 211 | 212 | self.image_path_list = natsorted(self.image_path_list) 213 | self.nSamples = len(self.image_path_list) 214 | 215 | def __len__(self): 216 | return self.nSamples 217 | 218 | def __getitem__(self, index): 219 | 220 | try: 221 | if self.opt.rgb: 222 | img = Image.open(self.image_path_list[index]).convert('RGB') # for color image 223 | else: 224 | img = Image.open(self.image_path_list[index]).convert('L') 225 | 226 | except IOError: 227 | print(f'Corrupted image for {index}') 228 | # make dummy image and dummy label for corrupted image. 229 | if self.opt.rgb: 230 | img = Image.new('RGB', (self.opt.imgW, self.opt.imgH)) 231 | else: 232 | img = Image.new('L', (self.opt.imgW, self.opt.imgH)) 233 | 234 | w, h = img.size 235 | if h > 2.0 * w: 236 | img = img.transpose(Image.ROTATE_90) 237 | 238 | return (img, self.image_path_list[index]) 239 | 240 | 241 | class baidu_raw_dataset(RawDataset): 242 | 243 | def __init__(self, root, opt): 244 | super(baidu_raw_dataset, self).__init__(root, opt) 245 | 246 | def __getitem__(self, index): 247 | try: 248 | if self.opt.rgb: 249 | img = Image.open(self.image_path_list[index]).convert('RGB') # for color image 250 | else: 251 | img = Image.open(self.image_path_list[index]).convert('L') 252 | 253 | except IOError: 254 | print(f'Corrupted image for {index}') 255 | # make dummy image and dummy label for corrupted image. 256 | if self.opt.rgb: 257 | img = Image.new('RGB', (self.opt.imgW, self.opt.imgH)) 258 | else: 259 | img = Image.new('L', (self.opt.imgW, self.opt.imgH)) 260 | 261 | w, h = img.size 262 | if h > 1.5 * w: 263 | img = img.transpose(Image.ROTATE_90) 264 | 265 | return (img, self.image_path_list[index]) 266 | 267 | 268 | 269 | class ResizeNormalize(object): 270 | 271 | def __init__(self, size, interpolation=Image.BICUBIC): 272 | self.size = size 273 | self.interpolation = interpolation 274 | self.toTensor = transforms.ToTensor() 275 | 276 | def __call__(self, img): 277 | img = img.resize(self.size, self.interpolation) 278 | img = self.toTensor(img) 279 | img.sub_(0.5).div_(0.5) 280 | return img 281 | 282 | 283 | class NormalizePAD(object): 284 | 285 | def __init__(self, max_size, PAD_type='right'): 286 | self.toTensor = transforms.ToTensor() 287 | self.max_size = max_size 288 | self.max_width_half = math.floor(max_size[2] / 2) 289 | self.PAD_type = PAD_type 290 | 291 | def __call__(self, img): 292 | img = self.toTensor(img) 293 | img.sub_(0.5).div_(0.5) 294 | c, h, w = img.size() 295 | Pad_img = torch.FloatTensor(*self.max_size).fill_(0) 296 | Pad_img[:, :, :w] = img # right pad 297 | if self.max_size[2] != w: # add border Pad 298 | Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w) 299 | 300 | return Pad_img 301 | 302 | 303 | class AlignCollate(object): 304 | 305 | def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False): 306 | self.imgH = imgH 307 | self.imgW = imgW 308 | self.keep_ratio_with_pad = keep_ratio_with_pad 309 | 310 | def __call__(self, batch): 311 | batch = filter(lambda x: x is not None, batch) 312 | images, labels = zip(*batch) 313 | 314 | if self.keep_ratio_with_pad: # same concept with 'Rosetta' paper 315 | resized_max_w = self.imgW 316 | transform = NormalizePAD((1, self.imgH, resized_max_w)) 317 | 318 | resized_images = [] 319 | for image in images: 320 | w, h = image.size 321 | ratio = w / float(h) 322 | if math.ceil(self.imgH * ratio) > self.imgW: 323 | resized_w = self.imgW 324 | else: 325 | resized_w = math.ceil(self.imgH * ratio) 326 | 327 | resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC) 328 | resized_images.append(transform(resized_image)) 329 | # resized_image.save('./image_test/%d_test.jpg' % w) 330 | 331 | image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0) 332 | 333 | else: 334 | transform = ResizeNormalize((self.imgW, self.imgH)) 335 | image_tensors = [transform(image) for image in images] 336 | image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0) 337 | 338 | return image_tensors, labels 339 | 340 | 341 | def tensor2im(image_tensor, imtype=np.uint8): 342 | image_numpy = image_tensor.cpu().float().numpy() 343 | if image_numpy.shape[0] == 1: 344 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 345 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 346 | return image_numpy.astype(imtype) 347 | 348 | 349 | def save_image(image_numpy, image_path): 350 | image_pil = Image.fromarray(image_numpy) 351 | image_pil.save(image_path) 352 | -------------------------------------------------------------------------------- /src/baidudataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # encoding: utf-8 3 | 4 | import random 5 | import torch 6 | from torch.utils.data import Dataset 7 | from torch.utils.data import sampler 8 | import torchvision.transforms as transforms 9 | # import lmdb 10 | import six 11 | import sys 12 | from PIL import Image 13 | import numpy as np 14 | import cv2 15 | import os 16 | import math 17 | # import matplotlib.pyplot as plt 18 | from PIL import ImageFile 19 | ImageFile.LOAD_TRUNCATED_IMAGES = True 20 | 21 | 22 | def random_crop(img, coor, r_int=[-10, -8, -6, -2, 0, 2, 4, 6, 8], train=True): 23 | '''对文字区域随机往内或外扩充 24 | img:cv2格式 25 | coor:list,8维 26 | ''' 27 | # 找到最大的矩形区域 28 | h,w,_ = img.shape 29 | coor = [int(x) for x in coor] 30 | minx = min(coor[::2]) 31 | maxx = max(coor[::2]) 32 | miny = min(coor[1::2]) 33 | maxy = max(coor[1::2]) 34 | 35 | newx1, newy1 = minx, miny # 左上角的点 36 | newx3, newy3 = maxx, maxy # 右下角的点 37 | 38 | randint = random.choice(r_int) # 随机扩大或缩小 39 | 40 | if len(coor) == 8: # 8个点的区域标注,4周同时扩 41 | newx1 += randint; newx1 = max(newx1, 0) 42 | newy1 += randint; newy1 = max(newy1, 0) 43 | newx3 -= randint; newx3 = min(newx3, w) 44 | newy3 -= randint; newy3 = min(newy3, h) 45 | 46 | newimg = img[newy1:newy3, newx1:newx3, :] # crop出图像 47 | # 判断是否需要旋转 48 | if abs(newy3 - newy1) / abs(newx3-newx1) > 1.5: # y方向的距离大于x方向的距离 49 | newimg = np.rot90(newimg) # 旋转90度 50 | # 批量训练的时候需要等宽,等比例缩放再填充,放到dataloader中进行,这样每一批次按最长的填充就好了 51 | if train: 52 | pass 53 | 54 | elif len(coor) == 4: # 只扩大宽度方向 55 | newx1 += randint; newx1 = max(newx1, 0) 56 | newx3 -= randint; newx3 = min(newx3, w) 57 | 58 | newimg = img[newy1:newy3, newx1:newx3, :] # crop出图像 59 | 60 | return Image.fromarray(newimg) 61 | 62 | 63 | 64 | class BAIDUset(Dataset): 65 | ''' 66 | baidu oc识别的数据集,图片都已经crop好了,有水平和竖直两种图片 67 | ''' 68 | def __init__(self, opt, csv_root, transform=None, target_transform=None): 69 | self.opt = opt 70 | self.root = opt.root 71 | with open(csv_root) as f: 72 | self.labels = f.readlines() 73 | self.transform = transform 74 | self.target_transform = target_transform 75 | 76 | def __len__(self): 77 | return len(self.labels) 78 | 79 | def __getitem__(self, idx): 80 | per_label = self.labels[idx].rstrip().split('\t') 81 | # imgpath = os.path.join(self.root, per_label[0]) # 图片位置 82 | # text = per_label[1].replace(' ','') # 图片的文字label 83 | imgpath = os.path.join(self.root, per_label[0].rstrip()) # 图片位置 84 | text = per_label[1].strip() 85 | 86 | try: 87 | if self.opt.rgb: 88 | img = Image.open(imgpath).convert('RGB') # for color image 89 | else: 90 | img = Image.open(imgpath).convert('L') 91 | 92 | except IOError: 93 | print(imgpath) 94 | print(f'Corrupted image for {idx}') 95 | # make dummy image and dummy label for corrupted image. 96 | if self.opt.rgb: 97 | img = Image.new('RGB', (self.opt.imgW, self.opt.imgH)) 98 | else: 99 | img = Image.new('L', (self.opt.imgW, self.opt.imgH)) 100 | 101 | ''' 102 | for vertice text, rotate 90 103 | ''' 104 | w, h = img.size 105 | if h > 1.5 * w and len(text) > 2: 106 | img = img.transpose(Image.ROTATE_90) 107 | 108 | return (img, text) 109 | 110 | 111 | class BaiduCollate(object): 112 | '''每个batch按最宽的图像进行填充 113 | ''' 114 | 115 | def __init__(self, imgH=32, imgW=128, keep_ratio=True): 116 | self.imgH = imgH 117 | self.imgW = imgW 118 | self.keep_ratio = keep_ratio 119 | 120 | def __call__(self, batch): 121 | images, labels = zip(*batch) 122 | 123 | imgH = self.imgH 124 | resized_images = [] 125 | if self.keep_ratio: # 等比例缩放 126 | for image in images: 127 | w, h = image.size 128 | scale = 1.0 * h / self.imgH 129 | image = image.resize((int(w/scale),int(h/scale)), Image.BILINEAR) 130 | # image = image.resize((self.imgW, int(h/scale)), Image.BILINEAR) 131 | resized_images.append(image) 132 | else: 133 | for image in images: 134 | image = image.resize((self.imgW, self.imgH), Image.BILINEAR) 135 | resized_images.append(image) 136 | 137 | # 按最大的w进行填充 138 | maxw = max([x.size[0] for x in resized_images]) 139 | # images = [baidu_pad(x, maxw, imgH) for x in resized_images] # 填充完成 140 | images = resized_images 141 | 142 | transform = resizeNormalize((maxw, imgH)) 143 | images = [transform(image) for image in images] 144 | images = torch.cat([t.unsqueeze(0) for t in images], 0) 145 | 146 | return images, labels 147 | 148 | 149 | def baidu_pad(img, maxw, h=32): 150 | img = np.array(img) 151 | if img.ndim == 2: 152 | h, w = img.shape 153 | randint = random.randint(-10, -1) 154 | newimg = np.ones((h, maxw)) * img[randint, randint] 155 | newimg[:,:w] = img # 以最大宽度填充 156 | else: 157 | h,w,c = img.shape 158 | randint = random.randint(-10, -1) 159 | newimg = np.ones((h, maxw, c), dtype='uint8') * img[randint, randint, :] 160 | newimg[:,:w, :] = img # 以最大宽度填充 161 | 162 | return Image.fromarray(newimg) 163 | 164 | class ImgDataset(Dataset): 165 | ''' 166 | 采用直接读取图片的方式读入,可以对文字区域随机加上抖动,针对4个点的标注 167 | ''' 168 | def __init__(self, root=None, csv_root=None,transform=None, target_transform=None, training=True): 169 | self.root = root 170 | with open(csv_root) as f: 171 | self.labels = f.readlines() 172 | self.transform = transform 173 | self.target_transform = target_transform 174 | self.train = training 175 | 176 | def __len__(self): 177 | return len(self.labels) 178 | 179 | def __getitem__(self, idx): 180 | per_label = self.labels[idx].rstrip().split(',') 181 | imgpath = os.path.join(self.root, per_label[0]) # 图片位置 182 | img = cv2.imread(imgpath) 183 | coor = per_label[1:9] # 标记的4个点 184 | if self.train: 185 | img = random_crop(img, coor) # 随机裁剪和90度旋转 186 | else: 187 | img = random_crop(img, coor, r_int=[0, 0]) # 不随机裁剪 188 | 189 | if self.transform is not None: 190 | img = self.transform(img) 191 | 192 | label = str(per_label[9]) 193 | label = label.decode('utf-8') # 194 | 195 | if self.target_transform is not None: # target_transform是没有的 196 | label = self.target_transform(label) 197 | 198 | return (img, label) 199 | 200 | class FourCoorDataset(ImgDataset): 201 | def __init__(self, root=None, csv_root=None,transform=None, target_transform=None, training=True): 202 | super(FourCoorDataset, self).__init__(root=root, csv_root=csv_root,transform=transform, target_transform=target_transform, training=training) 203 | 204 | def __getitem__(self, idx): 205 | per_label = self.labels[idx].rstrip().split('\t') 206 | imgpath = os.path.join(self.root, per_label[0]) # 图片位置 207 | img = cv2.imread(imgpath) 208 | img = img[:, :, ::-1] 209 | coor = per_label[2:6] # 标记的4个点 210 | if self.train: 211 | img = random_crop(img, coor) # 随机裁剪和90度旋转 212 | else: 213 | img = random_crop(img, coor, r_int=[0, 0]) # 不随机裁剪 214 | # img = img.convert('L') 215 | 216 | if self.transform is not None: 217 | img = self.transform(img) 218 | 219 | label = str(per_label[1].lstrip()) 220 | # label = label.decode('utf-8') # 221 | 222 | if self.target_transform is not None: # target_transform是没有的 223 | label = self.target_transform(label) 224 | 225 | return (img, label) 226 | 227 | 228 | class lmdbDataset(Dataset): 229 | '''采用lmdb工具进行数据读取 230 | ''' 231 | def __init__(self, root=None, transform=None, target_transform=None): 232 | self.root = root 233 | self.env = lmdb.open( 234 | root, 235 | max_readers=1, 236 | readonly=True, 237 | lock=False, 238 | readahead=False, 239 | meminit=False) 240 | 241 | if not self.env: 242 | print('cannot creat lmdb from %s' % (root)) 243 | sys.exit(0) 244 | 245 | with self.env.begin(write=False) as txn: 246 | nSamples = int(txn.get('num-samples')) 247 | self.nSamples = nSamples 248 | 249 | self.transform = transform 250 | self.target_transform = target_transform 251 | 252 | def __len__(self): 253 | return self.nSamples 254 | 255 | def __getitem__(self, index): 256 | assert index <= len(self), 'index range error' 257 | index += 1 258 | with self.env.begin(write=False) as txn: 259 | img_key = 'image-%09d' % index 260 | imgbuf = txn.get(img_key) 261 | 262 | buf = six.BytesIO() 263 | buf.write(imgbuf) 264 | buf.seek(0) 265 | try: 266 | # img = Image.open(buf).convert('L') 267 | if 'test' in self.root: 268 | img = Image.open(buf).convert('L') 269 | else: 270 | img = Image.open(buf) # brightness 调整的时候需要是彩色图像 271 | except IOError: 272 | print('Corrupted image for %d' % index) 273 | return self[index + 1] 274 | 275 | if self.transform is not None: 276 | img = self.transform(img) 277 | 278 | label_key = 'label-%09d' % index 279 | label = str(txn.get(label_key)) 280 | 281 | if self.target_transform is not None: 282 | label = self.target_transform(label) 283 | 284 | return (img, label) 285 | 286 | 287 | class resizeNormalize(object): 288 | 289 | def __init__(self, size, interpolation=Image.BILINEAR): 290 | self.size = size 291 | self.interpolation = interpolation 292 | self.toTensor = transforms.ToTensor() 293 | 294 | def __call__(self, img): 295 | img = img.resize(self.size, self.interpolation) 296 | img = self.toTensor(img) 297 | img.sub_(0.5).div_(0.5) 298 | return img 299 | 300 | 301 | class randomSequentialSampler(sampler.Sampler): 302 | 303 | def __init__(self, data_source, batch_size): 304 | self.num_samples = len(data_source) 305 | self.batch_size = batch_size 306 | 307 | def __iter__(self): 308 | n_batch = len(self) // self.batch_size 309 | tail = len(self) % self.batch_size 310 | index = torch.LongTensor(len(self)).fill_(0) 311 | for i in range(n_batch): 312 | random_start = random.randint(0, len(self) - self.batch_size) 313 | batch_index = random_start + torch.arange(0, self.batch_size) 314 | index[i * self.batch_size:(i + 1) * self.batch_size] = batch_index 315 | # deal with tail 316 | if tail: 317 | random_start = random.randint(0, len(self) - self.batch_size) 318 | tail_index = random_start + torch.arange(0, tail) 319 | index[(i + 1) * self.batch_size:] = tail_index 320 | 321 | return iter(index) 322 | 323 | def __len__(self): 324 | return self.num_samples 325 | 326 | 327 | class alignCollate(object): 328 | 329 | def __init__(self, imgH=32, imgW=100, keep_ratio=False, min_ratio=1): 330 | self.imgH = imgH 331 | self.imgW = imgW 332 | self.keep_ratio = keep_ratio 333 | self.min_ratio = min_ratio 334 | 335 | def __call__(self, batch): 336 | images, labels = zip(*batch) 337 | 338 | imgH = self.imgH 339 | imgW = self.imgW 340 | if self.keep_ratio: 341 | ratios = [] 342 | for image in images: 343 | w, h = image.size 344 | ratios.append(w / float(h)) 345 | ratios.sort() 346 | max_ratio = ratios[-1] # 找到最大的长宽比 347 | imgW = int(np.floor(max_ratio * imgH)) 348 | imgW = max(imgH * self.min_ratio, imgW) # assure imgH >= imgW 349 | 350 | transform = resizeNormalize((imgW, imgH)) 351 | images = [transform(image) for image in images] 352 | images = torch.cat([t.unsqueeze(0) for t in images], 0) 353 | 354 | return images, labels 355 | 356 | def pad_pil(img, maxw, h=32): 357 | img = np.array(img) 358 | if img.ndim == 2: 359 | h, w = img.shape 360 | randint = random.randint(-10, -1) 361 | newimg = np.ones((h, maxw)) * img[randint, randint] 362 | newimg[:,:w] = img # 以最大宽度填充 363 | else: 364 | h,w,c = img.shape 365 | randint = random.randint(-10, -1) 366 | newimg = np.ones((h, maxw, c), dtype='uint8') * img[randint, randint, :] 367 | newimg[:,:w, :] = img # 以最大宽度填充 368 | 369 | return Image.fromarray(newimg) 370 | 371 | 372 | class OwnalignCollate(object): 373 | '''每个batch按最宽的图像进行填充 374 | ''' 375 | 376 | def __init__(self, imgH=32, keep_ratio=True, min_ratio=1): 377 | self.imgH = imgH 378 | self.keep_ratio = keep_ratio 379 | 380 | def __call__(self, batch): 381 | images, labels = zip(*batch) 382 | 383 | imgH = self.imgH 384 | resized_images = [] 385 | if self.keep_ratio: # 等比例缩放 386 | for image in images: 387 | w, h = image.size 388 | scale = 1.0 * h / self.imgH 389 | image = image.resize((int(w/scale),int(h/scale)), Image.ANTIALIAS) 390 | resized_images.append(image) 391 | 392 | # 按最大的w进行填充 393 | maxw = max([x.size[0] for x in resized_images]) 394 | images = [pad_pil(x, maxw, imgH) for x in resized_images] # 填充完成 395 | 396 | transform = resizeNormalize((maxw, imgH)) 397 | images = [transform(image) for image in images] 398 | images = torch.cat([t.unsqueeze(0) for t in images], 0) 399 | 400 | return images, labels 401 | 402 | 403 | class TransformerCollate(OwnalignCollate): 404 | ''' 405 | converter:将字符串转换成整数组 406 | ''' 407 | def __init__(self, imgH=32, keep_ratio=True, converter=None): 408 | super(TransformerCollate, self).__init__(imgH=32, keep_ratio=True) 409 | self.converter = converter 410 | 411 | def __call__(self, batch): 412 | 413 | images, labels = zip(*batch) # 414 | 415 | # 按图片的最大长度进行填充 416 | imgH = self.imgH 417 | resized_images = [] 418 | if self.keep_ratio: # 等比例缩放 419 | for image in images: 420 | w, h = image.size 421 | scale = 1.0 * h / self.imgH 422 | image = image.resize((int(w/scale),int(h/scale)), Image.ANTIALIAS) 423 | resized_images.append(image) 424 | 425 | # 求src_seq的mask,将填充部分mask设为0 426 | lengthw = [x.size[0] for x in resized_images] # 每个图片的宽度 427 | maxw = max([x.size[0] for x in resized_images]) # 最大宽度 428 | srcw = [math.floor(x / 4.0 + 1) for x in lengthw] # 原图有文字区域的src的长度 429 | max_seq = math.floor(maxw / 4.0 + 1) # 求图片经过encoder之后的序列长度 430 | src_seq = np.array([ 431 | [Constants.UNK] * int(inst) + [Constants.PAD] * int(max_seq - inst) for inst in srcw 432 | ]) 433 | src_seq = torch.LongTensor(src_seq) # 填充 434 | # 按最大的w进行填充 435 | 436 | images = [pad_pil(x, maxw, imgH) for x in resized_images] # 填充完成 437 | 438 | transform = resizeNormalize((maxw, imgH)) 439 | images = [transform(image) for image in images] 440 | images = torch.cat([t.unsqueeze(0) for t in images], 0) 441 | 442 | # 把label编码成整数,进行batch填充 443 | tlabel = [self.converter.encode(x) for x in labels] 444 | tgt_seq, tgt_pos = paired(tlabel) 445 | 446 | return images, src_seq, tgt_seq, tgt_pos 447 | 448 | 449 | def paired(insts): 450 | ''' Pad the instance to the max seq length in batch ''' 451 | 452 | max_len = max(len(inst) for inst in insts) 453 | 454 | batch_seq = np.array([ 455 | inst + [Constants.PAD] * (max_len - len(inst)) 456 | for inst in insts]) 457 | 458 | batch_pos = np.array([ 459 | [pos_i+1 if w_i != Constants.PAD else 0 460 | for pos_i, w_i in enumerate(inst)] for inst in batch_seq]) 461 | 462 | batch_seq = torch.LongTensor(batch_seq) 463 | batch_pos = torch.LongTensor(batch_pos) 464 | 465 | return batch_seq, batch_pos 466 | 467 | 468 | class TransformerConverter(object): 469 | """Convert between str and label. 470 | 471 | NOTE: 472 | Insert `blank` to the alphabet for CTC. 473 | 474 | Args: 475 | alphabet (str): set of the possible characters. 476 | ignore_case (bool, default=True): whether or not to ignore all of the case. 477 | """ 478 | 479 | def __init__(self, alphabet, ignore_case=True): 480 | self._ignore_case = ignore_case 481 | if self._ignore_case: 482 | alphabet = alphabet.lower() 483 | self.alphabet = 'PUBE' + alphabet # for `-1` index 484 | 485 | self.dict = {} 486 | for i, char in enumerate(self.alphabet): 487 | # NOTE: 0 is reserved for 'blank' required by wrap_ctc 488 | self.dict[char] = i 489 | 490 | def encode(self, text): 491 | """Support batch or single str. 492 | 493 | Args: 494 | text (str or list of str): texts to convert. 495 | 496 | Returns: 497 | torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. 498 | torch.IntTensor [n]: length of each text. 499 | """ 500 | if isinstance(text, unicode): 501 | text = [ 502 | self.dict[char.lower() if self._ignore_case else char] 503 | for char in text 504 | ] 505 | text.insert(0, self.dict['B']) # 添加开始标识符 506 | text.append(self.dict['E']) # 添加结束标识符 507 | return text 508 | 509 | def decode(self, preds, raw=True): 510 | '''Support batch and single str decode 511 | 512 | :param preds: numpy.array 513 | :return: text: predict string 514 | texts: predict string list 515 | ''' 516 | assert isinstance(preds, np.ndarray), 'preds must be np.ndarray type' 517 | if preds.ndim == 1: 518 | if raw: # 不剔除 519 | return ''.join([self.alphabet[i] for i in preds]) 520 | else: 521 | charlist = [x for x in preds if x > Constants.EOS] # 只有大于4以上的编码才有效 522 | return ''.join([self.alphabet[i] for i in charlist]) 523 | else: 524 | # batch mode 525 | assert preds.ndim > 1, 'The batch mode is wrong' 526 | 527 | texts = [self.decode(t, raw) for t in preds] 528 | 529 | return texts 530 | -------------------------------------------------------------------------------- /modules/SRN_modules.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | # chenjun 3 | # date:2020-04-18 4 | import torch.nn as nn 5 | import torch 6 | import torch.nn.functional as F 7 | import numpy as np 8 | 9 | 10 | # def get_non_pad_mask(seq, PAD): 11 | # assert seq.dim() == 2 12 | # return seq.ne(PAD).type(torch.float).unsqueeze(-1) 13 | 14 | def get_pad_mask(seq, pad_idx): 15 | return (seq == pad_idx).unsqueeze(-2) 16 | 17 | 18 | def get_subsequent_mask(seq): 19 | ''' For masking out the subsequent info. ''' 20 | 21 | sz_b, len_s = seq.size() 22 | subsequent_mask = torch.triu( 23 | torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1) # 返回上三角矩阵 24 | subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1) # b x ls x ls 25 | 26 | return subsequent_mask 27 | 28 | 29 | def get_attn_key_pad_mask(seq_k, seq_q, PAD): 30 | ''' For masking out the padding part of key sequence. 31 | seq_k:src_seq 32 | seq_q:tgt_seq 33 | ''' 34 | 35 | # Expand to fit the shape of key query attention matrix. 36 | len_q = seq_q.size(1) # 目标序列 37 | padding_mask = seq_k.eq(PAD) # 源序列 38 | padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1) # b x lq x lk 39 | 40 | return padding_mask 41 | 42 | 43 | class PositionalEncoding(nn.Module): 44 | 45 | def __init__(self, d_hid, n_position=200): 46 | super(PositionalEncoding, self).__init__() 47 | 48 | # Not a parameter 49 | self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid)) 50 | 51 | def _get_sinusoid_encoding_table(self, n_position, d_hid): 52 | ''' Sinusoid position encoding table ''' 53 | # TODO: make it with torch instead of numpy 54 | 55 | def get_position_angle_vec(position): 56 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 57 | 58 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 59 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 60 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 61 | 62 | return torch.FloatTensor(sinusoid_table).unsqueeze(0) 63 | 64 | def forward(self, x): 65 | return x + self.pos_table[:, :x.size(1)].clone().detach() 66 | 67 | 68 | class ScaledDotProductAttention(nn.Module): 69 | ''' Scaled Dot-Product Attention ''' 70 | 71 | def __init__(self, temperature, attn_dropout=0.1): 72 | super(ScaledDotProductAttention, self).__init__() 73 | self.temperature = temperature 74 | self.dropout = nn.Dropout(attn_dropout) 75 | self.softmax = nn.Softmax(dim=2) 76 | 77 | def forward(self, q, k, v, mask=None): 78 | 79 | attn = torch.bmm(q, k.transpose(1, 2)) 80 | attn = attn / self.temperature 81 | 82 | if mask is not None: 83 | # print(mask.shape, attn.shape, v.shape) 84 | attn = attn.masked_fill(mask, -1e9) 85 | 86 | attn = self.softmax(attn) # 第3个维度为权重 87 | attn = self.dropout(attn) 88 | output = torch.bmm(attn, v) 89 | 90 | return output, attn 91 | 92 | 93 | class MultiHeadAttention(nn.Module): 94 | ''' Multi-Head Attention module ''' 95 | 96 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 97 | super(MultiHeadAttention, self).__init__() 98 | 99 | self.n_head = n_head 100 | self.d_k = d_k 101 | self.d_v = d_v 102 | 103 | self.w_qs = nn.Linear(d_model, n_head * d_k) 104 | self.w_ks = nn.Linear(d_model, n_head * d_k) 105 | self.w_vs = nn.Linear(d_model, n_head * d_v) 106 | nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 107 | nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 108 | nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) 109 | 110 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) 111 | self.layer_norm = nn.LayerNorm(d_model) 112 | 113 | self.fc = nn.Linear(n_head * d_v, d_model) 114 | nn.init.xavier_normal_(self.fc.weight) 115 | 116 | self.dropout = nn.Dropout(dropout) 117 | 118 | 119 | def forward(self, q, k, v, mask=None): 120 | 121 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 122 | 123 | sz_b, len_q, _ = q.size() 124 | sz_b, len_k, _ = k.size() 125 | sz_b, len_v, _ = v.size() 126 | 127 | residual = q 128 | 129 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) # 4*21*512 ---- 4*21*8*64 130 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 131 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 132 | 133 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 134 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 135 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 136 | 137 | mask = mask.repeat(n_head, 1, 1) if mask is not None else None # (n*b) x .. x .. 138 | output, attn = self.attention(q, k, v, mask=mask) 139 | 140 | output = output.view(n_head, sz_b, len_q, d_v) 141 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) 142 | 143 | output = self.dropout(self.fc(output)) 144 | output = self.layer_norm(output + residual) 145 | 146 | return output, attn 147 | 148 | class PositionwiseFeedForward(nn.Module): 149 | ''' A two-feed-forward-layer module ''' 150 | 151 | def __init__(self, d_in, d_hid, dropout=0.1): 152 | super(PositionwiseFeedForward, self).__init__() 153 | self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise 154 | self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise 155 | self.layer_norm = nn.LayerNorm(d_in) 156 | self.dropout = nn.Dropout(dropout) 157 | 158 | def forward(self, x): 159 | residual = x 160 | output = x.transpose(1, 2) 161 | output = self.w_2(F.relu(self.w_1(output))) 162 | output = output.transpose(1, 2) 163 | output = self.dropout(output) 164 | output = self.layer_norm(output + residual) 165 | return output 166 | 167 | 168 | class EncoderLayer(nn.Module): 169 | ''' Compose with two layers ''' 170 | 171 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 172 | super(EncoderLayer, self).__init__() 173 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 174 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 175 | 176 | def forward(self, enc_input, slf_attn_mask=None): 177 | enc_output, enc_slf_attn = self.slf_attn( 178 | enc_input, enc_input, enc_input, mask=slf_attn_mask) 179 | enc_output = self.pos_ffn(enc_output) 180 | return enc_output, enc_slf_attn 181 | 182 | 183 | class Torch_transformer_encoder(nn.Module): 184 | ''' 185 | use pytorch transformer for sequence learning 186 | 187 | ''' 188 | def __init__(self, d_word_vec=512, n_layers=2, n_head=8, d_model=512, dim_feedforward=1024, n_position=256): 189 | super(Torch_transformer_encoder, self).__init__() 190 | 191 | self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position) 192 | encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=n_head, dim_feedforward=dim_feedforward) 193 | self.layer_norm = nn.LayerNorm(d_model) 194 | self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers, norm=self.layer_norm) 195 | self.dropout = nn.Dropout(p=0.1) 196 | 197 | def forward(self, cnn_feature, src_mask=None, return_attns=False): 198 | enc_slf_attn_list = [] 199 | 200 | # -- Forward 201 | enc_output = self.dropout(self.position_enc(cnn_feature)) # position embeding 202 | 203 | enc_output = self.encoder(enc_output) 204 | 205 | enc_output = self.layer_norm(enc_output) 206 | 207 | if return_attns: 208 | return enc_output, enc_slf_attn_list 209 | return enc_output, 210 | 211 | 212 | 213 | class Transforme_Encoder(nn.Module): 214 | ''' to capture the global spatial dependencies''' 215 | ''' 216 | d_word_vec: 位置编码,特征空间维度 217 | n_layers: transformer的层数 218 | n_head:多头数量 219 | d_k: 64 220 | d_v: 64 221 | d_model: 512, 222 | d_inner: 1024 223 | n_position: 位置编码的最大值 224 | ''' 225 | def __init__( 226 | self, d_word_vec=512, n_layers=2, n_head=8, d_k=64, d_v=64, 227 | d_model=512, d_inner=1024, dropout=0.1, n_position=256): 228 | 229 | super().__init__() 230 | 231 | self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position) 232 | self.dropout = nn.Dropout(p=dropout) 233 | self.layer_stack = nn.ModuleList([ 234 | EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 235 | for _ in range(n_layers)]) 236 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 237 | 238 | def forward(self, cnn_feature, src_mask, return_attns=False): 239 | 240 | enc_slf_attn_list = [] 241 | 242 | # -- Forward 243 | enc_output = self.dropout(self.position_enc(cnn_feature)) # position embeding 244 | 245 | for enc_layer in self.layer_stack: 246 | enc_output, enc_slf_attn = enc_layer(enc_output, slf_attn_mask=src_mask) 247 | enc_slf_attn_list += [enc_slf_attn] if return_attns else [] 248 | 249 | enc_output = self.layer_norm(enc_output) 250 | 251 | if return_attns: 252 | return enc_output, enc_slf_attn_list 253 | return enc_output, 254 | 255 | 256 | class PVAM(nn.Module): 257 | ''' Parallel Visual attention module 平行解码''' 258 | ''' 259 | n_dim:512,阅读顺序序列编码的空间维度 260 | N_max_character: 25,单张图片最多有多少个字符 261 | n_position: cnn出来之后特征的序列长度 262 | ''' 263 | def __init__(self, n_dim=512, N_max_character=25, n_position=256): 264 | 265 | super(PVAM, self).__init__() 266 | self.character_len = N_max_character 267 | 268 | self.f0_embedding = nn.Embedding(N_max_character, n_dim) 269 | 270 | self.w0 = nn.Linear(N_max_character, n_position) 271 | self.wv = nn.Linear(n_dim, n_dim) 272 | # first linear(512,25) 273 | self.we = nn.Linear(n_dim, N_max_character) 274 | 275 | self.active = nn.Tanh() 276 | self.softmax = nn.Softmax(dim=2) 277 | 278 | def forward(self, enc_output): 279 | reading_order = torch.arange(self.character_len, dtype=torch.long, device=enc_output.device) 280 | reading_order = reading_order.unsqueeze(0).expand(enc_output.size(0), -1) # (S,) -> (B, S) 281 | reading_order_embed = self.f0_embedding(reading_order) # b,25,512 282 | 283 | t = self.w0(reading_order_embed.permute(0,2,1)) # b,512,256 284 | t = self.active(t.permute(0,2,1) + self.wv(enc_output)) # b,256,512 285 | # first linear(512,25) 286 | attn = self.we(t) # b,256,25 287 | 288 | attn = self.softmax(attn.permute(0,2,1)) # b,25,256 289 | 290 | g_output = torch.bmm(attn, enc_output) # b,25,512 291 | return g_output 292 | 293 | 294 | class GSRM(nn.Module): 295 | # global semantic reasoning module 296 | ''' 297 | n_dim:embed编码的特征空间维度 298 | n_class:embedding需要用到 299 | PAD:计算mask用到 300 | ''' 301 | def __init__(self, n_dim=512, n_class=37, PAD=37-1, n_layers=4, n_position=25): 302 | 303 | super(GSRM, self).__init__() 304 | 305 | self.PAD = PAD 306 | self.argmax_embed = nn.Embedding(n_class, n_dim) 307 | 308 | self.transformer_units = Transforme_Encoder(n_layers=n_layers, n_position=n_position) # for global context information 309 | # self.transformer_units = Torch_transformer_encoder(n_layers=n_layers, n_position=n_position) 310 | 311 | def forward(self, e_out): 312 | ''' 313 | e_out: b,25,37 | the output from PVAM3 314 | ''' 315 | e_argmax = e_out.argmax(dim=-1) # b, 25 316 | e = self.argmax_embed(e_argmax) # b,25,512 317 | 318 | e_mask = get_pad_mask(e_argmax, self.PAD) # b,25,1 319 | s = self.transformer_units(e, None) # b,25,512 320 | 321 | return s 322 | 323 | 324 | class SRN_Decoder(nn.Module): 325 | # the wrapper of decoder layers 326 | ''' 327 | n_dim: 特征空间维度 328 | n_class:字符种类 329 | N_max_character: 单张图最多只25个字符 330 | n_position:cnn输出的特征序列长度 331 | 整个有三个部分的输出 332 | ''' 333 | def __init__(self, n_dim=512, n_class=37, N_max_character=25, n_position=256, GSRM_layer=4 ): 334 | 335 | super(SRN_Decoder, self).__init__() 336 | 337 | self.pvam = PVAM(N_max_character=N_max_character, n_position=n_position) 338 | self.w_e = nn.Linear(n_dim, n_class) # output layer 339 | 340 | self.GSRM = GSRM(n_class=n_class, PAD=n_class-1, n_dim=n_dim, n_position=N_max_character, n_layers=GSRM_layer) 341 | self.w_s = nn.Linear(n_dim, n_class) # output layer 342 | 343 | self.w_f = nn.Linear(n_dim, n_class) # output layer 344 | 345 | def forward(self, cnn_feature ): 346 | '''cnn_feature: b,256,512 | the output from cnn''' 347 | 348 | g_output = self.pvam(cnn_feature) # b,25,512 349 | e_out = self.w_e(g_output) # b,25,37 ----> cross entropy loss | 第一个输出 350 | 351 | s = self.GSRM(e_out)[0] # b,25,512 352 | s_out = self.w_s(s) # b,25,37f 353 | 354 | # TODO:change the add to gated unit 355 | f = g_output + s # b,25,512 356 | f_out = self.w_f(f) 357 | 358 | return e_out, s_out, f_out 359 | 360 | 361 | def cal_performance(preds, gold, mask=None, smoothing='1'): 362 | ''' Apply label smoothing if needed ''' 363 | 364 | loss = 0. 365 | n_correct = 0 366 | weights = [1.0, 0.15, 2.0] 367 | for ori_pred, weight in zip(preds, weights): 368 | pred = ori_pred.view(-1, ori_pred.shape[-1]) 369 | # debug show 370 | t_gold = gold.view(ori_pred.shape[0], -1) 371 | t_pred_index = ori_pred.max(2)[1] 372 | 373 | mask = mask.view(-1) 374 | non_pad_mask = mask.ne(0) if mask is not None else None 375 | tloss = cal_loss(pred, gold, non_pad_mask, smoothing) 376 | if torch.isnan(tloss): 377 | print('have nan loss') 378 | continue 379 | else: 380 | loss += tloss * weight 381 | 382 | pred = pred.max(1)[1] 383 | gold = gold.contiguous().view(-1) 384 | n_correct = pred.eq(gold) 385 | n_correct = n_correct.masked_select(non_pad_mask).sum().item() if mask is not None else None 386 | 387 | return loss, n_correct 388 | 389 | 390 | def cal_loss(pred, gold, mask, smoothing): 391 | ''' Calculate cross entropy loss, apply label smoothing if needed. ''' 392 | 393 | gold = gold.contiguous().view(-1) 394 | 395 | if smoothing=='0': 396 | eps = 0.1 397 | n_class = pred.size(1) 398 | 399 | one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1) 400 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) 401 | log_prb = F.log_softmax(pred, dim=1) 402 | 403 | non_pad_mask = gold.ne(0) 404 | loss = -(one_hot * log_prb).sum(dim=1) 405 | loss = loss.masked_select(non_pad_mask).sum() # average later 406 | elif smoothing == '1': 407 | if mask is not None: 408 | loss = F.cross_entropy(pred, gold, reduction='none') 409 | loss = loss.masked_select(mask) 410 | loss = loss.sum() / mask.sum() 411 | else: 412 | loss = F.cross_entropy(pred, gold) 413 | else: 414 | # loss = F.cross_entropy(pred, gold, ignore_index=PAD) 415 | loss = F.cross_entropy(pred, gold) 416 | 417 | return loss 418 | 419 | 420 | def cal_performance2(preds, gold, PAD, smoothing='1'): 421 | ''' Apply label smoothing if needed ''' 422 | 423 | loss = 0. 424 | n_correct = 0 425 | weights = [1.0, 0.15, 2.0] 426 | for ori_pred, weight in zip(preds, weights): 427 | pred = ori_pred.view(-1, ori_pred.shape[-1]) 428 | # debug show 429 | t_gold = gold.view(ori_pred.shape[0], -1) 430 | t_pred_index = ori_pred.max(2)[1] 431 | 432 | tloss = cal_loss2(pred, gold, PAD, smoothing=smoothing) 433 | if torch.isnan(tloss): 434 | print('have nan loss') 435 | continue 436 | else: 437 | loss += tloss * weight 438 | 439 | pred = pred.max(1)[1] 440 | gold = gold.contiguous().view(-1) 441 | n_correct = pred.eq(gold) 442 | non_pad_mask = gold.ne(PAD) 443 | n_correct = n_correct.masked_select(non_pad_mask).sum().item() 444 | 445 | return loss, n_correct 446 | 447 | 448 | def cal_loss2(pred, gold, PAD, smoothing='1'): 449 | ''' Calculate cross entropy loss, apply label smoothing if needed. ''' 450 | 451 | gold = gold.contiguous().view(-1) 452 | 453 | if smoothing=='0': 454 | eps = 0.1 455 | n_class = pred.size(1) 456 | 457 | one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1) 458 | one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1) 459 | log_prb = F.log_softmax(pred, dim=1) 460 | 461 | non_pad_mask = gold.ne(0) 462 | loss = -(one_hot * log_prb).sum(dim=1) 463 | loss = loss.masked_select(non_pad_mask).sum() # average later 464 | elif smoothing == '1': 465 | loss = F.cross_entropy(pred, gold, ignore_index=PAD) 466 | else: 467 | # loss = F.cross_entropy(pred, gold, ignore_index=PAD) 468 | loss = F.cross_entropy(pred, gold) 469 | 470 | return loss 471 | 472 | 473 | if __name__=='__main__': 474 | cnn_feature = torch.rand((2,256,512)) 475 | model1 = Transforme_Encoder() 476 | image = model1(cnn_feature,src_mask=None)[0] 477 | model = SRN_Decoder(N_max_character=30) 478 | 479 | outs = model(image) 480 | for out in outs: 481 | print(out.shape) 482 | 483 | # image = torch.rand((4,3,32,60)) 484 | # tgt_seq = torch.tensor([[ 2, 24, 2176, 882, 2480, 612, 1525, 480, 875, 147, 1700, 715, 485 | # 1465, 3], 486 | # [ 2, 369, 1781, 882, 703, 879, 2855, 2415, 502, 1154, 833, 1465, 487 | # 3, 0], 488 | # [ 2, 2943, 334, 328, 480, 330, 1644, 1449, 163, 147, 1823, 1184, 489 | # 1465, 3], 490 | # [ 2, 24, 396, 480, 703, 1646, 897, 1711, 1508, 703, 2321, 147, 491 | # 642, 1465]], device='cuda:0') 492 | # tgt_pos = torch.tensor([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], 493 | # [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0], 494 | # [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], 495 | # [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]], 496 | # device='cuda:0') 497 | # src_seq = torch.tensor([[ 2, 598, 2088, 822, 2802, 1156, 157, 1099, 1000, 598, 1707, 1345, 498 | # 3, 0, 0, 0], 499 | # [ 2, 598, 2348, 822, 598, 1222, 471, 948, 986, 423, 1345, 3, 500 | # 0, 0, 0, 0], 501 | # [ 2, 2437, 2470, 901, 2473, 598, 1735, 84, 1, 2277, 1979, 499, 502 | # 962, 1345, 3, 0], 503 | # [ 2, 598, 186, 1904, 598, 868, 1339, 1604, 84, 598, 608, 1728, 504 | # 1345, 3, 0, 0]], device='cuda:0') 505 | 506 | # device = torch.device('cuda') 507 | # image = image.cuda() 508 | # transformer = Transformer() 509 | # transformer = transformer.to(device) 510 | # transformer.train() 511 | # out = transformer(image, tgt_seq, tgt_pos, src_seq) 512 | 513 | # gold = tgt_seq[:, 1:] # 从第二列开始 514 | 515 | # # backward 516 | # loss, n_correct = cal_performance(out, gold, smoothing=True) 517 | # print(loss, n_correct) 518 | # a = 1 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import random 5 | import string 6 | import argparse 7 | 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | import torch.nn.init as init 11 | import torch.optim as optim 12 | from torch.optim import lr_scheduler 13 | import torch.utils.data 14 | import numpy as np 15 | 16 | from utils import CTCLabelConverter, AttnLabelConverter, Averager, TransformerConverter, SRNConverter 17 | from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset 18 | from model import Model 19 | from test import validation 20 | from src.baidudataset import BAIDUset, BaiduCollate 21 | from modules.optimizer.ranger import Ranger 22 | # from modules.SRN_modules import cal_performance 23 | from modules.SRN_modules import cal_performance2 as cal_performance 24 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 25 | 26 | def train(opt): 27 | """ dataset preparation """ 28 | if opt.select_data == 'baidu': 29 | train_set = BAIDUset(opt, opt.train_csv) 30 | train_loader = torch.utils.data.DataLoader( 31 | train_set, batch_size=opt.batch_size, 32 | shuffle=True, num_workers=int(opt.workers), 33 | collate_fn=BaiduCollate(opt.imgH, opt.imgW, keep_ratio=False) 34 | ) 35 | val_set = BAIDUset(opt, opt.val_csv) 36 | valid_loader = torch.utils.data.DataLoader( 37 | val_set, batch_size=opt.batch_size, 38 | shuffle=True, 39 | num_workers=int(opt.workers), 40 | collate_fn=BaiduCollate(opt.imgH, opt.imgW, keep_ratio=False), pin_memory=True) 41 | 42 | else: 43 | opt.select_data = opt.select_data.split('-') 44 | opt.batch_ratio = opt.batch_ratio.split('-') 45 | train_dataset = Batch_Balanced_Dataset(opt) 46 | 47 | AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 48 | valid_dataset = hierarchical_dataset(root=opt.valid_data, opt=opt) 49 | valid_loader = torch.utils.data.DataLoader( 50 | valid_dataset, batch_size=opt.batch_size, 51 | shuffle=True, # 'True' to check training progress with validation function. 52 | num_workers=int(opt.workers), 53 | collate_fn=AlignCollate_valid, pin_memory=True) 54 | print('-' * 80) 55 | 56 | """ model configuration """ 57 | if 'CTC' in opt.Prediction: 58 | converter = CTCLabelConverter(opt.character) 59 | elif 'Bert' in opt.Prediction: 60 | converter = TransformerConverter(opt.character, opt.max_seq) 61 | elif 'SRN' in opt.Prediction: 62 | converter = SRNConverter(opt.character, opt.SRN_PAD) 63 | else: 64 | converter = AttnLabelConverter(opt.character) 65 | opt.num_class = len(converter.character) 66 | 67 | if opt.rgb: 68 | opt.input_channel = 3 69 | model = Model(opt) 70 | print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, 71 | opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, 72 | opt.SequenceModeling, opt.Prediction) 73 | 74 | # weight initialization 75 | for name, param in model.named_parameters(): 76 | if 'localization_fc2' in name: 77 | print(f'Skip {name} as it is already initialized') 78 | continue 79 | try: 80 | if 'bias' in name: 81 | init.constant_(param, 0.0) 82 | elif 'weight' in name: 83 | init.kaiming_normal_(param) 84 | except Exception as e: # for batchnorm. 85 | if 'weight' in name: 86 | param.data.fill_(1) 87 | continue 88 | 89 | # data parallel for multi-GPU 90 | model = torch.nn.DataParallel(model).cuda() 91 | model.train() 92 | if opt.continue_model != '': 93 | print(f'loading pretrained model from {opt.continue_model}') 94 | model.load_state_dict(torch.load(opt.continue_model)) 95 | print("Model:") 96 | print(model) 97 | 98 | """ setup loss """ 99 | if 'CTC' in opt.Prediction: 100 | criterion = torch.nn.CTCLoss(zero_infinity=True).cuda() 101 | elif 'Bert' in opt.Prediction: 102 | criterion = torch.nn.CrossEntropyLoss(ignore_index=0).cuda() 103 | elif 'SRN' in opt.Prediction: 104 | criterion = cal_performance 105 | else: 106 | criterion = torch.nn.CrossEntropyLoss(ignore_index=0).cuda() # ignore [GO] token = ignore index 0 107 | # loss averager 108 | loss_avg = Averager() 109 | 110 | # filter that only require gradient decent 111 | filtered_parameters = [] 112 | params_num = [] 113 | for p in filter(lambda p: p.requires_grad, model.parameters()): 114 | filtered_parameters.append(p) 115 | params_num.append(np.prod(p.size())) 116 | print('Trainable params num : ', sum(params_num)) 117 | # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] 118 | 119 | # setup optimizer 120 | if opt.adam: 121 | optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) 122 | elif opt.ranger: 123 | optimizer = Ranger(filtered_parameters, lr=opt.lr) 124 | else: 125 | optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) 126 | print("Optimizer:") 127 | print(optimizer) 128 | 129 | lrScheduler = lr_scheduler.MultiStepLR(optimizer, [2, 4, 5], gamma=0.1) # 减小学习速率 130 | 131 | """ final options """ 132 | # print(opt) 133 | with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file: 134 | opt_log = '------------ Options -------------\n' 135 | args = vars(opt) 136 | for k, v in args.items(): 137 | opt_log += f'{str(k)}: {str(v)}\n' 138 | opt_log += '---------------------------------------\n' 139 | print(opt_log) 140 | opt_file.write(opt_log) 141 | 142 | """ start training """ 143 | start_iter = 0 144 | if opt.continue_model != '': 145 | start_iter = int(opt.continue_model.split('_')[-1].split('.')[0]) 146 | print(f'continue to train, start_iter: {start_iter}') 147 | 148 | start_time = time.time() 149 | best_accuracy = -1 150 | best_norm_ED = 1e+6 151 | i = start_iter 152 | if opt.select_data == 'baidu': 153 | train_iter = iter(train_loader) 154 | step_per_epoch = len(train_set) / opt.batch_size 155 | print('一代有多少step:', step_per_epoch) 156 | else: 157 | step_per_epoch = train_dataset.nums_samples / opt.batch_size 158 | print('一代有多少step:', step_per_epoch) 159 | 160 | 161 | while(True): 162 | # try: 163 | # train part 164 | for p in model.parameters(): 165 | p.requires_grad = True 166 | 167 | if opt.select_data == 'baidu': 168 | try: 169 | image_tensors, labels = train_iter.next() 170 | except: 171 | train_iter = iter(train_loader) 172 | image_tensors, labels = train_iter.next() 173 | else: 174 | image_tensors, labels = train_dataset.get_batch() 175 | 176 | image = image_tensors.cuda() 177 | if 'SRN' in opt.Prediction: 178 | text, length = converter.encode(labels) 179 | else: 180 | text, length = converter.encode(labels) 181 | batch_size = image.size(0) 182 | 183 | if 'CTC' in opt.Prediction: 184 | preds = model(image, text).log_softmax(2) 185 | preds_size = torch.IntTensor([preds.size(1)] * batch_size) 186 | preds = preds.permute(1, 0, 2) 187 | 188 | # (ctc_a) For PyTorch 1.2.0 and 1.3.0. To avoid ctc_loss issue, disabled cudnn for the computation of the ctc_loss 189 | # https://github.com/jpuigcerver/PyLaia/issues/16 190 | torch.backends.cudnn.enabled = False 191 | cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device)) 192 | torch.backends.cudnn.enabled = True 193 | 194 | elif 'Bert' in opt.Prediction: 195 | pad_mask = None 196 | # print(image.shape) 197 | preds = model(image, pad_mask) 198 | cost = criterion(preds[0].view(-1, preds[0].shape[-1]), text.contiguous().view(-1)) + \ 199 | criterion(preds[1].view(-1, preds[1].shape[-1]), text.contiguous().view(-1)) 200 | 201 | elif 'SRN' in opt.Prediction: 202 | preds = model(image, None) 203 | cost, train_correct = criterion(preds, text, opt.SRN_PAD) 204 | 205 | else: 206 | preds = model(image, text[:, :-1]) # align with Attention.forward 207 | target = text[:, 1:] # without [GO] Symbol 208 | cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) 209 | 210 | model.zero_grad() 211 | cost.backward() 212 | torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) 213 | optimizer.step() 214 | 215 | loss_avg.add(cost) 216 | 217 | if i % opt.disInterval == 0: 218 | elapsed_time = time.time() - start_time 219 | print(f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}') 220 | start_time = time.time() 221 | 222 | # validation part 223 | if i % opt.valInterval == 0 and i > start_iter: 224 | elapsed_time = time.time() - start_time 225 | print(f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}') 226 | # for log 227 | with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: 228 | log.write(f'[{i}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n') 229 | loss_avg.reset() 230 | 231 | # model.eval() 232 | # valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation( 233 | # # model, criterion, valid_loader, converter, opt) 234 | valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation( 235 | model, criterion, valid_loader, converter, opt) 236 | model.train() 237 | 238 | for pred, gt in zip(preds[:5], labels[:5]): 239 | if 'Attn' in opt.Prediction: 240 | pred = pred[:pred.find('[s]')] 241 | gt = gt[:gt.find('[s]')] 242 | print(f'pred: {pred:20s}, gt: {gt:20s}, {str(pred == gt)}') 243 | log.write(f'pred: {pred:20s}, gt: {gt:20s}, {str(pred == gt)}\n') 244 | 245 | valid_log = f'[{i}/{opt.num_iter}] valid loss: {valid_loss:0.5f}' 246 | valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}' 247 | print(valid_log) 248 | log.write(valid_log + '\n') 249 | 250 | # keep best accuracy model 251 | if current_accuracy > best_accuracy: 252 | best_accuracy = current_accuracy 253 | torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth') 254 | if current_norm_ED < best_norm_ED: 255 | best_norm_ED = current_norm_ED 256 | torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth') 257 | best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}' 258 | print(best_model_log) 259 | log.write(best_model_log + '\n') 260 | 261 | # save model per 1e+5 iter. 262 | if (i + 1) % opt.saveInterval == 0: 263 | torch.save( 264 | model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') 265 | 266 | if i == opt.num_iter: 267 | print('end the training') 268 | sys.exit() 269 | 270 | if i > 0 and i % int(step_per_epoch) == 0: # 调整学习速率 271 | print('down the learn rate 1/10') 272 | lrScheduler.step() 273 | 274 | i += 1 275 | # except: 276 | # import sys, traceback 277 | # traceback.print_exc(file=sys.stdout) 278 | # continue 279 | 280 | 281 | if __name__ == '__main__': 282 | parser = argparse.ArgumentParser() 283 | parser.add_argument('--experiment_name', help='Where to store logs and models') 284 | parser.add_argument('--train_data', default='/home/deepblue/deepbluetwo/chenjun/1_OCR/data/data_lmdb_release/training', help='path to training dataset') 285 | parser.add_argument('--valid_data', default='/home/deepblue/deepbluetwo/chenjun/1_OCR/data/data_lmdb_release/validation', help='path to validation dataset') 286 | parser.add_argument('--manualSeed', type=int, default=666, help='for random seed setting') 287 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=6) 288 | parser.add_argument('--batch_size', type=int, default=256, help='input batch size') 289 | parser.add_argument('--num_iter', type=int, default=300000, help='number of iterations to train for') 290 | parser.add_argument('--valInterval', type=int, default=5000, help='Interval between each validation') 291 | parser.add_argument('--saveInterval', type=int, default=5000, help='Interval between each save') 292 | parser.add_argument('--disInterval', type=int, default=5, help='Interval betweet each show') 293 | parser.add_argument('--continue_model', default = '', help="path to model to continue training") 294 | # parser.add_argument('--continue_model', default='./saved_models/None-ResNet-SRN-SRN-Seed666/iter_150000.pth', help="path to model to continue training") 295 | parser.add_argument('--adam', default=True, help='Whether to use adam (default is Adadelta)') 296 | parser.add_argument('--ranger', default=False, help='use RAdam + Lookahead for optimizer') 297 | parser.add_argument('--lr', type=float, default=0.0001, help='learning rate, default=1.0 for Adadelta') 298 | parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for adam. default=0.9') 299 | parser.add_argument('--rho', type=float, default=0.95, help='decay rate rho for Adadelta. default=0.95') 300 | parser.add_argument('--eps', type=float, default=1e-8, help='eps for Adadelta. default=1e-8') 301 | parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping value. default=5') 302 | 303 | """ all baidu images """ 304 | # parser.add_argument('--root', type=str, default='/root/shenlan/deepblue/1_OCR/data/train_images', help='the path of images') 305 | # parser.add_argument('--train_csv', type=str, default='/root/shenlan/deepblue/1_OCR/text_reco/dataset/BAIDU/add_train_30w.txt', help='the train samples') 306 | # parser.add_argument('--val_csv', type=str, default='/root/shenlan/deepblue/1_OCR/text_reco/dataset/BAIDU/add_val.txt', help='the val samples') 307 | # parser.add_argument('--baidu_alphabet', type=str, default='/root/shenlan/deepblue/1_OCR/text_reco/dataset/BAIDU/baidu_alphabet_30w.txt') 308 | 309 | '''a small baidu image''' 310 | parser.add_argument('--root', type=str, default='./dataset/BAIDU/images/', help='the path of images') 311 | parser.add_argument('--train_csv', type=str, default='./dataset/BAIDU/small_train.txt', help='the train samples') 312 | parser.add_argument('--val_csv', type=str, default='./dataset/BAIDU/small_train.txt', help='the val samples') 313 | parser.add_argument('--baidu_alphabet', type=str, default='./dataset/BAIDU/baidu_alphabet.txt') 314 | 315 | '''bert_ocr setting''' 316 | parser.add_argument('--max_seq', type=int, default=26, help='the maxium of the sequence length') 317 | parser.add_argument('--position_dim', type=int, default=26, help='the length sequence out from cnn encoder,resnet:65,resnetfpn:256') 318 | 319 | '''SRN setting''' 320 | parser.add_argument('--SRN_PAD', type=int, default=37, help='refer to EOS') 321 | parser.add_argument('--batch_max_character', type=int, default=25, help='the max character of one image') 322 | parser.add_argument('--alphabet_size', type=int, default=None, help='the categry of the string') 323 | 324 | parser.add_argument('--select_data', type=str, default='MJ-ST', 325 | help='select training data MJ-ST | MJ-ST-ICDAR2019 | baidu') 326 | parser.add_argument('--batch_ratio', type=str, default='1.0-1.0', 327 | help='assign ratio for each selected data in the batch') 328 | parser.add_argument('--total_data_usage_ratio', type=str, default='1.0', 329 | help='total data usage ratio, this ratio is multiplied to total number of data.') 330 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') 331 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') 332 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') 333 | parser.add_argument('--rgb', action='store_true', help='use rgb input') 334 | parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz$#', help='character label') 335 | parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode') 336 | parser.add_argument('--PAD', action='store_true', help='whether tlabelo keep ratio then pad for image resize') 337 | parser.add_argument('--data_filtering_off', action='store_true', help='for data_filtering_off mode') 338 | """ Model Architecture """ 339 | parser.add_argument('--Transformation', type=str, default='None', help='Transformation stage. None|TPS') 340 | parser.add_argument('--FeatureExtraction', type=str, default='ResNet', help='FeatureExtraction stage. VGG|RCNN|ResNet|AsterRes|ResnetFpn') 341 | parser.add_argument('--SequenceModeling', type=str, default='SRN', help='SequenceModeling stage. None|BiLSTM|Bert|SRN') 342 | parser.add_argument('--Prediction', type=str, default='SRN', help='Prediction stage. CTC|Attn|Bert_pred|SRN') 343 | parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') 344 | parser.add_argument('--input_channel', type=int, default=1, help='the number of input channel of Feature extractor') 345 | parser.add_argument('--output_channel', type=int, default=512, 346 | help='the number of output channel of Feature extractor') 347 | parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') 348 | 349 | opt = parser.parse_args() 350 | 351 | if not opt.experiment_name: 352 | opt.experiment_name = f'{opt.Transformation}-{opt.FeatureExtraction}-{opt.SequenceModeling}-{opt.Prediction}' 353 | opt.experiment_name += f'-Seed{opt.manualSeed}' 354 | # print(opt.experiment_name) 355 | 356 | os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True) 357 | 358 | """ vocab / character number configuration """ 359 | if opt.sensitive: 360 | # opt.character += 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' 361 | opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). 362 | 363 | if opt.select_data == 'baidu': 364 | with open(opt.baidu_alphabet) as f: 365 | opt.character = f.readlines()[0] 366 | # opt.character = opt.baidu_alphabet 367 | opt.alphabet_size = len(opt.character) # +2 for [UNK]+[EOS] 368 | 369 | '''SRN setting''' 370 | opt.SRN_PAD = len(opt.character) - 1 371 | 372 | 373 | """ Seed and GPU setting """ 374 | # print("Random Seed: ", opt.manualSeed) 375 | random.seed(opt.manualSeed) 376 | np.random.seed(opt.manualSeed) 377 | torch.manual_seed(opt.manualSeed) 378 | torch.cuda.manual_seed(opt.manualSeed) 379 | 380 | cudnn.benchmark = True 381 | cudnn.deterministic = True 382 | opt.num_gpu = torch.cuda.device_count() 383 | # opt.num_gpu = 1 384 | # print('device count', opt.num_gpu) 385 | if opt.num_gpu > 1: 386 | print('------ Use multi-GPU setting ------') 387 | print('if you stuck too long time with multi-GPU setting, try to set --workers 0') 388 | # check multi-GPU issue https://github.com/clovaai/deep-text-recognition-benchmark/issues/1 389 | opt.workers = opt.workers * opt.num_gpu 390 | 391 | """ previous version 392 | print('To equlize batch stats to 1-GPU setting, the batch_size is multiplied with num_gpu and multiplied batch_size is ', opt.batch_size) 393 | opt.batch_size = opt.batch_size * opt.num_gpu 394 | print('To equalize the number of epochs to 1-GPU setting, num_iter is divided with num_gpu by default.') 395 | If you dont care about it, just commnet out these line.) 396 | opt.num_iter = int(opt.num_iter / opt.num_gpu) 397 | """ 398 | 399 | train(opt) 400 | --------------------------------------------------------------------------------