├── imgs └── dataset1.png ├── modules ├── sequence_modeling.py ├── transformation.py ├── feature_extraction.py └── prediction.py ├── create_lmdb_dataset.py ├── model.py ├── utils.py ├── README.md ├── test.py ├── train.py ├── meta_train.py ├── self_training.py ├── meta_self_learning.py └── dataset.py /imgs/dataset1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bupt-ai-cz/Meta-SelfLearning/HEAD/imgs/dataset1.png -------------------------------------------------------------------------------- /modules/sequence_modeling.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BidirectionalLSTM(nn.Module): 5 | 6 | def __init__(self, input_size, hidden_size, output_size): 7 | super(BidirectionalLSTM, self).__init__() 8 | self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True) 9 | self.linear = nn.Linear(hidden_size * 2, output_size) 10 | 11 | def forward(self, input): 12 | """ 13 | input : visual feature [batch_size x T x input_size] 14 | output : contextual feature [batch_size x T x output_size] 15 | """ 16 | self.rnn.flatten_parameters() 17 | recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) 18 | output = self.linear(recurrent) # batch_size x T x output_size 19 | return output 20 | -------------------------------------------------------------------------------- /create_lmdb_dataset.py: -------------------------------------------------------------------------------- 1 | """ a modified version of CRNN torch repository https://github.com/bgshih/crnn/blob/master/tool/create_dataset.py """ 2 | ''' 3 | python create_lmdb_dataset.py --inputPath data/ --gtFile data/gt.txt --outputPath result/ 4 | 其中,inputPath和gtFile中的路径拼起来为实际图片路径 5 | 注意要把label中的空格换为\t $ sed -s 's/ /\t/g' label.txt > label_t.txt 6 | ''' 7 | import fire 8 | import os 9 | import lmdb 10 | import cv2 11 | 12 | import numpy as np 13 | 14 | 15 | def checkImageIsValid(imageBin): 16 | if imageBin is None: 17 | return False 18 | imageBuf = np.frombuffer(imageBin, dtype=np.uint8) 19 | img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE) 20 | imgH, imgW = img.shape[0], img.shape[1] 21 | if imgH * imgW == 0: 22 | return False 23 | return True 24 | 25 | 26 | def writeCache(env, cache): 27 | with env.begin(write=True) as txn: 28 | for k, v in cache.items(): 29 | txn.put(k, v) 30 | 31 | 32 | def createDataset(inputPath, gtFile, outputPath, checkValid=True): 33 | """ 34 | Create LMDB dataset for training and evaluation. 35 | ARGS: 36 | inputPath : input folder path where starts imagePath 37 | outputPath : LMDB output path 38 | gtFile : list of image path and label 39 | checkValid : if true, check the validity of every image 40 | """ 41 | os.makedirs(outputPath, exist_ok=True) 42 | env = lmdb.open(outputPath, map_size=1099511627776) 43 | cache = {} 44 | cnt = 1 45 | 46 | with open(gtFile, 'r', encoding='utf-8') as data: 47 | datalist = data.readlines() 48 | 49 | nSamples = len(datalist) 50 | for i in range(nSamples): 51 | imagePath, label = datalist[i].strip('\n').split('\t') 52 | imagePath = os.path.join(inputPath, imagePath) 53 | 54 | # # only use alphanumeric data 55 | # if re.search('[^a-zA-Z0-9]', label): 56 | # continue 57 | 58 | if not os.path.exists(imagePath): 59 | print('%s does not exist' % imagePath) 60 | continue 61 | with open(imagePath, 'rb') as f: 62 | imageBin = f.read() 63 | if checkValid: 64 | try: 65 | if not checkImageIsValid(imageBin): 66 | print('%s is not a valid image' % imagePath) 67 | continue 68 | except: 69 | print('error occured', i) 70 | with open(outputPath + '/error_image_log.txt', 'a') as log: 71 | log.write('%s-th image data occured error\n' % str(i)) 72 | continue 73 | 74 | imageKey = 'image-%09d'.encode() % cnt 75 | labelKey = 'label-%09d'.encode() % cnt 76 | cache[imageKey] = imageBin 77 | cache[labelKey] = label.encode() 78 | 79 | if cnt % 1000 == 0: 80 | writeCache(env, cache) 81 | cache = {} 82 | print('Written %d / %d' % (cnt, nSamples)) 83 | cnt += 1 84 | nSamples = cnt-1 85 | cache['num-samples'.encode()] = str(nSamples).encode() 86 | writeCache(env, cache) 87 | print('Created dataset with %d samples' % nSamples) 88 | 89 | 90 | if __name__ == '__main__': 91 | fire.Fire(createDataset) 92 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import torch.nn as nn 18 | from torch.nn import Transformer 19 | 20 | from modules.transformation import TPS_SpatialTransformerNetwork 21 | from modules.feature_extraction import VGG_FeatureExtractor, RCNN_FeatureExtractor, ResNet_FeatureExtractor, SEResNet50_FeatureExtractor 22 | from modules.sequence_modeling import BidirectionalLSTM 23 | from modules.prediction import Attention 24 | 25 | 26 | class Model(nn.Module): 27 | 28 | def __init__(self, opt): 29 | super(Model, self).__init__() 30 | self.opt = opt 31 | self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction, 32 | 'Seq': opt.SequenceModeling, 'Pred': opt.Prediction} 33 | 34 | """ Transformation """ 35 | if opt.Transformation == 'TPS': 36 | self.Transformation = TPS_SpatialTransformerNetwork( 37 | F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), I_channel_num=opt.input_channel) 38 | else: 39 | print('No Transformation module specified') 40 | 41 | """ FeatureExtraction """ 42 | if opt.FeatureExtraction == 'VGG': 43 | self.FeatureExtraction = VGG_FeatureExtractor(opt.input_channel, opt.output_channel) 44 | elif opt.FeatureExtraction == 'RCNN': 45 | self.FeatureExtraction = RCNN_FeatureExtractor(opt.input_channel, opt.output_channel) 46 | elif opt.FeatureExtraction == 'ResNet': 47 | self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel) 48 | elif opt.FeatureExtraction == 'SEResNet': 49 | self.FeatureExtraction = SEResNet50_FeatureExtractor(opt.input_channel, opt.output_channel) 50 | else: 51 | raise Exception('No FeatureExtraction module specified') 52 | self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512 53 | self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1 54 | 55 | """ Sequence modeling""" 56 | if opt.SequenceModeling == 'BiLSTM': 57 | self.SequenceModeling = nn.Sequential( 58 | BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size), 59 | BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size)) 60 | self.SequenceModeling_output = opt.hidden_size 61 | else: 62 | print('No SequenceModeling module specified') 63 | self.SequenceModeling_output = self.FeatureExtraction_output 64 | 65 | """ Prediction """ 66 | if opt.Prediction == 'CTC': 67 | self.Prediction = nn.Linear(self.SequenceModeling_output, opt.num_class) 68 | elif opt.Prediction == 'Attn': 69 | self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, opt.num_class) 70 | elif opt.Prediction == 'Transformer': 71 | self.Prediction = Transformer(self.SequenceModeling_output, opt.num_class) 72 | else: 73 | raise Exception('Prediction is neither CTC or Attn') 74 | 75 | def forward(self, input, text, is_train=True, is_domain=False): 76 | ''' 77 | is_domain: 为True时同时返回特征提取部分的特征和最终序列编码后的输出 78 | ''' 79 | """ Transformation stage """ 80 | if not self.stages['Trans'] == "None": 81 | input = self.Transformation(input) 82 | 83 | """ Feature extraction stage """ 84 | visual_feature = self.FeatureExtraction(input) 85 | visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h] 86 | visual_feature = visual_feature.squeeze(3) # b, t, c 87 | 88 | """ Sequence modeling stage """ 89 | if self.stages['Seq'] == 'BiLSTM': 90 | contextual_feature = self.SequenceModeling(visual_feature) 91 | else: 92 | contextual_feature = visual_feature # for convenience. this is NOT contextually modeled by BiLSTM 93 | 94 | """ Prediction stage """ 95 | if self.stages['Pred'] == 'CTC': 96 | # prediction = self.Prediction(contextual_feature.contiguous(), is_domain=is_domain) 97 | prediction = self.Prediction(contextual_feature.contiguous()) 98 | else: 99 | prediction = self.Prediction(contextual_feature.contiguous(), text, is_train, batch_max_length=self.opt.batch_max_length, is_domain=is_domain) 100 | 101 | if is_domain: 102 | return prediction, visual_feature, self.Prediction.context_history 103 | 104 | return prediction 105 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 3 | 4 | 5 | class CTCLabelConverter(object): 6 | """ Convert between text-label and text-index """ 7 | 8 | def __init__(self, character): 9 | # character (str): set of the possible characters. 10 | dict_character = list(character) 11 | 12 | self.dict = {} 13 | for i, char in enumerate(dict_character): 14 | # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss 15 | self.dict[char] = i + 1 16 | 17 | self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0) 18 | 19 | def encode(self, text, batch_max_length=25): 20 | """convert text-label into text-index. 21 | input: 22 | text: text labels of each image. [batch_size] 23 | batch_max_length: max length of text label in the batch. 25 by default 24 | 25 | output: 26 | text: text index for CTCLoss. [batch_size, batch_max_length] 27 | length: length of each text. [batch_size] 28 | """ 29 | length = [len(s) for s in text] 30 | 31 | # The index used for padding (=0) would not affect the CTC loss calculation. 所有label统一长度,多余位用0来填充 32 | batch_text = torch.LongTensor(len(text), batch_max_length).fill_(0) 33 | for i, t in enumerate(text): 34 | text = list(t) 35 | text = [self.dict[char] for char in text] 36 | batch_text[i][:len(text)] = torch.LongTensor(text) 37 | return (batch_text.to(device), torch.IntTensor(length).to(device)) 38 | 39 | def decode(self, text_index, length): 40 | """ convert text-index into text-label. """ 41 | texts = [] 42 | for index, l in enumerate(length): 43 | t = text_index[index, :] 44 | 45 | char_list = [] 46 | for i in range(l): 47 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. 48 | char_list.append(self.character[t[i]]) 49 | text = ''.join(char_list) 50 | 51 | texts.append(text) 52 | return texts 53 | 54 | 55 | class CTCLabelConverterForBaiduWarpctc(object): 56 | """ Convert between text-label and text-index for baidu warpctc """ 57 | 58 | def __init__(self, character): 59 | # character (str): set of the possible characters. 60 | dict_character = list(character) 61 | 62 | self.dict = {} 63 | for i, char in enumerate(dict_character): 64 | # NOTE: 0 is reserved for 'CTCblank' token required by CTCLoss 65 | self.dict[char] = i + 1 66 | 67 | self.character = ['[CTCblank]'] + dict_character # dummy '[CTCblank]' token for CTCLoss (index 0) 68 | 69 | def encode(self, text, batch_max_length=25): 70 | """convert text-label into text-index. 71 | input: 72 | text: text labels of each image. [batch_size] 73 | output: 74 | text: concatenated text index for CTCLoss. 75 | [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] 76 | length: length of each text. [batch_size] 77 | """ 78 | length = [len(s) for s in text] 79 | text = ''.join(text) 80 | text = [self.dict[char] for char in text] 81 | 82 | return (torch.IntTensor(text), torch.IntTensor(length)) 83 | 84 | def decode(self, text_index, length): 85 | """ convert text-index into text-label. """ 86 | texts = [] 87 | index = 0 88 | for l in length: 89 | t = text_index[index:index + l] 90 | 91 | char_list = [] 92 | for i in range(l): 93 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. 94 | char_list.append(self.character[t[i]]) 95 | text = ''.join(char_list) 96 | 97 | texts.append(text) 98 | index += l 99 | return texts 100 | 101 | 102 | class AttnLabelConverter(object): 103 | """ Convert between text-label and text-index """ 104 | 105 | def __init__(self, character): 106 | # character (str): set of the possible characters. 107 | # [GO] for the start token of the attention decoder. [s] for end-of-sentence token. 108 | # [GO]是起始token,对应索引0,[s]是终止token,对应索引1 109 | list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]'] 110 | list_character = list(character) 111 | self.character = list_token + list_character 112 | 113 | self.dict = {} 114 | for i, char in enumerate(self.character): 115 | # print(i, char) 116 | self.dict[char] = i 117 | 118 | def encode(self, text, batch_max_length=25): 119 | """ convert text-label into text-index. 120 | input: 121 | text: text labels of each image. [batch_size] 列表,每个元素为一个字符串 122 | batch_max_length: max length of text label in the batch. 25 by default 123 | 124 | output: 125 | text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token. 126 | text[:, 0] is [GO] token and text is padded with [GO] token after [s] token. 127 | length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size] 128 | 对于attention来说,输出的标签就是一个’03534534610000...'其中0代表GO,1代表S,S后均用GO来填充,长度为batch_max_length + 2 129 | """ 130 | length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence. [GO]不需要计入长度 131 | # batch_max_length = max(length) # this is not allowed for multi-gpu setting 132 | batch_max_length += 1 133 | # additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token. [s]之后的用[GO]也就是0补齐 134 | batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0) # 相当于长度是batch_max_length + 2 135 | for i, t in enumerate(text): 136 | text = list(t) 137 | text.append('[s]') 138 | text = [self.dict[char] for char in text] 139 | batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token 140 | return (batch_text.to(device), torch.IntTensor(length).to(device)) # length向量的长度是batch_max_length + 1,即[GO]不计入长度 141 | 142 | def decode(self, text_index, length): # length在attention中用不到,在CTCloss里面才用到,统一格式 143 | """ convert text-index into text-label. """ 144 | texts = [] 145 | for index, l in enumerate(length): 146 | text = ''.join([self.character[i] for i in text_index[index, :]]) 147 | texts.append(text) 148 | return texts 149 | 150 | 151 | class Averager(object): 152 | """Compute average for torch.Tensor, used for loss average.""" 153 | 154 | def __init__(self): 155 | self.reset() 156 | 157 | def add(self, v): 158 | count = v.data.numel() # 返回Tensor中元素的数量,对于loss来说,就是1 159 | v = v.data.sum() # 返回Tensor中所有元素的和,对于loss来说,就是loss本身的值 160 | self.n_count += count 161 | self.sum += v 162 | 163 | def reset(self): 164 | self.n_count = 0 165 | self.sum = 0 166 | 167 | def val(self): 168 | res = 0 169 | if self.n_count != 0: 170 | res = self.sum / float(self.n_count) 171 | return res 172 | -------------------------------------------------------------------------------- /modules/transformation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 6 | 7 | 8 | class TPS_SpatialTransformerNetwork(nn.Module): 9 | """ Rectification Network of RARE, namely TPS based STN """ 10 | 11 | def __init__(self, F, I_size, I_r_size, I_channel_num=1): 12 | """ Based on RARE TPS 13 | input: 14 | batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width] 15 | I_size : (height, width) of the input image I 16 | I_r_size : (height, width) of the rectified image I_r 17 | I_channel_num : the number of channels of the input image I 18 | output: 19 | batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width] 20 | """ 21 | super(TPS_SpatialTransformerNetwork, self).__init__() 22 | self.F = F 23 | self.I_size = I_size 24 | self.I_r_size = I_r_size # = (I_r_height, I_r_width) 25 | self.I_channel_num = I_channel_num 26 | self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num) 27 | self.GridGenerator = GridGenerator(self.F, self.I_r_size) 28 | 29 | def forward(self, batch_I): 30 | batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 31 | build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2 32 | build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2]) 33 | 34 | if torch.__version__ > "1.2.0": 35 | batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True) 36 | else: 37 | batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border') 38 | 39 | return batch_I_r 40 | 41 | 42 | class LocalizationNetwork(nn.Module): 43 | """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """ 44 | 45 | def __init__(self, F, I_channel_num): 46 | super(LocalizationNetwork, self).__init__() 47 | self.F = F 48 | self.I_channel_num = I_channel_num 49 | self.conv = nn.Sequential( 50 | nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1, 51 | bias=False), nn.BatchNorm2d(64), nn.ReLU(True), 52 | nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2 53 | nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), 54 | nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4 55 | nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), 56 | nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8 57 | nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), 58 | nn.AdaptiveAvgPool2d(1) # batch_size x 512 59 | ) 60 | 61 | self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True)) 62 | self.localization_fc2 = nn.Linear(256, self.F * 2) 63 | 64 | # Init fc2 in LocalizationNetwork 65 | self.localization_fc2.weight.data.fill_(0) 66 | """ see RARE paper Fig. 6 (a) """ 67 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 68 | ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2)) 69 | ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2)) 70 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 71 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 72 | initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 73 | self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1) 74 | 75 | def forward(self, batch_I): 76 | """ 77 | input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width] 78 | output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2] 79 | """ 80 | batch_size = batch_I.size(0) 81 | features = self.conv(batch_I).view(batch_size, -1) 82 | batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2) 83 | return batch_C_prime 84 | 85 | 86 | class GridGenerator(nn.Module): 87 | """ Grid Generator of RARE, which produces P_prime by multipling T with P """ 88 | 89 | def __init__(self, F, I_r_size): 90 | """ Generate P_hat and inv_delta_C for later """ 91 | super(GridGenerator, self).__init__() 92 | self.eps = 1e-6 93 | self.I_r_height, self.I_r_width = I_r_size 94 | self.F = F 95 | self.C = self._build_C(self.F) # F x 2 96 | self.P = self._build_P(self.I_r_width, self.I_r_height) 97 | ## for multi-gpu, you need register buffer 98 | self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3 99 | self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3 100 | ## for fine-tuning with different image width, you may use below instead of self.register_buffer 101 | #self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float().cuda() # F+3 x F+3 102 | #self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float().cuda() # n x F+3 103 | 104 | def _build_C(self, F): 105 | """ Return coordinates of fiducial points in I_r; C """ 106 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 107 | ctrl_pts_y_top = -1 * np.ones(int(F / 2)) 108 | ctrl_pts_y_bottom = np.ones(int(F / 2)) 109 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 110 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 111 | C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 112 | return C # F x 2 113 | 114 | def _build_inv_delta_C(self, F, C): 115 | """ Return inv_delta_C which is needed to calculate T """ 116 | hat_C = np.zeros((F, F), dtype=float) # F x F 117 | for i in range(0, F): 118 | for j in range(i, F): 119 | r = np.linalg.norm(C[i] - C[j]) 120 | hat_C[i, j] = r 121 | hat_C[j, i] = r 122 | np.fill_diagonal(hat_C, 1) 123 | hat_C = (hat_C ** 2) * np.log(hat_C) 124 | # print(C.shape, hat_C.shape) 125 | delta_C = np.concatenate( # F+3 x F+3 126 | [ 127 | np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3 128 | np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3 129 | np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3 130 | ], 131 | axis=0 132 | ) 133 | inv_delta_C = np.linalg.inv(delta_C) 134 | return inv_delta_C # F+3 x F+3 135 | 136 | def _build_P(self, I_r_width, I_r_height): 137 | I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width # self.I_r_width 138 | I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height # self.I_r_height 139 | P = np.stack( # self.I_r_width x self.I_r_height x 2 140 | np.meshgrid(I_r_grid_x, I_r_grid_y), 141 | axis=2 142 | ) 143 | return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2 144 | 145 | def _build_P_hat(self, F, C, P): 146 | n = P.shape[0] # n (= self.I_r_width x self.I_r_height) 147 | P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2 148 | C_tile = np.expand_dims(C, axis=0) # 1 x F x 2 149 | P_diff = P_tile - C_tile # n x F x 2 150 | rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F 151 | rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F 152 | P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1) 153 | return P_hat # n x F+3 154 | 155 | def build_P_prime(self, batch_C_prime): 156 | """ Generate Grid from batch_C_prime [batch_size x F x 2] """ 157 | batch_size = batch_C_prime.size(0) 158 | batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1) 159 | batch_P_hat = self.P_hat.repeat(batch_size, 1, 1) 160 | batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros( 161 | batch_size, 3, 2).float().to(device)), dim=1) # batch_size x F+3 x 2 162 | batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2 163 | batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2 164 | return batch_P_prime # batch_size x n x 2 165 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Meta Self-Learning for Multi-Source Domain Adaptation: A Benchmark [![Tweet](https://img.shields.io/twitter/url/http/shields.io.svg?style=social)](https://twitter.com/intent/tweet?text=Codes%20and%20Data%20for%20Our%20Paper:%20"Meta%20Self-Learning%20for%20Multi-Source%20Domain%20Adaptation:%20A%20Benchmark"%20&url=https://github.com/bupt-ai-cz/Meta-SelfLearning) 2 | 3 | [Project](https://bupt-ai-cz.github.io/Meta-SelfLearning/) | [Arxiv](https://arxiv.org/abs/2108.10840) | [YouTube](https://youtu.be/NaakbL4tPJw) | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/meta-self-learning-for-multi-source-domain/scene-text-recognition-on-msda)](https://paperswithcode.com/sota/scene-text-recognition-on-msda?p=meta-self-learning-for-multi-source-domain) | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/meta-self-learning-for-multi-source-domain/domain-adaptation-on-msda)](https://paperswithcode.com/sota/domain-adaptation-on-msda?p=meta-self-learning-for-multi-source-domain) 4 | 5 | ## News 6 | - ⚡(2022-3-5): We released pre-trained models. Download from[百度云-BaiDuYun](https://pan.baidu.com/s/1sKMzdx20IbKdM1wgyEVrzA) with Fetch code: 7cwt 7 | 8 | --- 9 | ![dataset1](imgs/dataset1.png) 10 | --- 11 | 12 | ## Abstract 13 | 14 | In recent years, deep learning-based methods have shown promising results in computer vision area. However, a common deep learning model requires a large amount of labeled data, which is labor-intensive to collect and label. What’s more, the model can be ruined due to the domain shift between training data and testing data. Text recognition is a broadly studied field in computer vision and suffers from the same problems noted above due to the diversity of fonts and complicated backgrounds. In this paper, we focus on the text recognition problem and mainly make three contributions toward these problems. First, we collect a multi-source domain adaptation dataset for text recognition, including five different domains with over five million images, which is the first multi-domain text recognition dataset to our best knowledge. Secondly, we propose a new method called Meta Self-Learning, which combines the self-learning method with the meta-learning paradigm and achieves a better recognition result under the scene of multi domain adaptation. Thirdly, extensive experiments are conducted on the dataset to provide a benchmark and also show the effectiveness of our method. 15 | 16 | ## Data Prepare 17 | 18 | Download the dataset from [here](https://github.com/bupt-ai-cz/Meta-SelfLearning/issues/5). 19 | 20 | Before using the raw data, you need to convert it to lmdb dataset. 21 | ``` 22 | python create_lmdb_dataset.py --inputPath data/ --gtFile data/gt.txt --outputPath result/ 23 | ``` 24 | The data folder should be organized below 25 | ``` 26 | data 27 | ├── train_label.txt 28 | └── imgs 29 | ├── 000000001.png 30 | ├── 000000002.png 31 | ├── 000000003.png 32 | └── ... 33 | ``` 34 | The format of train_label.txt should be `{imagepath}\t{label}\n` 35 | For example, 36 | 37 | ``` 38 | imgs/000000001.png Tiredness 39 | imgs/000000002.png kills 40 | imgs/000000003.png A 41 | ``` 42 | 43 | ## Requirements 44 | * Python == 3.7 45 | * Pytorch == 1.7.0 46 | * torchvision == 0.8.1 47 | 48 | - Linux or OSX 49 | - NVIDIA GPU + CUDA CuDNN (CPU mode and CUDA without CuDNN may work with minimal modification, but untested) 50 | 51 | ## Argument 52 | * `--train_data`: folder path to training lmdb dataset. 53 | * `--valid_data`: folder path to validation lmdb dataset. 54 | * `--select_data`: select training data, examples are shown below 55 | * `--batch_ratio`: assign ratio for each selected data in the batch. 56 | * `--Transformation`: select Transformation module [None | TPS], in our method, we use None only. 57 | * `--FeatureExtraction`: select FeatureExtraction module [VGG | RCNN | ResNet], in our method, we use ResNet only. 58 | * `--SequenceModeling`: select SequenceModeling module [None | BiLSTM], in our method, we use BiLSTM only. 59 | * `--Prediction`: select Prediction module [CTC | Attn], in our method, we use Attn only. 60 | * `--saved_model`: path to a pretrained model. 61 | * `--valInterval`: iteration interval for validation. 62 | * `--inner_loop`: update steps in the meta update, default is 1. 63 | * `--source_num`: number of source domains, default is 4. 64 | 65 | ## Get started 66 | - Install [PyTorch](http://pytorch.org) and 0.4+ and other dependencies (e.g., torchvision, [visdom](https://github.com/facebookresearch/visdom) and [dominate](https://github.com/Knio/dominate)). 67 | - For pip users, please type the command `pip install -r requirements.txt`. 68 | - For Conda users, you can create a new Conda environment using `conda env create -f environment.yml`. 69 | 70 | - Clone this repo: 71 | ```bash 72 | git clone https://github.com/bupt-ai-cz/Meta-SelfLearning.git 73 | cd Meta-SelfLearning 74 | ``` 75 | 76 | #### To train the baseline model for synthetic domain. 77 | ``` 78 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=0 python train.py \ 79 | --train_data data/train/ \ 80 | --select_data car-doc-street-handwritten \ 81 | --batch_ratio 0.25-0.25-0.25-0.25 \ 82 | --valid_data data/test/syn \ 83 | --Transformation None --FeatureExtraction ResNet \ 84 | --SequenceModeling BiLSTM --Prediction Attn \ 85 | --batch_size 96 --valInterval 5000 86 | ``` 87 | 88 | #### To train the meta_train model for synthetic domain using the pretrained model. 89 | ``` 90 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=0 python meta_train.py 91 | --train_data data/train/ \ 92 | --select_data car-doc-street-handwritten \ 93 | --batch_ratio 0.25-0.25-0.25-0.25 \ 94 | --valid_data data/test/syn/ \ 95 | --Transformation None --FeatureExtraction ResNet \ 96 | --SequenceModeling BiLSTM --Prediction Attn \ 97 | --batch_size 96 --source_num 4 \ 98 | --valInterval 5000 --inner_loop 1\ 99 | --saved_model saved_models/pretrained.pth 100 | ``` 101 | 102 | #### To train the pseudo-label model for synthetic domain. 103 | ``` 104 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=0 python self_training.py 105 | --train_data data/train \ 106 | —-select_data car-doc-street-handwritten \ 107 | --batch_ratio 0.25-0.25-0.25-0.25 \ 108 | --valid_data data/train/syn \ 109 | --test_data data/test/syn \ 110 | --Transformation None --FeatureExtraction ResNet \ 111 | --SequenceModeling BiLSTM --Prediction Attn \ 112 | --batch_size 96 --source_num 4 \ 113 | --warmup_threshold 28 --pseudo_threshold 0.9 \ 114 | --pseudo_dataset_num 50000 --valInterval 5000 \ 115 | --saved_model saved_models/pretrained.pth 116 | ``` 117 | #### To train the meta self-learning model for synthetic domain. 118 | ``` 119 | OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=0 python meta_self_learning.py 120 | --train_data data/train \ 121 | —-select_data car-doc-street-handwritten \ 122 | --batch_ratio 0.25-0.25-0.25-0.25 \ 123 | --valid_data data/train/syn \ 124 | --test_data data/test/syn \ 125 | --Transformation None --FeatureExtraction ResNet \ 126 | --SequenceModeling BiLSTM --Prediction Attn \ 127 | --batch_size 96 --source_num 4 \ 128 | --warmup_threshold 0 --pseudo_threshold 0.9 \ 129 | --pseudo_dataset_num 50000 --valInterval 5000 --inner_loop 1 \ 130 | --saved_model pretrained_model/pretrained.pth 131 | ``` 132 | ## Citation 133 | If you use this data or code for your research, please cite our paper [Meta Self-Learning for Multi-Source Domain Adaptation: A Benchmark](https://arxiv.org/abs/2108.10840) 134 | 135 | ``` 136 | @inproceedings{qiu2021meta, 137 | title={Meta Self-Learning for Multi-Source Domain Adaptation: A Benchmark}, 138 | author={Qiu, Shuhao and Zhu, Chuang and Zhou, Wenli}, 139 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 140 | pages={1592--1601}, 141 | year={2021} 142 | } 143 | ``` 144 | 145 | ## License 146 | This Dataset is made freely available to academic and non-academic entities for non-commercial purposes such as academic research, teaching, scientific publications, or personal experimentation. Permission is granted to use the data given that you agree to our license terms bellow: 147 | 148 | 1. That you include a reference to our paper in any work that makes use of the dataset. 149 | 2. That you may not use the dataset or any derivative work for commercial purposes. 150 | 151 | ## Privacy 152 | Part of the data is constructed based on the processing of existing databases. Part of the data is crawled online or captured by ourselves. Part of the data is newly generated. We prohibit you from using the Datasets in any manner to identify or invade the privacy of any person. If you have any privacy concerns, including to remove your information from the Dataset, please contact us. 153 | 154 | ## Contact 155 | * email: czhu@bupt.edu.cn; qiushuhao@bupt.edu.cn 156 | 157 | ## Reference 158 | * https://github.com/YCG09/chinese_ocr 159 | * [Synthetic data and artificial neural networks for natural scene text recognition](https://arxiv.org/abs/1406.2227) 160 | * [Icdar 2015 competition on robust reading](https://ieeexplore.ieee.org/abstract/document/7333942) 161 | * [Icdar 2013 robust reading competition](https://ieeexplore.ieee.org/abstract/document/6628859) 162 | * [Casia online and offline chinese handwriting databases](https://ieeexplore.ieee.org/abstract/document/6065272) 163 | * [Icdar2019 robust reading challenge on multi-lingual scene text detection and recognition—rrc-mlt-2019](https://ieeexplore.ieee.org/abstract/document/8978096) 164 | * [Icdar2017 robust reading challenge on multi-lingual scene text detection and script identification-rrc-mlt](https://ieeexplore.ieee.org/abstract/document/8270168) 165 | * [A robust arbitrary text detection system for natural scene images](https://www.sciencedirect.com/science/article/abs/pii/S0957417414004060) 166 | * [End-to-end scene text recognition](https://ieeexplore.ieee.org/abstract/document/6126402) 167 | * [Towards end-to-end license plate detection and recognition: A large dataset and baseline](https://openaccess.thecvf.com/content_ECCV_2018/html/Zhenbo_Xu_Towards_End-to-End_License_ECCV_2018_paper.html) 168 | -------------------------------------------------------------------------------- /modules/feature_extraction.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | class VGG_FeatureExtractor(nn.Module): 5 | """ FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf) """ 6 | 7 | def __init__(self, input_channel, output_channel=512): 8 | super(VGG_FeatureExtractor, self).__init__() 9 | self.output_channel = [int(output_channel / 8), int(output_channel / 4), 10 | int(output_channel / 2), output_channel] # [64, 128, 256, 512] 11 | self.ConvNet = nn.Sequential( 12 | nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True), 13 | nn.MaxPool2d(2, 2), # 64x16x50 14 | nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1), nn.ReLU(True), 15 | nn.MaxPool2d(2, 2), # 128x8x25 16 | nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), nn.ReLU(True), # 256x8x25 17 | nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1), nn.ReLU(True), 18 | nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25 19 | nn.Conv2d(self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False), 20 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), # 512x4x25 21 | nn.Conv2d(self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False), 22 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), 23 | nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25 24 | nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), nn.ReLU(True)) # 512x1x24 25 | 26 | def forward(self, input): 27 | return self.ConvNet(input) 28 | 29 | 30 | class RCNN_FeatureExtractor(nn.Module): 31 | """ FeatureExtractor of GRCNN (https://papers.nips.cc/paper/6637-gated-recurrent-convolution-neural-network-for-ocr.pdf) """ 32 | 33 | def __init__(self, input_channel, output_channel=512): 34 | super(RCNN_FeatureExtractor, self).__init__() 35 | self.output_channel = [int(output_channel / 8), int(output_channel / 4), 36 | int(output_channel / 2), output_channel] # [64, 128, 256, 512] 37 | self.ConvNet = nn.Sequential( 38 | nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True), 39 | nn.MaxPool2d(2, 2), # 64 x 16 x 50 40 | GRCL(self.output_channel[0], self.output_channel[0], num_iteration=5, kernel_size=3, pad=1), 41 | nn.MaxPool2d(2, 2), # 64 x 8 x 25 42 | GRCL(self.output_channel[0], self.output_channel[1], num_iteration=5, kernel_size=3, pad=1), 43 | nn.MaxPool2d(2, (2, 1), (0, 1)), # 128 x 4 x 26 44 | GRCL(self.output_channel[1], self.output_channel[2], num_iteration=5, kernel_size=3, pad=1), 45 | nn.MaxPool2d(2, (2, 1), (0, 1)), # 256 x 2 x 27 46 | nn.Conv2d(self.output_channel[2], self.output_channel[3], 2, 1, 0, bias=False), 47 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True)) # 512 x 1 x 26 48 | 49 | def forward(self, input): 50 | return self.ConvNet(input) 51 | 52 | 53 | class ResNet_FeatureExtractor(nn.Module): 54 | """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """ 55 | 56 | def __init__(self, input_channel, output_channel=512): 57 | super(ResNet_FeatureExtractor, self).__init__() 58 | self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3]) 59 | 60 | def forward(self, input): 61 | return self.ConvNet(input) 62 | 63 | 64 | class SEResNet50_FeatureExtractor(nn.Module): 65 | def __init__(self, input_channel, output_channel=512): 66 | super(SEResNet50_FeatureExtractor, self).__init__() 67 | # self.ConvNet = se_resnet50_wfc(input_channel=input_channel, output_channel=output_channel) 68 | self.ConvNet = ResNet(input_channel, output_channel, SEBlock, [1, 2, 5, 3]) 69 | 70 | def forward(self, input): 71 | return self.ConvNet(input) 72 | 73 | # For Gated RCNN 74 | class GRCL(nn.Module): 75 | 76 | def __init__(self, input_channel, output_channel, num_iteration, kernel_size, pad): 77 | super(GRCL, self).__init__() 78 | self.wgf_u = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=False) 79 | self.wgr_x = nn.Conv2d(output_channel, output_channel, 1, 1, 0, bias=False) 80 | self.wf_u = nn.Conv2d(input_channel, output_channel, kernel_size, 1, pad, bias=False) 81 | self.wr_x = nn.Conv2d(output_channel, output_channel, kernel_size, 1, pad, bias=False) 82 | 83 | self.BN_x_init = nn.BatchNorm2d(output_channel) 84 | 85 | self.num_iteration = num_iteration 86 | self.GRCL = [GRCL_unit(output_channel) for _ in range(num_iteration)] 87 | self.GRCL = nn.Sequential(*self.GRCL) 88 | 89 | def forward(self, input): 90 | """ The input of GRCL is consistant over time t, which is denoted by u(0) 91 | thus wgf_u / wf_u is also consistant over time t. 92 | """ 93 | wgf_u = self.wgf_u(input) 94 | wf_u = self.wf_u(input) 95 | x = F.relu(self.BN_x_init(wf_u)) 96 | 97 | for i in range(self.num_iteration): 98 | x = self.GRCL[i](wgf_u, self.wgr_x(x), wf_u, self.wr_x(x)) 99 | 100 | return x 101 | 102 | 103 | class GRCL_unit(nn.Module): 104 | 105 | def __init__(self, output_channel): 106 | super(GRCL_unit, self).__init__() 107 | self.BN_gfu = nn.BatchNorm2d(output_channel) 108 | self.BN_grx = nn.BatchNorm2d(output_channel) 109 | self.BN_fu = nn.BatchNorm2d(output_channel) 110 | self.BN_rx = nn.BatchNorm2d(output_channel) 111 | self.BN_Gx = nn.BatchNorm2d(output_channel) 112 | 113 | def forward(self, wgf_u, wgr_x, wf_u, wr_x): 114 | G_first_term = self.BN_gfu(wgf_u) 115 | G_second_term = self.BN_grx(wgr_x) 116 | G = F.sigmoid(G_first_term + G_second_term) 117 | 118 | x_first_term = self.BN_fu(wf_u) 119 | x_second_term = self.BN_Gx(self.BN_rx(wr_x) * G) 120 | x = F.relu(x_first_term + x_second_term) 121 | 122 | return x 123 | 124 | 125 | class SEBlock(nn.Module): 126 | expansion = 1 127 | 128 | def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=16): 129 | super(SEBlock, self).__init__() 130 | self.conv1 = self._conv3x3(inplanes, planes) 131 | self.bn1 = nn.BatchNorm2d(planes) 132 | self.conv2 = self._conv3x3(planes, planes) 133 | self.bn2 = nn.BatchNorm2d(planes) 134 | 135 | self.conv3 = nn.Conv2d(planes, planes, kernel_size=1, bias=False) 136 | self.bn3 = nn.BatchNorm2d(planes) 137 | self.se = SELayer(planes, reduction) 138 | 139 | self.relu = nn.ReLU(inplace=True) 140 | self.downsample = downsample 141 | self.stride = stride 142 | 143 | def _conv3x3(self, in_planes, out_planes, stride=1): 144 | "3x3 convolution with padding" 145 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 146 | padding=1, bias=False) 147 | 148 | def forward(self, x): 149 | residual = x 150 | 151 | out = self.conv1(x) 152 | out = self.bn1(out) 153 | out = self.relu(out) 154 | 155 | out = self.conv2(out) 156 | out = self.bn2(out) 157 | out = self.relu(out) 158 | 159 | out = self.conv3(out) 160 | out = self.bn3(out) 161 | out = self.se(out) 162 | 163 | if self.downsample is not None: 164 | residual = self.downsample(x) 165 | out += residual 166 | out = self.relu(out) 167 | 168 | return out 169 | 170 | 171 | class BasicBlock(nn.Module): 172 | expansion = 1 173 | 174 | def __init__(self, inplanes, planes, stride=1, downsample=None): 175 | super(BasicBlock, self).__init__() 176 | self.conv1 = self._conv3x3(inplanes, planes) 177 | self.bn1 = nn.BatchNorm2d(planes) 178 | self.conv2 = self._conv3x3(planes, planes) 179 | self.bn2 = nn.BatchNorm2d(planes) 180 | self.relu = nn.ReLU(inplace=True) 181 | self.downsample = downsample 182 | self.stride = stride 183 | 184 | def _conv3x3(self, in_planes, out_planes, stride=1): 185 | "3x3 convolution with padding" 186 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 187 | padding=1, bias=False) 188 | 189 | def forward(self, x): 190 | residual = x 191 | 192 | out = self.conv1(x) 193 | out = self.bn1(out) 194 | out = self.relu(out) 195 | 196 | out = self.conv2(out) 197 | out = self.bn2(out) 198 | 199 | if self.downsample is not None: 200 | residual = self.downsample(x) 201 | out += residual 202 | out = self.relu(out) 203 | 204 | return out 205 | 206 | 207 | class ResNet(nn.Module): 208 | 209 | def __init__(self, input_channel, output_channel, block, layers): 210 | super(ResNet, self).__init__() 211 | 212 | self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel] 213 | 214 | self.inplanes = int(output_channel / 8) 215 | self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16), 216 | kernel_size=3, stride=1, padding=1, bias=False) 217 | self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16)) 218 | self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes, 219 | kernel_size=3, stride=1, padding=1, bias=False) 220 | self.bn0_2 = nn.BatchNorm2d(self.inplanes) 221 | self.relu = nn.ReLU(inplace=True) 222 | 223 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 224 | self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0]) 225 | self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[ 226 | 0], kernel_size=3, stride=1, padding=1, bias=False) 227 | self.bn1 = nn.BatchNorm2d(self.output_channel_block[0]) 228 | 229 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 230 | self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1) 231 | self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[ 232 | 1], kernel_size=3, stride=1, padding=1, bias=False) 233 | self.bn2 = nn.BatchNorm2d(self.output_channel_block[1]) 234 | 235 | self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1)) 236 | self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1) 237 | self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[ 238 | 2], kernel_size=3, stride=1, padding=1, bias=False) 239 | self.bn3 = nn.BatchNorm2d(self.output_channel_block[2]) 240 | 241 | self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1) 242 | self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ 243 | 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False) 244 | self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3]) 245 | self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ 246 | 3], kernel_size=2, stride=1, padding=0, bias=False) 247 | self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3]) 248 | 249 | def _make_layer(self, block, planes, blocks, stride=1): 250 | downsample = None 251 | if stride != 1 or self.inplanes != planes * block.expansion: 252 | downsample = nn.Sequential( 253 | nn.Conv2d(self.inplanes, planes * block.expansion, 254 | kernel_size=1, stride=stride, bias=False), 255 | nn.BatchNorm2d(planes * block.expansion), 256 | ) 257 | 258 | layers = [] 259 | layers.append(block(self.inplanes, planes, stride, downsample)) 260 | self.inplanes = planes * block.expansion 261 | for i in range(1, blocks): 262 | layers.append(block(self.inplanes, planes)) 263 | 264 | return nn.Sequential(*layers) 265 | 266 | def forward(self, x): 267 | x = self.conv0_1(x) 268 | x = self.bn0_1(x) 269 | x = self.relu(x) 270 | x = self.conv0_2(x) 271 | x = self.bn0_2(x) 272 | x = self.relu(x) 273 | 274 | x = self.maxpool1(x) 275 | x = self.layer1(x) 276 | x = self.conv1(x) 277 | x = self.bn1(x) 278 | x = self.relu(x) 279 | 280 | x = self.maxpool2(x) 281 | x = self.layer2(x) 282 | x = self.conv2(x) 283 | x = self.bn2(x) 284 | x = self.relu(x) 285 | 286 | x = self.maxpool3(x) 287 | x = self.layer3(x) 288 | x = self.conv3(x) 289 | x = self.bn3(x) 290 | x = self.relu(x) 291 | 292 | x = self.layer4(x) 293 | x = self.conv4_1(x) 294 | x = self.bn4_1(x) 295 | x = self.relu(x) 296 | x = self.conv4_2(x) 297 | x = self.bn4_2(x) 298 | x = self.relu(x) 299 | 300 | return x 301 | -------------------------------------------------------------------------------- /modules/prediction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.init import xavier_uniform_ 5 | from torch.nn.init import constant_ 6 | from torch.nn.init import xavier_normal_ 7 | import math 8 | import copy 9 | 10 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 11 | 12 | 13 | class Attention(nn.Module): 14 | 15 | def __init__(self, input_size, hidden_size, num_classes): 16 | super(Attention, self).__init__() 17 | self.input_size = input_size 18 | self.attention_cell = AttentionCell(input_size, hidden_size, num_classes) 19 | self.hidden_size = hidden_size 20 | self.num_classes = num_classes # charset_len + [Go] + [s] 21 | self.generator = nn.Linear(hidden_size, num_classes) 22 | 23 | def _char_to_onehot(self, input_char, onehot_dim=38): 24 | input_char = input_char.unsqueeze(1) # [batch_size] -> [batch_size, 1] 25 | batch_size = input_char.size(0) 26 | one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(device) 27 | one_hot = one_hot.scatter_(1, input_char, 1) # scatter_(dim, index, src, reduce=None) → Tensor 28 | return one_hot 29 | 30 | def forward(self, batch_H, text, is_train=True, batch_max_length=25, is_domain=False): 31 | """ 32 | input: 33 | batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x contextual_feature_channels] 34 | text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO]. 35 | output: probability distribution at each step [batch_size x num_steps x num_classes] 36 | """ 37 | batch_size = batch_H.size(0) 38 | num_steps = batch_max_length + 1 # +1 for [s] at end of sentence. 39 | 40 | output_hiddens = torch.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0).to(device) 41 | hidden = (torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device), 42 | torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device)) 43 | if is_domain: self.context_history = torch.FloatTensor(batch_size, num_steps, self.input_size).fill_(0).to(device) 44 | 45 | if is_train: 46 | for i in range(num_steps): 47 | # one-hot vectors for a i-th char. in a batch 48 | char_onehots = self._char_to_onehot(text[:, i], onehot_dim=self.num_classes) 49 | # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_onehots : one-hot(y_{t-1}) 50 | if not is_domain: hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots, is_domain) 51 | else: 52 | hidden, alpha, context = self.attention_cell(hidden, batch_H, char_onehots, is_domain) 53 | self.context_history[:, i, :] = context 54 | output_hiddens[:, i, :] = hidden[0] # LSTM hidden index (0: hidden, 1: Cell) 55 | probs = self.generator(output_hiddens) 56 | else: 57 | targets = torch.LongTensor(batch_size).fill_(0).to(device) # [GO] token 58 | probs = torch.FloatTensor(batch_size, num_steps, self.num_classes).fill_(0).to(device) 59 | 60 | for i in range(num_steps): 61 | char_onehots = self._char_to_onehot(targets, onehot_dim=self.num_classes) 62 | if not is_domain: hidden, alpha = self.attention_cell(hidden, batch_H, char_onehots, is_domain) 63 | else: 64 | hidden, alpha, context = self.attention_cell(hidden, batch_H, char_onehots, is_domain) 65 | self.context_history[:, i, :] = context 66 | probs_step = self.generator(hidden[0]) 67 | probs[:, i, :] = probs_step 68 | _, next_input = probs_step.max(1) 69 | targets = next_input 70 | 71 | return probs # batch_size x num_steps x num_classes 72 | 73 | class AttentionCell(nn.Module): 74 | 75 | def __init__(self, input_size, hidden_size, num_embeddings): 76 | super(AttentionCell, self).__init__() 77 | self.i2h = nn.Linear(input_size, hidden_size, bias=False) 78 | self.h2h = nn.Linear(hidden_size, hidden_size) # either i2i or h2h should have bias 79 | self.score = nn.Linear(hidden_size, 1, bias=False) 80 | self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size) 81 | self.hidden_size = hidden_size 82 | 83 | def forward(self, prev_hidden, batch_H, char_onehots, is_domain=False): 84 | ''' 85 | input: 86 | prev_hidden: 一个包含两个元素的元组,每个元素均为一个形状为[batch_size, self.hidden_size]的隐藏状态向量,其中[0]位置上为上一时间步的隐藏状态,[1]位置上位cell状态 87 | batch_H: encoder的输出向量,[batch_size, num_steps, contextual_feature_channels],即encoder输出的所有时间步上的隐藏状态 88 | ''' 89 | # 1,将encoder所有时间步的输出经过全连接层 90 | batch_H_proj = self.i2h(batch_H) # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size] 91 | # 2. 将上一时间步的隐藏状态经过全连接层 92 | prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1) # [batch_size, 1, self.hidden_size] 93 | # 3. 将1、2得到的向量相加,注意,这里相加用到了如下广播机制 94 | # 形状为[batch_size, 1, self.hidden_size]的向量A和形状为[batch_size, num_encoder_step, hidden_size]的向量B相加 95 | # 就是把A中每个batch的值复制num_encoder_step次,然后加到B上 96 | e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step x 1 97 | # 4. 以上两步实际相当于实现了tanh(Ws * s + Wh * h),也就是公式中的e,接下来在时间轴上对对e进行softmax 98 | alpha = F.softmax(e, dim=1) # batch_size x num_encoder_step x 1 99 | # 5. 使用4中权重对encoder输出所有时间步上的隐藏状态计算出一个当前时间步的context向量,理论上已经可以直接softmax计算概率了 100 | # alpha经过permute后,形状为 batch_szie x 1 x num_encoder_step, batch_H的形状为 batch_size x num_encoder_steps x contextual_feature_channels 101 | # 相当于就是,得到了encoder个时间步上向量的一个加权平均 102 | # 其实可以看出来,decoder完全是依据encoder的输出直接进行的解码 103 | context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel 104 | # 6. 但后面还是先接上了前一时间步的ground truth的onehot编码(在最前面添加了[GO]作为起始标识,所以其实是前一时间步的gt,也就是teacher forcing),用LSTM再编码了一次,再计算概率 105 | concat_context = torch.cat([context, char_onehots], 1) # batch_size x (num_channel + num_embedding) 106 | cur_hidden = self.rnn(concat_context, prev_hidden) 107 | if is_domain: return cur_hidden, alpha, context 108 | else: return cur_hidden, alpha 109 | 110 | def _get_clones(module, N): 111 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 112 | 113 | class TransformerEncoder(nn.Module): 114 | __constants__ = ['norm'] 115 | 116 | def __init__(self, encoder_layer, num_layers, norm=None): 117 | super(TransformerEncoder, self).__init__() 118 | self.layers = _get_clones(encoder_layer, num_layers) 119 | self.num_layers = num_layers 120 | self.norm = norm 121 | 122 | def forward(self, src, mask=None, src_key_padding_mask=None): 123 | # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor 124 | r"""Pass the input through the encoder layers in turn. 125 | Args: 126 | src: the sequence to the encoder (required). 127 | mask: the mask for the src sequence (optional). 128 | src_key_padding_mask: the mask for the src keys per batch (optional). 129 | Shape: 130 | see the docs in Transformer class. 131 | """ 132 | output = src 133 | 134 | for mod in self.layers: 135 | output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask) 136 | 137 | if self.norm is not None: 138 | output = self.norm(output) 139 | 140 | return output 141 | 142 | 143 | def _get_activation_fn(activation): 144 | if activation == "relu": 145 | return F.relu 146 | elif activation == "gelu": 147 | return F.gelu 148 | 149 | raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) 150 | 151 | 152 | class MultiheadAttention(nn.Module): 153 | __annotations__ = { 154 | 'bias_k': torch._jit_internal.Optional[torch.Tensor], 155 | 'bias_v': torch._jit_internal.Optional[torch.Tensor], 156 | } 157 | __constants__ = ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight'] 158 | 159 | def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None): 160 | super(MultiheadAttention, self).__init__() 161 | self.embed_dim = embed_dim 162 | self.kdim = kdim if kdim is not None else embed_dim 163 | self.vdim = vdim if vdim is not None else embed_dim 164 | self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim 165 | 166 | self.num_heads = num_heads 167 | self.dropout = dropout 168 | self.head_dim = embed_dim // num_heads 169 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" 170 | 171 | if self._qkv_same_embed_dim is False: 172 | self.q_proj_weight = nn.Parameter(torch.Tensor(embed_dim, embed_dim)) 173 | self.k_proj_weight = nn.Parameter(torch.Tensor(embed_dim, self.kdim)) 174 | self.v_proj_weight = nn.Parameter(torch.Tensor(embed_dim, self.vdim)) 175 | self.register_parameter('in_proj_weight', None) 176 | else: 177 | self.in_proj_weight = nn.Parameter(torch.empty(3 * embed_dim, embed_dim)) 178 | self.register_parameter('q_proj_weight', None) 179 | self.register_parameter('k_proj_weight', None) 180 | self.register_parameter('v_proj_weight', None) 181 | 182 | if bias: 183 | self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim)) 184 | else: 185 | self.register_parameter('in_proj_bias', None) 186 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 187 | 188 | if add_bias_kv: 189 | self.bias_k = nn.Parameter(torch.empty(1, 1, embed_dim)) 190 | self.bias_v = nn.Parameter(torch.empty(1, 1, embed_dim)) 191 | else: 192 | self.bias_k = self.bias_v = None 193 | 194 | self.add_zero_attn = add_zero_attn 195 | 196 | self._reset_parameters() 197 | 198 | def _reset_parameters(self): 199 | if self._qkv_same_embed_dim: 200 | xavier_uniform_(self.in_proj_weight) 201 | else: 202 | xavier_uniform_(self.q_proj_weight) 203 | xavier_uniform_(self.k_proj_weight) 204 | xavier_uniform_(self.v_proj_weight) 205 | 206 | if self.in_proj_bias is not None: 207 | constant_(self.in_proj_bias, 0.) 208 | constant_(self.out_proj.bias, 0.) 209 | if self.bias_k is not None: 210 | xavier_normal_(self.bias_k) 211 | if self.bias_v is not None: 212 | xavier_normal_(self.bias_v) 213 | 214 | def __setstate__(self, state): 215 | # Support loading old MultiheadAttention checkpoints generated by v1.1.0 216 | if '_qkv_same_embed_dim' not in state: 217 | state['_qkv_same_embed_dim'] = True 218 | 219 | super(MultiheadAttention, self).__setstate__(state) 220 | 221 | def forward(self, query, key, value, key_padding_mask=None, 222 | need_weights=True, attn_mask=None): 223 | if not self._qkv_same_embed_dim: 224 | return F.multi_head_attention_forward( 225 | query, key, value, self.embed_dim, self.num_heads, 226 | self.in_proj_weight, self.in_proj_bias, 227 | self.bias_k, self.bias_v, self.add_zero_attn, 228 | self.dropout, self.out_proj.weight, self.out_proj.bias, 229 | training=self.training, 230 | key_padding_mask=key_padding_mask, need_weights=need_weights, 231 | attn_mask=attn_mask, use_separate_proj_weight=True, 232 | q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, 233 | v_proj_weight=self.v_proj_weight) 234 | else: 235 | return F.multi_head_attention_forward( 236 | query, key, value, self.embed_dim, self.num_heads, 237 | self.in_proj_weight, self.in_proj_bias, 238 | self.bias_k, self.bias_v, self.add_zero_attn, 239 | self.dropout, self.out_proj.weight, self.out_proj.bias, 240 | training=self.training, 241 | key_padding_mask=key_padding_mask, need_weights=need_weights, 242 | attn_mask=attn_mask) 243 | 244 | 245 | class TransformerEncoderLayer(nn.Module): 246 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"): 247 | super(TransformerEncoderLayer, self).__init__() 248 | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout) 249 | # Implementation of Feedforward model 250 | self.linear1 = nn.Linear(d_model, dim_feedforward) 251 | self.dropout = nn.Dropout(dropout) 252 | self.linear2 = nn.Linear(dim_feedforward, d_model) 253 | 254 | self.norm1 = nn.LayerNorm(d_model) 255 | self.norm2 = nn.LayerNorm(d_model) 256 | self.dropout1 = nn.Dropout(dropout) 257 | self.dropout2 = nn.Dropout(dropout) 258 | 259 | self.activation = _get_activation_fn(activation) 260 | 261 | def __setstate__(self, state): 262 | if 'activation' not in state: 263 | state['activation'] = F.relu 264 | super(TransformerEncoderLayer, self).__setstate__(state) 265 | 266 | def forward(self, src, src_mask=None, src_key_padding_mask=None): 267 | src2 = self.self_attn(src, src, src, attn_mask=src_mask, 268 | key_padding_mask=src_key_padding_mask)[0] 269 | src = src + self.dropout1(src2) 270 | src = self.norm1(src) 271 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 272 | src = src + self.dropout2(src2) 273 | src = self.norm2(src) 274 | return src 275 | 276 | 277 | class PositionalEncoding(nn.Module): 278 | 279 | def __init__(self, d_model, dropout=0.1, max_len=5000): 280 | super(PositionalEncoding, self).__init__() 281 | self.dropout = nn.Dropout(p=dropout) 282 | 283 | pe = torch.zeros(max_len, d_model) 284 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 285 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 286 | pe[:, 0::2] = torch.sin(position * div_term) 287 | pe[:, 1::2] = torch.cos(position * div_term) 288 | pe = pe.unsqueeze(0).transpose(0, 1) 289 | self.register_buffer('pe', pe) 290 | 291 | def forward(self, x): 292 | x = x + self.pe[:x.size(0), :] 293 | return self.dropout(x) 294 | 295 | 296 | class Transformer(nn.Module): 297 | 298 | def __init__(self, ntoken, ninp, nhid=256, nhead=2, nlayers=2, dropout=0.2): 299 | super(Transformer, self).__init__() 300 | self.src_mask = None 301 | self.pos_encoder = PositionalEncoding(ninp, dropout) 302 | encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout) 303 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) 304 | self.ninp = ninp 305 | self.decoder = nn.Linear(ninp, ntoken) 306 | 307 | self.init_weights() 308 | 309 | def _generate_square_subsequent_mask(self, sz): 310 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) 311 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) 312 | return mask 313 | 314 | def init_weights(self): 315 | initrange = 0.1 316 | self.decoder.bias.data.zero_() 317 | self.decoder.weight.data.uniform_(-initrange, initrange) 318 | 319 | def forward(self, src): 320 | if self.src_mask is None or self.src_mask.size(0) != len(src): 321 | mask = self._generate_square_subsequent_mask(len(src)).to(src.device) 322 | self.src_mask = mask 323 | 324 | src = src * math.sqrt(self.ninp) 325 | src = self.pos_encoder(src) 326 | output = self.transformer_encoder(src, self.src_mask) 327 | output = self.decoder(output) 328 | return output 329 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import os 3 | import time 4 | import string 5 | import argparse 6 | import re 7 | 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | import torch.utils.data 11 | import torch.nn.functional as F 12 | import numpy as np 13 | from nltk.metrics.distance import edit_distance 14 | 15 | from utils import CTCLabelConverter, AttnLabelConverter, Averager 16 | from dataset import hierarchical_dataset, AlignCollate 17 | from model import Model 18 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 19 | 20 | 21 | def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=False): 22 | """ evaluation with 10 benchmark evaluation datasets """ 23 | # The evaluation datasets, dataset order is same with Table 1 in our paper. 24 | eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', 25 | 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80'] 26 | 27 | # # To easily compute the total accuracy of our paper. 28 | # eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_867', 29 | # 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80'] 30 | 31 | if calculate_infer_time: 32 | evaluation_batch_size = 1 # batch_size should be 1 to calculate the GPU inference time per image. 33 | else: 34 | evaluation_batch_size = opt.batch_size 35 | 36 | list_accuracy = [] 37 | total_forward_time = 0 38 | total_evaluation_data_number = 0 39 | total_correct_number = 0 40 | log = open(f'./result/{opt.exp_name}/log_all_evaluation.txt', 'a') 41 | dashed_line = '-' * 80 42 | print(dashed_line) 43 | log.write(dashed_line + '\n') 44 | for eval_data in eval_data_list: 45 | eval_data_path = os.path.join(opt.eval_data, eval_data) 46 | AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 47 | eval_data, eval_data_log = hierarchical_dataset(root=eval_data_path, opt=opt) 48 | evaluation_loader = torch.utils.data.DataLoader( 49 | eval_data, batch_size=evaluation_batch_size, 50 | shuffle=False, 51 | num_workers=int(opt.workers), 52 | collate_fn=AlignCollate_evaluation, pin_memory=True) 53 | 54 | _, accuracy_by_best_model, norm_ED_by_best_model, _, _, _, infer_time, length_of_data = validation( 55 | model, criterion, evaluation_loader, converter, opt) 56 | list_accuracy.append(f'{accuracy_by_best_model:0.3f}') 57 | total_forward_time += infer_time 58 | total_evaluation_data_number += len(eval_data) 59 | total_correct_number += accuracy_by_best_model * length_of_data 60 | log.write(eval_data_log) 61 | print(f'Acc {accuracy_by_best_model:0.3f}\t normalized_ED {norm_ED_by_best_model:0.3f}') 62 | log.write(f'Acc {accuracy_by_best_model:0.3f}\t normalized_ED {norm_ED_by_best_model:0.3f}\n') 63 | print(dashed_line) 64 | log.write(dashed_line + '\n') 65 | 66 | averaged_forward_time = total_forward_time / total_evaluation_data_number * 1000 67 | total_accuracy = total_correct_number / total_evaluation_data_number 68 | params_num = sum([np.prod(p.size()) for p in model.parameters()]) 69 | 70 | evaluation_log = 'accuracy: ' 71 | for name, accuracy in zip(eval_data_list, list_accuracy): 72 | evaluation_log += f'{name}: {accuracy}\t' 73 | evaluation_log += f'total_accuracy: {total_accuracy:0.3f}\t' 74 | evaluation_log += f'averaged_infer_time: {averaged_forward_time:0.3f}\t# parameters: {params_num/1e6:0.3f}' 75 | print(evaluation_log) 76 | log.write(evaluation_log + '\n') 77 | log.close() 78 | 79 | return None 80 | 81 | 82 | def validation(model, criterion, evaluation_loader, converter, opt, self_training=False): 83 | """ validation or evaluation """ 84 | n_correct = 0 85 | norm_ED = 0 86 | length_of_data = 0 87 | infer_time = 0 88 | valid_loss_avg = Averager() 89 | 90 | all_imgs = [] 91 | all_gts = [] 92 | all_pseudo_labels = [] 93 | all_confidences = [] 94 | 95 | for i, (image_tensors, labels) in enumerate(evaluation_loader): 96 | if self_training: 97 | for img in image_tensors: all_imgs.append(img) 98 | all_gts.extend(labels) 99 | batch_size = image_tensors.size(0) 100 | length_of_data = length_of_data + batch_size 101 | image = image_tensors.to(device) 102 | # For max length prediction 103 | length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) 104 | text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) 105 | 106 | text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length) 107 | 108 | start_time = time.time() 109 | if 'CTC' in opt.Prediction: 110 | preds = model(image, text_for_pred) 111 | forward_time = time.time() - start_time 112 | 113 | # Calculate evaluation loss for CTC deocder. 114 | preds_size = torch.IntTensor([preds.size(1)] * batch_size) 115 | # permute 'preds' to use CTCloss format 116 | if opt.baiduCTC: 117 | cost = criterion(preds.permute(1, 0, 2), text_for_loss, preds_size, length_for_loss) / batch_size 118 | else: 119 | cost = criterion(preds.log_softmax(2).permute(1, 0, 2), text_for_loss, preds_size, length_for_loss) 120 | 121 | # Select max probabilty (greedy decoding) then decode index to character 122 | if opt.baiduCTC: 123 | _, preds_index = preds.max(2) 124 | preds_index = preds_index.view(-1) 125 | else: 126 | _, preds_index = preds.max(2) 127 | preds_str = converter.decode(preds_index.data, preds_size.data) 128 | 129 | else: 130 | preds = model(image, text_for_pred, is_train=False) 131 | forward_time = time.time() - start_time 132 | 133 | preds = preds[:, :text_for_loss.shape[1] - 1, :] 134 | target = text_for_loss[:, 1:] # without [GO] Symbol 135 | cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), target.contiguous().view(-1)) 136 | 137 | # select max probabilty (greedy decoding) then decode index to character 138 | _, preds_index = preds.max(2) 139 | preds_str = converter.decode(preds_index, length_for_pred) 140 | labels = converter.decode(text_for_loss[:, 1:], length_for_loss) 141 | 142 | infer_time += forward_time 143 | valid_loss_avg.add(cost) 144 | 145 | all_pseudo_labels.extend(preds_str) 146 | # calculate accuracy & confidence score 147 | preds_prob = F.softmax(preds, dim=2) 148 | preds_max_prob, _ = preds_prob.max(dim=2) 149 | confidence_score_list = [] 150 | # print(len(labels), len(preds_str), len(preds_max_prob)) 151 | for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob): 152 | if 'Attn' in opt.Prediction: 153 | gt = gt[:gt.find('[s]')] 154 | pred_EOS = pred.find('[s]') 155 | pred = pred[:pred_EOS] # prune after "end of sentence" token ([s]) 156 | pred_max_prob = pred_max_prob[:pred_EOS] 157 | 158 | # To evaluate 'case sensitive model' with alphanumeric and case insensitve setting. 159 | if opt.sensitive and opt.data_filtering_off: 160 | pred = pred.lower() 161 | gt = gt.lower() 162 | alphanumeric_case_insensitve = '0123456789abcdefghijklmnopqrstuvwxyz' 163 | out_of_alphanumeric_case_insensitve = f'[^{alphanumeric_case_insensitve}]' 164 | pred = re.sub(out_of_alphanumeric_case_insensitve, '', pred) 165 | gt = re.sub(out_of_alphanumeric_case_insensitve, '', gt) 166 | 167 | if pred == gt: 168 | n_correct += 1 169 | 170 | ''' 171 | (old version) ICDAR2017 DOST Normalized Edit Distance https://rrc.cvc.uab.es/?ch=7&com=tasks 172 | "For each word we calculate the normalized edit distance to the length of the ground truth transcription." 173 | if len(gt) == 0: 174 | norm_ED += 1 175 | else: 176 | norm_ED += edit_distance(pred, gt) / len(gt) 177 | ''' 178 | 179 | # ICDAR2019 Normalized Edit Distance 180 | if len(gt) == 0 or len(pred) == 0: 181 | norm_ED += 0 182 | elif len(gt) > len(pred): 183 | norm_ED += 1 - edit_distance(pred, gt) / len(gt) 184 | else: 185 | norm_ED += 1 - edit_distance(pred, gt) / len(pred) 186 | 187 | # calculate confidence score (= multiply of pred_max_prob) 188 | try: 189 | confidence_score = pred_max_prob.cumprod(dim=0)[-1] 190 | except: 191 | confidence_score = 0 # for empty pred case, when prune after "end of sentence" token ([s]) 192 | confidence_score_list.append(confidence_score) 193 | all_confidences.append(confidence_score) 194 | # print(pred, gt, pred==gt, confidence_score) 195 | 196 | accuracy = n_correct / float(length_of_data) * 100 197 | norm_ED = norm_ED / float(length_of_data) # ICDAR2019 Normalized Edit Distance 198 | 199 | if self_training: 200 | return valid_loss_avg.val(), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data, all_imgs, all_pseudo_labels, all_confidences, all_gts 201 | return valid_loss_avg.val(), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data 202 | 203 | 204 | def test(opt): 205 | """ model configuration """ 206 | if 'CTC' in opt.Prediction: 207 | converter = CTCLabelConverter(opt.character) 208 | else: 209 | converter = AttnLabelConverter(opt.character) 210 | opt.num_class = len(converter.character) 211 | 212 | if opt.rgb: 213 | opt.input_channel = 3 214 | model = Model(opt) 215 | print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, 216 | opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, 217 | opt.SequenceModeling, opt.Prediction) 218 | model = torch.nn.DataParallel(model).to(device) 219 | 220 | # load model 221 | print('loading pretrained model from %s' % opt.saved_model) 222 | model.load_state_dict(torch.load(opt.saved_model, map_location=device)) 223 | opt.exp_name = '_'.join(opt.saved_model.split('/')[1:]) 224 | # print(model) 225 | 226 | """ keep evaluation model and result logs """ 227 | os.makedirs(f'./result/{opt.exp_name}', exist_ok=True) 228 | os.system(f'cp {opt.saved_model} ./result/{opt.exp_name}/') 229 | 230 | """ setup loss """ 231 | if 'CTC' in opt.Prediction: 232 | criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) 233 | else: 234 | criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 235 | 236 | """ evaluation """ 237 | model.eval() 238 | with torch.no_grad(): 239 | if opt.benchmark_all_eval: # evaluation with 10 benchmark evaluation datasets 240 | benchmark_all_eval(model, criterion, converter, opt) 241 | else: 242 | log = open(f'./result/{opt.exp_name}/log_evaluation.txt', 'a') 243 | AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 244 | eval_data, eval_data_log = hierarchical_dataset(root=opt.eval_data, opt=opt) 245 | evaluation_loader = torch.utils.data.DataLoader( 246 | eval_data, batch_size=opt.batch_size, 247 | shuffle=False, 248 | num_workers=int(opt.workers), 249 | collate_fn=AlignCollate_evaluation, pin_memory=True) 250 | _, accuracy_by_best_model, _, _, _, _, _, _ = validation( 251 | model, criterion, evaluation_loader, converter, opt) 252 | log.write(eval_data_log) 253 | print(f'{accuracy_by_best_model:0.3f}') 254 | log.write(f'{accuracy_by_best_model:0.3f}\n') 255 | log.close() 256 | 257 | 258 | if __name__ == '__main__': 259 | parser = argparse.ArgumentParser() 260 | parser.add_argument('--eval_data', required=True, help='path to evaluation dataset') 261 | parser.add_argument('--benchmark_all_eval', action='store_true', help='evaluate 10 benchmark evaluation datasets') 262 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) 263 | parser.add_argument('--batch_size', type=int, default=192, help='input batch size') 264 | parser.add_argument('--saved_model', required=True, help="path to saved_model to evaluation") 265 | """ Data processing """ 266 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') 267 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') 268 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') 269 | parser.add_argument('--rgb', action='store_true', help='use rgb input') 270 | parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', help='character label') 271 | parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode') 272 | parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize') 273 | parser.add_argument('--data_filtering_off', action='store_true', help='for data_filtering_off mode') 274 | parser.add_argument('--baiduCTC', action='store_true', help='for data_filtering_off mode') 275 | """ Model Architecture """ 276 | parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS') 277 | parser.add_argument('--FeatureExtraction', type=str, required=True, help='FeatureExtraction stage. VGG|RCNN|ResNet') 278 | parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM') 279 | parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn') 280 | parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') 281 | parser.add_argument('--input_channel', type=int, default=1, help='the number of input channel of Feature extractor') 282 | parser.add_argument('--output_channel', type=int, default=512, 283 | help='the number of output channel of Feature extractor') 284 | parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') 285 | 286 | opt = parser.parse_args() 287 | 288 | """ vocab / character number configuration """ 289 | if opt.sensitive: 290 | opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). 291 | opt.character = u"啊阿埃挨哎唉哀皑癌蔼矮艾碍爱隘鞍氨安俺按暗岸胺案肮昂盎凹敖熬翱袄傲奥懊澳芭捌扒叭吧笆八疤巴拔跋靶把耙坝霸罢爸白柏百摆佰败拜稗斑班搬扳般颁板版扮拌伴瓣半办绊邦帮梆榜膀绑棒磅蚌镑傍谤苞胞包褒剥薄雹保堡饱宝抱报暴豹鲍爆杯碑悲卑北辈背贝钡倍狈备惫焙被奔苯本笨崩绷甭泵蹦迸逼鼻比鄙笔彼碧蓖蔽毕毙毖币庇痹闭敝弊必辟壁臂避陛鞭边编贬扁便变卞辨辩辫遍标彪膘表鳖憋别瘪彬斌濒滨宾摈兵冰柄丙秉饼炳病并玻菠播拨钵波博勃搏铂箔伯帛舶脖膊渤泊驳捕卜哺补埠不布步簿部怖擦猜裁材才财睬踩采彩菜蔡餐参蚕残惭惨灿苍舱仓沧藏操糙槽曹草厕策侧册测层蹭插叉茬茶查碴搽察岔差诧拆柴豺搀掺蝉馋谗缠铲产阐颤昌猖场尝常长偿肠厂敞畅唱倡超抄钞朝嘲潮巢吵炒车扯撤掣彻澈郴臣辰尘晨忱沉陈趁衬撑称城橙成呈乘程惩澄诚承逞骋秤吃痴持匙池迟弛驰耻齿侈尺赤翅斥炽充冲虫崇宠抽酬畴踌稠愁筹仇绸瞅丑臭初出橱厨躇锄雏滁除楚础储矗搐触处揣川穿椽传船喘串疮窗幢床闯创吹炊捶锤垂春椿醇唇淳纯蠢戳绰疵茨磁雌辞慈瓷词此刺赐次聪葱囱匆从丛凑粗醋簇促蹿篡窜摧崔催脆瘁粹淬翠村存寸磋撮搓措挫错搭达答瘩打大呆歹傣戴带殆代贷袋待逮怠耽担丹单郸掸胆旦氮但惮淡诞弹蛋当挡党荡档刀捣蹈倒岛祷导到稻悼道盗德得的蹬灯登等瞪凳邓堤低滴迪敌笛狄涤翟嫡抵底地蒂第帝弟递缔颠掂滇碘点典靛垫电佃甸店惦奠淀殿碉叼雕凋刁掉吊钓调跌爹碟蝶迭谍叠丁盯叮钉顶鼎锭定订丢东冬董懂动栋侗恫冻洞兜抖斗陡豆逗痘都督毒犊独读堵睹赌杜镀肚度渡妒端短锻段断缎堆兑队对墩吨蹲敦顿囤钝盾遁掇哆多夺垛躲朵跺舵剁惰堕蛾峨鹅俄额讹娥恶厄扼遏鄂饿恩而儿耳尔饵洱二贰发罚筏伐乏阀法珐藩帆番翻樊矾钒繁凡烦反返范贩犯饭泛坊芳方肪房防妨仿访纺放菲非啡飞肥匪诽吠肺废沸费芬酚吩氛分纷坟焚汾粉奋份忿愤粪丰封枫蜂峰锋风疯烽逢冯缝讽奉凤佛否夫敷肤孵扶拂辐幅氟符伏俘服浮涪福袱弗甫抚辅俯釜斧脯腑府腐赴副覆赋复傅付阜父腹负富讣附妇缚咐噶嘎该改概钙盖溉干甘杆柑竿肝赶感秆敢赣冈刚钢缸肛纲岗港杠篙皋高膏羔糕搞镐稿告哥歌搁戈鸽胳疙割革葛格蛤阁隔铬个各给根跟耕更庚羹埂耿梗工攻功恭龚供躬公宫弓巩汞拱贡共钩勾沟苟狗垢构购够辜菇咕箍估沽孤姑鼓古蛊骨谷股故顾固雇刮瓜剐寡挂褂乖拐怪棺关官冠观管馆罐惯灌贯光广逛瑰规圭硅归龟闺轨鬼诡癸桂柜跪贵刽辊滚棍锅郭国果裹过哈骸孩海氦亥害骇酣憨邯韩含涵寒函喊罕翰撼捍旱憾悍焊汗汉夯杭航壕嚎豪毫郝好耗号浩呵喝荷菏核禾和何合盒貉阂河涸赫褐鹤贺嘿黑痕很狠恨哼亨横衡恒轰哄烘虹鸿洪宏弘红喉侯猴吼厚候后呼乎忽瑚壶葫胡蝴狐糊湖弧虎唬护互沪户花哗华猾滑画划化话槐徊怀淮坏欢环桓还缓换患唤痪豢焕涣宦幻荒慌黄磺蝗簧皇凰惶煌晃幌恍谎灰挥辉徽恢蛔回毁悔慧卉惠晦贿秽会烩汇讳诲绘荤昏婚魂浑混豁活伙火获或惑霍货祸击圾基机畸稽积箕肌饥迹激讥鸡姬绩缉吉极棘辑籍集及急疾汲即嫉级挤几脊己蓟技冀季伎祭剂悸济寄寂计记既忌际妓继纪嘉枷夹佳家加荚颊贾甲钾假稼价架驾嫁歼监坚尖笺间煎兼肩艰奸缄茧检柬碱碱拣捡简俭剪减荐槛鉴践贱见键箭件健舰剑饯渐溅涧建僵姜将浆江疆蒋桨奖讲匠酱降蕉椒礁焦胶交郊浇骄娇嚼搅铰矫侥脚狡角饺缴绞剿教酵轿较叫窖揭接皆秸街阶截劫节桔杰捷睫竭洁结解姐戒藉芥界借介疥诫届巾筋斤金今津襟紧锦仅谨进靳晋禁近烬浸尽劲荆兢茎睛晶鲸京惊精粳经井警景颈静境敬镜径痉靖竟竞净炯窘揪究纠玖韭久灸九酒厩救旧臼舅咎就疚鞠拘狙疽居驹菊局咀矩举沮聚拒据巨具距踞锯俱句惧炬剧捐鹃娟倦眷卷绢撅攫抉掘倔爵觉决诀绝均菌钧军君峻俊竣浚郡骏喀咖卡咯开揩楷凯慨刊堪勘坎砍看康慷糠扛抗亢炕考拷烤靠坷苛柯棵磕颗科壳咳可渴克刻客课肯啃垦恳坑吭空恐孔控抠口扣寇枯哭窟苦酷库裤夸垮挎跨胯块筷侩快宽款匡筐狂框矿眶旷况亏盔岿窥葵奎魁傀馈愧溃坤昆捆困括扩廓阔垃拉喇蜡腊辣啦莱来赖蓝婪栏拦篮阑兰澜谰揽览懒缆烂滥琅榔狼廊郎朗浪捞劳牢老佬姥酪烙涝勒乐雷镭蕾磊累儡垒擂肋类泪棱楞冷厘梨犁黎篱狸离漓理李里鲤礼莉荔吏栗丽厉励砾历利僳例俐痢立粒沥隶力璃哩俩联莲连镰廉怜涟帘敛脸链恋炼练粮凉梁粱良两辆量晾亮谅撩聊僚疗燎寥辽潦了撂镣廖料列裂烈劣猎琳林磷霖临邻鳞淋凛赁吝拎玲菱零龄铃伶羚凌灵陵岭领另令溜琉榴硫馏留刘瘤流柳六龙聋咙笼窿隆垄拢陇楼娄搂篓漏陋芦卢颅庐炉掳卤虏鲁麓碌露路赂鹿潞禄录陆戮驴吕铝侣旅履屡缕虑氯律率滤绿峦挛孪滦卵乱掠略抡轮伦仑沦纶论萝螺罗逻锣箩骡裸落洛骆络妈麻玛码蚂马骂嘛吗埋买麦卖迈脉瞒馒蛮满蔓曼慢漫谩芒茫盲氓忙莽猫茅锚毛矛铆卯茂冒帽貌贸么玫枚梅酶霉煤没眉媒镁每美昧寐妹媚门闷们萌蒙檬盟锰猛梦孟眯醚靡糜迷谜弥米秘觅泌蜜密幂棉眠绵冕免勉娩缅面苗描瞄藐秒渺庙妙蔑灭民抿皿敏悯闽明螟鸣铭名命谬摸摹蘑模膜磨摩魔抹末莫墨默沫漠寞陌谋牟某拇牡亩姆母墓暮幕募慕木目睦牧穆拿哪呐钠那娜纳氖乃奶耐奈南男难囊挠脑恼闹淖呢馁内嫩能妮霓倪泥尼拟你匿腻逆溺蔫拈年碾撵捻念娘酿鸟尿捏聂孽啮镊镍涅您柠狞凝宁拧泞牛扭钮纽脓浓农弄奴努怒女暖虐疟挪懦糯诺哦欧鸥殴藕呕偶沤啪趴爬帕怕琶拍排牌徘湃派攀潘盘磐盼畔判叛乓庞旁耪胖抛咆刨炮袍跑泡呸胚培裴赔陪配佩沛喷盆砰抨烹澎彭蓬棚硼篷膨朋鹏捧碰坯砒霹批披劈琵毗啤脾疲皮匹痞僻屁譬篇偏片骗飘漂瓢票撇瞥拼频贫品聘乒坪苹萍平凭瓶评屏坡泼颇婆破魄迫粕剖扑铺仆莆葡菩蒲埔朴圃普浦谱曝瀑期欺栖戚妻七凄漆柒沏其棋奇歧畦崎脐齐旗祈祁骑起岂乞企启契砌器气迄弃汽泣讫掐洽牵扦钎铅千迁签仟谦乾黔钱钳前潜遣浅谴堑嵌欠歉枪呛腔羌墙蔷强抢橇锹敲悄桥瞧乔侨巧鞘撬翘峭俏窍切茄且怯窃钦侵亲秦琴勤芹擒禽寝沁青轻氢倾卿清擎晴氰情顷请庆琼穷秋丘邱球求囚酋泅趋区蛆曲躯屈驱渠取娶龋趣去圈颧权醛泉全痊拳犬券劝缺炔瘸却鹊榷确雀裙群然燃冉染瓤壤攘嚷让饶扰绕惹热壬仁人忍韧任认刃妊纫扔仍日戎茸蓉荣融熔溶容绒冗揉柔肉茹蠕儒孺如辱乳汝入褥软阮蕊瑞锐闰润若弱撒洒萨腮鳃塞赛三叁伞散桑嗓丧搔骚扫嫂瑟色涩森僧莎砂杀刹沙纱傻啥煞筛晒珊苫杉山删煽衫闪陕擅赡膳善汕扇缮墒伤商赏晌上尚裳梢捎稍烧芍勺韶少哨邵绍奢赊蛇舌舍赦摄射慑涉社设砷申呻伸身深娠绅神沈审婶甚肾慎渗声生甥牲升绳省盛剩胜圣师失狮施湿诗尸虱十石拾时什食蚀实识史矢使屎驶始式示士世柿事拭誓逝势是嗜噬适仕侍释饰氏市恃室视试收手首守寿授售受瘦兽蔬枢梳殊抒输叔舒淑疏书赎孰熟薯暑曙署蜀黍鼠属术述树束戍竖墅庶数漱恕刷耍摔衰甩帅栓拴霜双爽谁水睡税吮瞬顺舜说硕朔烁斯撕嘶思私司丝死肆寺嗣四伺似饲巳松耸怂颂送宋讼诵搜艘擞嗽苏酥俗素速粟僳塑溯宿诉肃酸蒜算虽隋随绥髓碎岁穗遂隧祟孙损笋蓑梭唆缩琐索锁所塌他它她塔獭挞蹋踏胎苔抬台泰酞太态汰坍摊贪瘫滩坛檀痰潭谭谈坦毯袒碳探叹炭汤塘搪堂棠膛唐糖倘躺淌趟烫掏涛滔绦萄桃逃淘陶讨套特藤腾疼誊梯剔踢锑提题蹄啼体替嚏惕涕剃屉天添填田甜恬舔腆挑条迢眺跳贴铁帖厅听烃汀廷停亭庭艇通桐酮瞳同铜彤童桶捅筒统痛偷投头透凸秃突图徒途涂屠土吐兔湍团推颓腿蜕褪退吞屯臀拖托脱鸵陀驮驼椭妥拓唾挖哇蛙洼娃瓦袜歪外豌弯湾玩顽丸烷完碗挽晚皖惋宛婉万腕汪王亡枉网往旺望忘妄威巍微危圩韦违桅围唯惟为潍维苇萎委伟伪尾纬未蔚味畏胃喂魏位渭谓尉慰卫瘟温蚊文闻纹吻稳紊问嗡翁瓮挝蜗涡窝我斡卧握沃巫呜钨乌污诬屋无芜梧吾吴毋武五捂午舞伍侮坞戊雾晤物勿务悟误昔熙析西硒矽晰嘻吸锡牺稀息希悉膝夕惜熄烯溪汐犀檄袭席习媳喜铣洗系隙戏细瞎虾匣霞辖暇峡侠狭下厦夏吓掀锨先仙鲜纤咸贤衔舷闲涎弦嫌显险现献县腺馅羡宪陷限线相厢镶香箱襄湘乡翔祥详想响享项巷橡像向象萧硝霄削哮嚣销消宵淆晓小孝校肖啸笑效楔些歇蝎鞋协挟携邪斜胁谐写械卸蟹懈泄泻谢屑薪芯锌欣辛新忻心信衅星腥猩惺兴刑型形邢行醒幸杏性姓兄凶胸匈汹雄熊休修羞朽嗅锈秀袖绣墟戌需虚嘘须徐许蓄酗叙旭序畜恤絮婿绪续轩喧宣悬旋玄选癣眩绚靴薛学穴雪血勋熏循旬询寻驯巡殉汛训讯逊迅压押鸦鸭呀丫芽牙蚜崖衙涯雅哑亚讶焉咽阉烟淹盐严研蜒岩延言颜阎炎沿奄掩眼衍演艳堰燕厌砚雁唁彦焰宴谚验殃央鸯秧杨扬佯疡羊洋阳氧仰痒养样漾邀腰妖瑶摇尧遥窑谣姚咬舀药要耀椰噎耶爷野冶也页掖业叶曳腋夜液一壹医揖铱依伊衣颐夷遗移仪胰疑沂宜姨彝椅蚁倚已乙矣以艺抑易邑屹亿役臆逸肄疫亦裔意毅忆义益溢诣议谊译异翼翌绎茵荫因殷音阴姻吟银淫寅饮尹引隐印英樱婴鹰应缨莹萤营荧蝇迎赢盈影颖硬映哟拥佣臃痈庸雍踊蛹咏泳涌永恿勇用幽优悠忧尤由邮铀犹油游酉有友右佑釉诱又幼迂淤于盂榆虞愚舆余俞逾鱼愉渝渔隅予娱雨与屿禹宇语羽玉域芋郁吁遇喻峪御愈欲狱育誉浴寓裕预豫驭鸳渊冤元垣袁原援辕园员圆猿源缘远苑愿怨院曰约越跃钥岳粤月悦阅耘云郧匀陨允运蕴酝晕韵孕匝砸杂栽哉灾宰载再在咱攒暂赞赃脏葬遭糟凿藻枣早澡蚤躁噪造皂灶燥责择则泽贼怎增憎曾赠扎喳渣札轧铡闸眨栅榨咋乍炸诈摘斋宅窄债寨瞻毡詹粘沾盏斩辗崭展蘸栈占战站湛绽樟章彰漳张掌涨杖丈帐账仗胀瘴障招昭找沼赵照罩兆肇召遮折哲蛰辙者锗蔗这浙珍斟真甄砧臻贞针侦枕疹诊震振镇阵蒸挣睁征狰争怔整拯正政帧症郑证芝枝支吱蜘知肢脂汁之织职直植殖执值侄址指止趾只旨纸志挚掷至致置帜峙制智秩稚质炙痔滞治窒中盅忠钟衷终种肿重仲众舟周州洲诌粥轴肘帚咒皱宙昼骤珠株蛛朱猪诸诛逐竹烛煮拄瞩嘱主著柱助蛀贮铸筑住注祝驻抓爪拽专砖转撰赚篆桩庄装妆撞壮状椎锥追赘坠缀谆准捉拙卓桌琢茁酌啄着灼浊兹咨资姿滋淄孜紫仔籽滓子自渍字鬃棕踪宗综总纵邹走奏揍租足卒族祖诅阻组钻纂嘴醉最罪尊遵昨左佐柞做作坐座1234567890" 292 | cudnn.benchmark = True 293 | cudnn.deterministic = True 294 | opt.num_gpu = torch.cuda.device_count() 295 | 296 | test(opt) 297 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | 3 | import os 4 | import sys 5 | import time 6 | import random 7 | import string 8 | import argparse 9 | import visdom 10 | 11 | import torch 12 | import torch.backends.cudnn as cudnn 13 | import torch.nn.init as init 14 | import torch.optim as optim 15 | import torch.utils.data 16 | import numpy as np 17 | 18 | from utils import CTCLabelConverter, CTCLabelConverterForBaiduWarpctc, AttnLabelConverter, Averager 19 | from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset 20 | from model import Model 21 | from test import validation 22 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 23 | 24 | 25 | def create_vis_plot(_xlabel, _ylabel, _title, _legend, viz): 26 | return viz.line( 27 | X=torch.zeros((1,)).cpu(), 28 | Y=torch.zeros((1, 1)).cpu(), 29 | opts=dict( 30 | xlabel=_xlabel, 31 | ylabel=_ylabel, 32 | title=_title, 33 | legend=_legend 34 | ) 35 | ) 36 | 37 | def update_vis_plot(iteration, loc, window1, update_type, viz): 38 | viz.line( 39 | X=torch.ones((1, 1)).cpu() * iteration, 40 | Y=torch.Tensor([loc]).unsqueeze(0).cpu(), 41 | win=window1, 42 | update=update_type 43 | ) 44 | 45 | def train(opt): 46 | 47 | """ dataset preparation """ 48 | if not opt.data_filtering_off: 49 | print('Filtering the images containing characters which are not in opt.character') 50 | print('Filtering the images whose label is longer than opt.batch_max_length') 51 | # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 52 | 53 | opt.select_data = opt.select_data.split('-') 54 | opt.batch_ratio = opt.batch_ratio.split('-') 55 | train_dataset = Batch_Balanced_Dataset(opt) # 组合数据集 56 | 57 | log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') 58 | AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 59 | valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) 60 | valid_loader = torch.utils.data.DataLoader( 61 | valid_dataset, batch_size=opt.batch_size, 62 | shuffle=True, # 'True' to check training progress with validation function. 63 | num_workers=int(opt.workers), 64 | collate_fn=AlignCollate_valid, pin_memory=False) 65 | log.write(valid_dataset_log) 66 | print('-' * 80) 67 | log.write('-' * 80 + '\n') 68 | log.close() 69 | 70 | """ model configuration """ 71 | if 'CTC' in opt.Prediction: 72 | if opt.baiduCTC: 73 | converter = CTCLabelConverterForBaiduWarpctc(opt.character) 74 | else: 75 | converter = CTCLabelConverter(opt.character) 76 | else: 77 | converter = AttnLabelConverter(opt.character) 78 | opt.num_class = len(converter.character) 79 | 80 | if opt.rgb: 81 | opt.input_channel = 3 82 | model = Model(opt) 83 | print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, 84 | opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, 85 | opt.SequenceModeling, opt.Prediction) 86 | 87 | # weight initialization 88 | for name, param in model.named_parameters(): 89 | if 'localization_fc2' in name: 90 | print(f'Skip {name} as it is already initialized') 91 | continue 92 | try: 93 | if 'bias' in name: 94 | init.constant_(param, 0.0) 95 | elif 'weight' in name: 96 | init.kaiming_normal_(param) 97 | except Exception as e: # for batchnorm. 98 | if 'weight' in name: 99 | param.data.fill_(1) 100 | continue 101 | 102 | # data parallel for multi-GPU 103 | model = torch.nn.DataParallel(model).to(device) 104 | model.train() 105 | if opt.saved_model != '': 106 | print(f'loading pretrained model from {opt.saved_model}') 107 | if opt.FT: 108 | model.load_state_dict(torch.load(opt.saved_model), strict=False) 109 | else: 110 | model.load_state_dict(torch.load(opt.saved_model)) 111 | print("Model:") 112 | print(model) 113 | 114 | """ setup loss """ 115 | if 'CTC' in opt.Prediction: 116 | if opt.baiduCTC: 117 | # need to install warpctc. see our guideline. 118 | from warpctc_pytorch import CTCLoss 119 | criterion = CTCLoss() 120 | else: 121 | criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) 122 | else: 123 | criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 124 | # loss averager 125 | loss_avg = Averager() 126 | 127 | # filter that only require gradient decent 128 | filtered_parameters = [] 129 | params_num = [] 130 | 131 | for p in filter(lambda p: p.requires_grad, model.parameters()): 132 | filtered_parameters.append(p) 133 | params_num.append(np.prod(p.size())) 134 | print('Trainable params num : ', sum(params_num)) 135 | 136 | # setup optimizer 137 | if opt.adam: 138 | optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) 139 | else: 140 | optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) 141 | print("Optimizer:") 142 | print(optimizer) 143 | 144 | """ final options """ 145 | # print(opt) 146 | with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: 147 | opt_log = '------------ Options -------------\n' 148 | args = vars(opt) 149 | for k, v in args.items(): 150 | opt_log += f'{str(k)}: {str(v)}\n' 151 | opt_log += '---------------------------------------\n' 152 | print(opt_log) 153 | opt_file.write(opt_log) 154 | 155 | """ start training """ 156 | start_iter = 0 157 | if opt.saved_model != '': 158 | try: 159 | start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) 160 | print(f'continue to train, start_iter: {start_iter}') 161 | except: 162 | if opt.resume_epoch != -1: 163 | start_iter = opt.resume_epoch 164 | print(f'continue to train, start_iter: {start_iter}') 165 | 166 | 167 | start_time = time.time() 168 | best_accuracy = -1 169 | best_norm_ED = -1 170 | test_best_accuracy = -1 171 | test_best_norm_ED = -1 172 | iteration = start_iter 173 | 174 | while(True): 175 | image_tensors, labels = train_dataset.get_batch() 176 | image = image_tensors.to(device) 177 | text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) 178 | batch_size = image.size(0) 179 | 180 | if 'CTC' in opt.Prediction: 181 | preds = model(image, text) 182 | preds_size = torch.IntTensor([preds.size(1)] * batch_size) 183 | if opt.baiduCTC: 184 | preds = preds.permute(1, 0, 2) # to use CTCLoss format 185 | cost = criterion(preds, text, preds_size, length) / batch_size 186 | else: 187 | preds = preds.log_softmax(2).permute(1, 0, 2) 188 | cost = criterion(preds, text, preds_size, length) 189 | 190 | else: 191 | preds = model(image, text[:, :-1]) # align with Attention.forward 192 | target = text[:, 1:] # without [GO] Symbol 193 | cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) 194 | 195 | model.zero_grad() 196 | cost.backward() 197 | torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) 198 | optimizer.step() 199 | 200 | loss_avg.add(cost) 201 | # validation part 202 | if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' 203 | elapsed_time = time.time() - start_time 204 | # for log 205 | with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: 206 | model.eval() 207 | with torch.no_grad(): 208 | valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( 209 | model, criterion, valid_loader, converter, opt) 210 | model.train() 211 | 212 | # training loss and validation loss 213 | loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' 214 | loss_avg.reset() 215 | 216 | current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' 217 | # keep best accuracy model (on valid dataset) 218 | if current_accuracy > best_accuracy: 219 | best_accuracy = current_accuracy 220 | torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy.pth') 221 | if current_norm_ED > best_norm_ED: 222 | best_norm_ED = current_norm_ED 223 | torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth') 224 | 225 | best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' 226 | # test_best_model_log = f'{"Test_Best_accuracy":17s}: {test_best_accuracy:0.3f}, {"Test_Best_norm_ED":17s}: {test_best_norm_ED:0.2f}' 227 | 228 | loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' 229 | print(loss_model_log) 230 | log.write(loss_model_log + '\n') 231 | 232 | # show some predicted results 233 | dashed_line = '-' * 80 234 | head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' 235 | predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' 236 | # print(len(valid_dataset), len(labels)) 237 | for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): 238 | if 'Attn' in opt.Prediction: 239 | gt = gt[:gt.find('[s]')] 240 | pred = pred[:pred.find('[s]')] 241 | 242 | predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' 243 | predicted_result_log += f'{dashed_line}' 244 | print(predicted_result_log) 245 | log.write(predicted_result_log + '\n') 246 | 247 | # save model per 1e+5 iter. 248 | if (iteration + 1) % 1e+5 == 0: 249 | torch.save( 250 | model.state_dict(), f'./saved_models/{opt.exp_name}/iter_{iteration+1}.pth') 251 | 252 | if (iteration + 1) == opt.num_iter: 253 | print('end the training') 254 | sys.exit() 255 | iteration += 1 256 | 257 | 258 | if __name__ == '__main__': 259 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 260 | parser = argparse.ArgumentParser() 261 | parser.add_argument('--exp_name', help='Where to store logs and models') 262 | parser.add_argument('--train_data', required=True, help='path to training dataset') 263 | parser.add_argument('--valid_data', required=True, help='path to validation dataset') 264 | parser.add_argument('--manualSeed', type=int, default=1111, help='for random seed setting') 265 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=2) 266 | parser.add_argument('--batch_size', type=int, default=192, help='input batch size') 267 | parser.add_argument('--num_iter', type=int, default=300000, help='number of iterations to train for') 268 | parser.add_argument('--valInterval', type=int, default=5000, help='Interval between each validation') 269 | parser.add_argument('--saved_model', default='', help="path to model to continue training") 270 | parser.add_argument('--FT', action='store_true', help='whether to do fine-tuning') 271 | parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is Adadelta)') 272 | parser.add_argument('--lr', type=float, default=1, help='learning rate, default=1.0 for Adadelta') 273 | parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for adam. default=0.9') 274 | parser.add_argument('--rho', type=float, default=0.95, help='decay rate rho for Adadelta. default=0.95') 275 | parser.add_argument('--eps', type=float, default=1e-8, help='eps for Adadelta. default=1e-8') 276 | parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping value. default=5') 277 | parser.add_argument('--baiduCTC', action='store_true', help='for data_filtering_off mode') 278 | """ Data processing """ 279 | parser.add_argument('--select_data', type=str, default='MJ-ST', 280 | help='select training data (default is MJ-ST, which means MJ and ST used as training data)') 281 | parser.add_argument('--batch_ratio', type=str, default='1', 282 | help='assign ratio for each selected data in the batch 0.5-0.5') 283 | parser.add_argument('--total_data_usage_ratio', type=str, default='1.0', 284 | help='total data usage ratio, this ratio is multiplied to total number of data.') 285 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') 286 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') 287 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') 288 | parser.add_argument('--rgb', action='store_true', help='use rgb input') 289 | parser.add_argument('--character', type=str, 290 | default='0123456789一下东丰云亚亨亮会佑佛佳俊信八关兴凯利力勤华南县发号君和圩城壹大天宁宇安宏宾富封州工帆平年广庆建开德恒惠成新昌明机权来柳桂梧森横永江沙河油泰泽洋浮海润清港湖源滨珠田盈益盛石祥福程穗粤翔肇航良英藤行衡西诚诺谢谷货贵辉达运远途通都金长阳雄韶顺颜风飞香鸿鼎龙', help='character label 0123456789abcdefghijklmnopqrstuvwxyz') 291 | parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode') 292 | parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize') 293 | parser.add_argument('--data_filtering_off', action='store_true', help='for data_filtering_off mode') 294 | """ Model Architecture """ 295 | parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS') 296 | parser.add_argument('--FeatureExtraction', type=str, required=True, 297 | help='FeatureExtraction stage. VGG|RCNN|ResNet') 298 | parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM') 299 | parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn') 300 | parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') 301 | parser.add_argument('--input_channel', type=int, default=1, 302 | help='the number of input channel of Feature extractor') 303 | parser.add_argument('--output_channel', type=int, default=512, 304 | help='the number of output channel of Feature extractor') 305 | parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') 306 | parser.add_argument('--expsuffix', type=str, default='') 307 | parser.add_argument('--resume_epoch', type=int, default=-1) 308 | opt = parser.parse_args() 309 | 310 | if not opt.exp_name: 311 | opt.exp_name = f'{opt.Transformation}-{opt.FeatureExtraction}-{opt.SequenceModeling}-{opt.Prediction}' 312 | opt.exp_name += f'-Seed{opt.manualSeed}' 313 | assert opt.expsuffix, 'You should specify the exp suffix' 314 | opt.exp_name += f'-{opt.expsuffix}' 315 | # print(opt.exp_name) 316 | 317 | os.makedirs(f'./saved_models/{opt.exp_name}', exist_ok=True) 318 | 319 | """ vocab / character number configuration """ 320 | if opt.sensitive: 321 | # opt.character += 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' 322 | opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). 323 | 324 | opt.character = CHARS = u"啊阿埃挨哎唉哀皑癌蔼矮艾碍爱隘鞍氨安俺按暗岸胺案肮昂盎凹敖熬翱袄傲奥懊澳芭捌扒叭吧笆八疤巴拔跋靶把耙坝霸罢爸白柏百摆佰败拜稗斑班搬扳般颁板版扮拌伴瓣半办绊邦帮梆榜膀绑棒磅蚌镑傍谤苞胞包褒剥薄雹保堡饱宝抱报暴豹鲍爆杯碑悲卑北辈背贝钡倍狈备惫焙被奔苯本笨崩绷甭泵蹦迸逼鼻比鄙笔彼碧蓖蔽毕毙毖币庇痹闭敝弊必辟壁臂避陛鞭边编贬扁便变卞辨辩辫遍标彪膘表鳖憋别瘪彬斌濒滨宾摈兵冰柄丙秉饼炳病并玻菠播拨钵波博勃搏铂箔伯帛舶脖膊渤泊驳捕卜哺补埠不布步簿部怖擦猜裁材才财睬踩采彩菜蔡餐参蚕残惭惨灿苍舱仓沧藏操糙槽曹草厕策侧册测层蹭插叉茬茶查碴搽察岔差诧拆柴豺搀掺蝉馋谗缠铲产阐颤昌猖场尝常长偿肠厂敞畅唱倡超抄钞朝嘲潮巢吵炒车扯撤掣彻澈郴臣辰尘晨忱沉陈趁衬撑称城橙成呈乘程惩澄诚承逞骋秤吃痴持匙池迟弛驰耻齿侈尺赤翅斥炽充冲虫崇宠抽酬畴踌稠愁筹仇绸瞅丑臭初出橱厨躇锄雏滁除楚础储矗搐触处揣川穿椽传船喘串疮窗幢床闯创吹炊捶锤垂春椿醇唇淳纯蠢戳绰疵茨磁雌辞慈瓷词此刺赐次聪葱囱匆从丛凑粗醋簇促蹿篡窜摧崔催脆瘁粹淬翠村存寸磋撮搓措挫错搭达答瘩打大呆歹傣戴带殆代贷袋待逮怠耽担丹单郸掸胆旦氮但惮淡诞弹蛋当挡党荡档刀捣蹈倒岛祷导到稻悼道盗德得的蹬灯登等瞪凳邓堤低滴迪敌笛狄涤翟嫡抵底地蒂第帝弟递缔颠掂滇碘点典靛垫电佃甸店惦奠淀殿碉叼雕凋刁掉吊钓调跌爹碟蝶迭谍叠丁盯叮钉顶鼎锭定订丢东冬董懂动栋侗恫冻洞兜抖斗陡豆逗痘都督毒犊独读堵睹赌杜镀肚度渡妒端短锻段断缎堆兑队对墩吨蹲敦顿囤钝盾遁掇哆多夺垛躲朵跺舵剁惰堕蛾峨鹅俄额讹娥恶厄扼遏鄂饿恩而儿耳尔饵洱二贰发罚筏伐乏阀法珐藩帆番翻樊矾钒繁凡烦反返范贩犯饭泛坊芳方肪房防妨仿访纺放菲非啡飞肥匪诽吠肺废沸费芬酚吩氛分纷坟焚汾粉奋份忿愤粪丰封枫蜂峰锋风疯烽逢冯缝讽奉凤佛否夫敷肤孵扶拂辐幅氟符伏俘服浮涪福袱弗甫抚辅俯釜斧脯腑府腐赴副覆赋复傅付阜父腹负富讣附妇缚咐噶嘎该改概钙盖溉干甘杆柑竿肝赶感秆敢赣冈刚钢缸肛纲岗港杠篙皋高膏羔糕搞镐稿告哥歌搁戈鸽胳疙割革葛格蛤阁隔铬个各给根跟耕更庚羹埂耿梗工攻功恭龚供躬公宫弓巩汞拱贡共钩勾沟苟狗垢构购够辜菇咕箍估沽孤姑鼓古蛊骨谷股故顾固雇刮瓜剐寡挂褂乖拐怪棺关官冠观管馆罐惯灌贯光广逛瑰规圭硅归龟闺轨鬼诡癸桂柜跪贵刽辊滚棍锅郭国果裹过哈骸孩海氦亥害骇酣憨邯韩含涵寒函喊罕翰撼捍旱憾悍焊汗汉夯杭航壕嚎豪毫郝好耗号浩呵喝荷菏核禾和何合盒貉阂河涸赫褐鹤贺嘿黑痕很狠恨哼亨横衡恒轰哄烘虹鸿洪宏弘红喉侯猴吼厚候后呼乎忽瑚壶葫胡蝴狐糊湖弧虎唬护互沪户花哗华猾滑画划化话槐徊怀淮坏欢环桓还缓换患唤痪豢焕涣宦幻荒慌黄磺蝗簧皇凰惶煌晃幌恍谎灰挥辉徽恢蛔回毁悔慧卉惠晦贿秽会烩汇讳诲绘荤昏婚魂浑混豁活伙火获或惑霍货祸击圾基机畸稽积箕肌饥迹激讥鸡姬绩缉吉极棘辑籍集及急疾汲即嫉级挤几脊己蓟技冀季伎祭剂悸济寄寂计记既忌际妓继纪嘉枷夹佳家加荚颊贾甲钾假稼价架驾嫁歼监坚尖笺间煎兼肩艰奸缄茧检柬碱碱拣捡简俭剪减荐槛鉴践贱见键箭件健舰剑饯渐溅涧建僵姜将浆江疆蒋桨奖讲匠酱降蕉椒礁焦胶交郊浇骄娇嚼搅铰矫侥脚狡角饺缴绞剿教酵轿较叫窖揭接皆秸街阶截劫节桔杰捷睫竭洁结解姐戒藉芥界借介疥诫届巾筋斤金今津襟紧锦仅谨进靳晋禁近烬浸尽劲荆兢茎睛晶鲸京惊精粳经井警景颈静境敬镜径痉靖竟竞净炯窘揪究纠玖韭久灸九酒厩救旧臼舅咎就疚鞠拘狙疽居驹菊局咀矩举沮聚拒据巨具距踞锯俱句惧炬剧捐鹃娟倦眷卷绢撅攫抉掘倔爵觉决诀绝均菌钧军君峻俊竣浚郡骏喀咖卡咯开揩楷凯慨刊堪勘坎砍看康慷糠扛抗亢炕考拷烤靠坷苛柯棵磕颗科壳咳可渴克刻客课肯啃垦恳坑吭空恐孔控抠口扣寇枯哭窟苦酷库裤夸垮挎跨胯块筷侩快宽款匡筐狂框矿眶旷况亏盔岿窥葵奎魁傀馈愧溃坤昆捆困括扩廓阔垃拉喇蜡腊辣啦莱来赖蓝婪栏拦篮阑兰澜谰揽览懒缆烂滥琅榔狼廊郎朗浪捞劳牢老佬姥酪烙涝勒乐雷镭蕾磊累儡垒擂肋类泪棱楞冷厘梨犁黎篱狸离漓理李里鲤礼莉荔吏栗丽厉励砾历利僳例俐痢立粒沥隶力璃哩俩联莲连镰廉怜涟帘敛脸链恋炼练粮凉梁粱良两辆量晾亮谅撩聊僚疗燎寥辽潦了撂镣廖料列裂烈劣猎琳林磷霖临邻鳞淋凛赁吝拎玲菱零龄铃伶羚凌灵陵岭领另令溜琉榴硫馏留刘瘤流柳六龙聋咙笼窿隆垄拢陇楼娄搂篓漏陋芦卢颅庐炉掳卤虏鲁麓碌露路赂鹿潞禄录陆戮驴吕铝侣旅履屡缕虑氯律率滤绿峦挛孪滦卵乱掠略抡轮伦仑沦纶论萝螺罗逻锣箩骡裸落洛骆络妈麻玛码蚂马骂嘛吗埋买麦卖迈脉瞒馒蛮满蔓曼慢漫谩芒茫盲氓忙莽猫茅锚毛矛铆卯茂冒帽貌贸么玫枚梅酶霉煤没眉媒镁每美昧寐妹媚门闷们萌蒙檬盟锰猛梦孟眯醚靡糜迷谜弥米秘觅泌蜜密幂棉眠绵冕免勉娩缅面苗描瞄藐秒渺庙妙蔑灭民抿皿敏悯闽明螟鸣铭名命谬摸摹蘑模膜磨摩魔抹末莫墨默沫漠寞陌谋牟某拇牡亩姆母墓暮幕募慕木目睦牧穆拿哪呐钠那娜纳氖乃奶耐奈南男难囊挠脑恼闹淖呢馁内嫩能妮霓倪泥尼拟你匿腻逆溺蔫拈年碾撵捻念娘酿鸟尿捏聂孽啮镊镍涅您柠狞凝宁拧泞牛扭钮纽脓浓农弄奴努怒女暖虐疟挪懦糯诺哦欧鸥殴藕呕偶沤啪趴爬帕怕琶拍排牌徘湃派攀潘盘磐盼畔判叛乓庞旁耪胖抛咆刨炮袍跑泡呸胚培裴赔陪配佩沛喷盆砰抨烹澎彭蓬棚硼篷膨朋鹏捧碰坯砒霹批披劈琵毗啤脾疲皮匹痞僻屁譬篇偏片骗飘漂瓢票撇瞥拼频贫品聘乒坪苹萍平凭瓶评屏坡泼颇婆破魄迫粕剖扑铺仆莆葡菩蒲埔朴圃普浦谱曝瀑期欺栖戚妻七凄漆柒沏其棋奇歧畦崎脐齐旗祈祁骑起岂乞企启契砌器气迄弃汽泣讫掐洽牵扦钎铅千迁签仟谦乾黔钱钳前潜遣浅谴堑嵌欠歉枪呛腔羌墙蔷强抢橇锹敲悄桥瞧乔侨巧鞘撬翘峭俏窍切茄且怯窃钦侵亲秦琴勤芹擒禽寝沁青轻氢倾卿清擎晴氰情顷请庆琼穷秋丘邱球求囚酋泅趋区蛆曲躯屈驱渠取娶龋趣去圈颧权醛泉全痊拳犬券劝缺炔瘸却鹊榷确雀裙群然燃冉染瓤壤攘嚷让饶扰绕惹热壬仁人忍韧任认刃妊纫扔仍日戎茸蓉荣融熔溶容绒冗揉柔肉茹蠕儒孺如辱乳汝入褥软阮蕊瑞锐闰润若弱撒洒萨腮鳃塞赛三叁伞散桑嗓丧搔骚扫嫂瑟色涩森僧莎砂杀刹沙纱傻啥煞筛晒珊苫杉山删煽衫闪陕擅赡膳善汕扇缮墒伤商赏晌上尚裳梢捎稍烧芍勺韶少哨邵绍奢赊蛇舌舍赦摄射慑涉社设砷申呻伸身深娠绅神沈审婶甚肾慎渗声生甥牲升绳省盛剩胜圣师失狮施湿诗尸虱十石拾时什食蚀实识史矢使屎驶始式示士世柿事拭誓逝势是嗜噬适仕侍释饰氏市恃室视试收手首守寿授售受瘦兽蔬枢梳殊抒输叔舒淑疏书赎孰熟薯暑曙署蜀黍鼠属术述树束戍竖墅庶数漱恕刷耍摔衰甩帅栓拴霜双爽谁水睡税吮瞬顺舜说硕朔烁斯撕嘶思私司丝死肆寺嗣四伺似饲巳松耸怂颂送宋讼诵搜艘擞嗽苏酥俗素速粟僳塑溯宿诉肃酸蒜算虽隋随绥髓碎岁穗遂隧祟孙损笋蓑梭唆缩琐索锁所塌他它她塔獭挞蹋踏胎苔抬台泰酞太态汰坍摊贪瘫滩坛檀痰潭谭谈坦毯袒碳探叹炭汤塘搪堂棠膛唐糖倘躺淌趟烫掏涛滔绦萄桃逃淘陶讨套特藤腾疼誊梯剔踢锑提题蹄啼体替嚏惕涕剃屉天添填田甜恬舔腆挑条迢眺跳贴铁帖厅听烃汀廷停亭庭艇通桐酮瞳同铜彤童桶捅筒统痛偷投头透凸秃突图徒途涂屠土吐兔湍团推颓腿蜕褪退吞屯臀拖托脱鸵陀驮驼椭妥拓唾挖哇蛙洼娃瓦袜歪外豌弯湾玩顽丸烷完碗挽晚皖惋宛婉万腕汪王亡枉网往旺望忘妄威巍微危圩韦违桅围唯惟为潍维苇萎委伟伪尾纬未蔚味畏胃喂魏位渭谓尉慰卫瘟温蚊文闻纹吻稳紊问嗡翁瓮挝蜗涡窝我斡卧握沃巫呜钨乌污诬屋无芜梧吾吴毋武五捂午舞伍侮坞戊雾晤物勿务悟误昔熙析西硒矽晰嘻吸锡牺稀息希悉膝夕惜熄烯溪汐犀檄袭席习媳喜铣洗系隙戏细瞎虾匣霞辖暇峡侠狭下厦夏吓掀锨先仙鲜纤咸贤衔舷闲涎弦嫌显险现献县腺馅羡宪陷限线相厢镶香箱襄湘乡翔祥详想响享项巷橡像向象萧硝霄削哮嚣销消宵淆晓小孝校肖啸笑效楔些歇蝎鞋协挟携邪斜胁谐写械卸蟹懈泄泻谢屑薪芯锌欣辛新忻心信衅星腥猩惺兴刑型形邢行醒幸杏性姓兄凶胸匈汹雄熊休修羞朽嗅锈秀袖绣墟戌需虚嘘须徐许蓄酗叙旭序畜恤絮婿绪续轩喧宣悬旋玄选癣眩绚靴薛学穴雪血勋熏循旬询寻驯巡殉汛训讯逊迅压押鸦鸭呀丫芽牙蚜崖衙涯雅哑亚讶焉咽阉烟淹盐严研蜒岩延言颜阎炎沿奄掩眼衍演艳堰燕厌砚雁唁彦焰宴谚验殃央鸯秧杨扬佯疡羊洋阳氧仰痒养样漾邀腰妖瑶摇尧遥窑谣姚咬舀药要耀椰噎耶爷野冶也页掖业叶曳腋夜液一壹医揖铱依伊衣颐夷遗移仪胰疑沂宜姨彝椅蚁倚已乙矣以艺抑易邑屹亿役臆逸肄疫亦裔意毅忆义益溢诣议谊译异翼翌绎茵荫因殷音阴姻吟银淫寅饮尹引隐印英樱婴鹰应缨莹萤营荧蝇迎赢盈影颖硬映哟拥佣臃痈庸雍踊蛹咏泳涌永恿勇用幽优悠忧尤由邮铀犹油游酉有友右佑釉诱又幼迂淤于盂榆虞愚舆余俞逾鱼愉渝渔隅予娱雨与屿禹宇语羽玉域芋郁吁遇喻峪御愈欲狱育誉浴寓裕预豫驭鸳渊冤元垣袁原援辕园员圆猿源缘远苑愿怨院曰约越跃钥岳粤月悦阅耘云郧匀陨允运蕴酝晕韵孕匝砸杂栽哉灾宰载再在咱攒暂赞赃脏葬遭糟凿藻枣早澡蚤躁噪造皂灶燥责择则泽贼怎增憎曾赠扎喳渣札轧铡闸眨栅榨咋乍炸诈摘斋宅窄债寨瞻毡詹粘沾盏斩辗崭展蘸栈占战站湛绽樟章彰漳张掌涨杖丈帐账仗胀瘴障招昭找沼赵照罩兆肇召遮折哲蛰辙者锗蔗这浙珍斟真甄砧臻贞针侦枕疹诊震振镇阵蒸挣睁征狰争怔整拯正政帧症郑证芝枝支吱蜘知肢脂汁之织职直植殖执值侄址指止趾只旨纸志挚掷至致置帜峙制智秩稚质炙痔滞治窒中盅忠钟衷终种肿重仲众舟周州洲诌粥轴肘帚咒皱宙昼骤珠株蛛朱猪诸诛逐竹烛煮拄瞩嘱主著柱助蛀贮铸筑住注祝驻抓爪拽专砖转撰赚篆桩庄装妆撞壮状椎锥追赘坠缀谆准捉拙卓桌琢茁酌啄着灼浊兹咨资姿滋淄孜紫仔籽滓子自渍字鬃棕踪宗综总纵邹走奏揍租足卒族祖诅阻组钻纂嘴醉最罪尊遵昨左佐柞做作坐座1234567890ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 325 | # opt.character = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' 326 | """ Seed and GPU setting """ 327 | # print("Random Seed: ", opt.manualSeed) 328 | random.seed(opt.manualSeed) 329 | np.random.seed(opt.manualSeed) 330 | torch.manual_seed(opt.manualSeed) 331 | torch.cuda.manual_seed(opt.manualSeed) 332 | 333 | cudnn.benchmark = True 334 | cudnn.deterministic = True 335 | opt.num_gpu = torch.cuda.device_count() 336 | print('device count', opt.num_gpu) 337 | opt.num_gpu = 1 338 | 339 | if opt.num_gpu > 1: 340 | print('------ Use multi-GPU setting ------') 341 | print('if you stuck too long time with multi-GPU setting, try to set --workers 0') 342 | # check multi-GPU issue https://github.com/clovaai/deep-text-recognition-benchmark/issues/1 343 | opt.workers = opt.workers * opt.num_gpu 344 | opt.batch_size = opt.batch_size * opt.num_gpu 345 | 346 | """ previous version 347 | print('To equlize batch stats to 1-GPU setting, the batch_size is multiplied with num_gpu and multiplied batch_size is ', opt.batch_size) 348 | opt.batch_size = opt.batch_size * opt.num_gpu 349 | print('To equalize the number of epochs to 1-GPU setting, num_iter is divided with num_gpu by default.') 350 | If you dont care about it, just commnet out these line.) 351 | opt.num_iter = int(opt.num_iter / opt.num_gpu) 352 | """ 353 | 354 | train(opt) 355 | 356 | -------------------------------------------------------------------------------- /meta_train.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import os 3 | import sys 4 | import time 5 | import random 6 | import string 7 | import argparse 8 | from copy import deepcopy 9 | from collections import OrderedDict 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | import torch.nn.init as init 13 | import torch.optim as optim 14 | import torch.utils.data 15 | import numpy as np 16 | import visdom 17 | 18 | from utils import CTCLabelConverter, CTCLabelConverterForBaiduWarpctc, AttnLabelConverter, Averager 19 | from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset 20 | from model import Model 21 | from test import validation 22 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 23 | 24 | def create_vis_plot(_xlabel, _ylabel, _title, _legend, viz): 25 | return viz.line( 26 | X=torch.zeros((1,)).cpu(), 27 | Y=torch.zeros((1, 1)).cpu(), 28 | opts=dict( 29 | xlabel=_xlabel, 30 | ylabel=_ylabel, 31 | title=_title, 32 | legend=_legend 33 | ) 34 | ) 35 | 36 | def update_vis_plot(iteration, loc, window1, update_type, viz): 37 | viz.line( 38 | X=torch.ones((1, 1)).cpu() * iteration, 39 | Y=torch.Tensor([loc]).unsqueeze(0).cpu(), 40 | win=window1, 41 | update=update_type 42 | ) 43 | 44 | 45 | def train(opt): 46 | 47 | """ dataset preparation """ 48 | if not opt.data_filtering_off: 49 | print('Filtering the images containing characters which are not in opt.character') 50 | print('Filtering the images whose label is longer than opt.batch_max_length') 51 | # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 52 | 53 | opt.select_data = opt.select_data.split('-') 54 | opt.batch_ratio = opt.batch_ratio.split('-') 55 | train_dataset = Batch_Balanced_Dataset(opt) # 组合数据集 56 | 57 | log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') 58 | AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 59 | AlignCollate_test = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 60 | valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) 61 | valid_loader = torch.utils.data.DataLoader( 62 | valid_dataset, batch_size=opt.batch_size, 63 | shuffle=True, # 'True' to check training progress with validation function. 64 | num_workers=int(opt.workers), 65 | collate_fn=AlignCollate_valid, pin_memory=True) 66 | log.write(valid_dataset_log) 67 | print('-' * 80) 68 | log.write('-' * 80 + '\n') 69 | log.close() 70 | 71 | """ model configuration """ 72 | if 'CTC' in opt.Prediction: 73 | if opt.baiduCTC: 74 | converter = CTCLabelConverterForBaiduWarpctc(opt.character) 75 | else: 76 | converter = CTCLabelConverter(opt.character) 77 | else: 78 | converter = AttnLabelConverter(opt.character) 79 | opt.num_class = len(converter.character) 80 | 81 | if opt.rgb: 82 | opt.input_channel = 3 83 | model = Model(opt) 84 | print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, 85 | opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, 86 | opt.SequenceModeling, opt.Prediction) 87 | 88 | # weight initialization 89 | for name, param in model.named_parameters(): 90 | if 'localization_fc2' in name: 91 | print(f'Skip {name} as it is already initialized') 92 | continue 93 | try: 94 | if 'bias' in name: 95 | init.constant_(param, 0.0) 96 | elif 'weight' in name: 97 | init.kaiming_normal_(param) 98 | except Exception as e: # for batchnorm. 99 | if 'weight' in name: 100 | param.data.fill_(1) 101 | continue 102 | 103 | # data parallel for multi-GPU 104 | model = torch.nn.DataParallel(model).to(device) 105 | model.train() 106 | if opt.saved_model != '': 107 | print(f'loading pretrained model from {opt.saved_model}') 108 | if opt.FT: 109 | model.load_state_dict(torch.load(opt.saved_model), strict=False) 110 | else: 111 | model.load_state_dict(torch.load(opt.saved_model)) 112 | print("Model:") 113 | print(model) 114 | """ setup loss """ 115 | if 'CTC' in opt.Prediction: 116 | if opt.baiduCTC: 117 | # need to install warpctc. see our guideline. 118 | from warpctc_pytorch import CTCLoss 119 | criterion = CTCLoss() 120 | else: 121 | criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) 122 | else: 123 | criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 124 | # loss averager 125 | loss_avg = Averager() 126 | 127 | # filter that only require gradient decent 128 | filtered_parameters = [] 129 | params_num = [] 130 | 131 | for p in model.parameters(): 132 | if not p.requires_grad: raise ValueError("The parameter must require grads") 133 | 134 | for p in filter(lambda p: p.requires_grad, model.parameters()): 135 | filtered_parameters.append(p) 136 | params_num.append(np.prod(p.size())) 137 | print('Trainable params num : ', sum(params_num)) 138 | 139 | # setup optimizer 140 | if opt.adam: 141 | optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) 142 | else: 143 | optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) 144 | print("Optimizer:") 145 | print(optimizer) 146 | 147 | """ final options """ 148 | with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: 149 | opt_log = '------------ Options -------------\n' 150 | args = vars(opt) 151 | for k, v in args.items(): 152 | opt_log += f'{str(k)}: {str(v)}\n' 153 | opt_log += '---------------------------------------\n' 154 | print(opt_log) 155 | opt_file.write(opt_log) 156 | 157 | """ start training """ 158 | start_iter = 0 159 | if opt.saved_model != '': 160 | try: 161 | start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) 162 | print(f'continue to train, start_iter: {start_iter}') 163 | except: 164 | pass 165 | 166 | start_time = time.time() 167 | best_accuracy = -1 168 | best_norm_ED = -1 169 | iteration = start_iter 170 | print('start training') 171 | while(True): 172 | # train part 173 | meta_target_index = random.randint(0,opt.source_num - 1) 174 | old_state_dict = deepcopy(model.state_dict()) 175 | weight_buffer, buffer_buffer = [], [] 176 | inner_optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4) 177 | for mix_iter in range(opt.inner_loop_iter + 1): 178 | if mix_iter == opt.inner_loop_iter: 179 | image_tensors, labels = train_dataset.get_meta_test_batch(meta_target_index) 180 | else: image_tensors, labels = train_dataset.get_batch(meta_target_index) 181 | image = image_tensors.to(device) 182 | text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) 183 | batch_size = image.size(0) 184 | if 'CTC' in opt.Prediction: 185 | preds = model(image, text) 186 | preds_size = torch.IntTensor([preds.size(1)] * batch_size) 187 | if opt.baiduCTC: 188 | preds = preds.permute(1, 0, 2) # to use CTCLoss format 189 | cost = criterion(preds, text, preds_size, length) / batch_size 190 | else: 191 | preds = preds.log_softmax(2).permute(1, 0, 2) 192 | cost = criterion(preds, text, preds_size, length) 193 | 194 | else: 195 | preds = model(image, text[:, :-1]) # align with Attention.forward 196 | target = text[:, 1:] # without [GO] Symbol 197 | cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) 198 | 199 | 200 | model.zero_grad() 201 | cost.backward() 202 | torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) 203 | inner_optimizer.step() 204 | loss_avg.add(cost) 205 | 206 | if mix_iter == opt.inner_loop_iter: 207 | weight_buffer.extend([deepcopy(p) for p in model.parameters()]) 208 | buffer_buffer.extend([deepcopy(p) for p in model.buffers()]) 209 | break 210 | 211 | 212 | model.load_state_dict(old_state_dict) 213 | 214 | old_param = list(map(lambda p: p, model.parameters())) 215 | for old, new in zip(old_param, weight_buffer): old.grads = new.grad 216 | 217 | index = 0 218 | buffer_index = 0 219 | new_weights_dict = OrderedDict() 220 | 221 | for key, param in model.state_dict().items(): 222 | if key.find('running') != -1 or key.find('tracked') != -1 or key.find('delta') != -1 or key.find('hat') != -1: 223 | new_weights_dict[key] = buffer_buffer[buffer_index] # buffer要用inner loop中的状态,因为外部BN层没有被调用过,tracking的均值和方差都为0 224 | buffer_index += 1 225 | continue 226 | new_weights_dict[key] = old_param[index] 227 | index += 1 228 | 229 | assert index == len(old_param) 230 | assert buffer_index == len(buffer_buffer) 231 | 232 | model.load_state_dict(new_weights_dict) 233 | optimizer.step() 234 | optimizer.zero_grad() 235 | 236 | 237 | image_tensors, labels = train_dataset.get_batch(meta_target_index) 238 | image = image_tensors.to(device) 239 | text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) 240 | batch_size = image.size(0) 241 | 242 | if 'CTC' in opt.Prediction: 243 | preds = model(image, text) 244 | preds_size = torch.IntTensor([preds.size(1)] * batch_size) 245 | if opt.baiduCTC: 246 | preds = preds.permute(1, 0, 2) # to use CTCLoss format 247 | cost = criterion(preds, text, preds_size, length) / batch_size 248 | else: 249 | preds = preds.log_softmax(2).permute(1, 0, 2) 250 | cost = criterion(preds, text, preds_size, length) 251 | 252 | else: 253 | preds = model(image, text[:, :-1]) 254 | target = text[:, 1:] 255 | cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) 256 | 257 | model.zero_grad() 258 | cost.backward() 259 | torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) 260 | optimizer.step() 261 | 262 | 263 | if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' 264 | elapsed_time = time.time() - start_time 265 | # for log 266 | with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: 267 | model.eval() 268 | with torch.no_grad(): 269 | valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation( 270 | model, criterion, valid_loader, converter, opt) 271 | 272 | model.train() 273 | 274 | # training loss and validation loss 275 | loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' 276 | loss_avg.reset() 277 | 278 | current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' 279 | # keep best accuracy model (on valid dataset) 280 | if current_accuracy > best_accuracy: 281 | best_accuracy = current_accuracy 282 | torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy.pth') 283 | if current_norm_ED > best_norm_ED: 284 | best_norm_ED = current_norm_ED 285 | torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth') 286 | 287 | best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' 288 | 289 | loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' 290 | print(loss_model_log) 291 | log.write(loss_model_log + '\n') 292 | 293 | # show some predicted results 294 | dashed_line = '-' * 80 295 | head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' 296 | predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' 297 | for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): 298 | if 'Attn' in opt.Prediction: 299 | gt = gt[:gt.find('[s]')] 300 | pred = pred[:pred.find('[s]')] 301 | 302 | predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' 303 | predicted_result_log += f'{dashed_line}' 304 | print(predicted_result_log) 305 | log.write(predicted_result_log + '\n') 306 | 307 | # save model per 1e+5 iter. 308 | if (iteration + 1) % 1e+5 == 0: 309 | torch.save( 310 | model.state_dict(), f'./saved_models/{opt.exp_name}/iter_{iteration+1}.pth') 311 | 312 | if (iteration + 1) == opt.num_iter: 313 | print('end the training') 314 | sys.exit() 315 | iteration += 1 316 | 317 | 318 | if __name__ == '__main__': 319 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 320 | parser = argparse.ArgumentParser() 321 | parser.add_argument('--exp_name', help='Where to store logs and models') 322 | parser.add_argument('--train_data', required=True, help='path to training dataset') 323 | parser.add_argument('--valid_data', required=True, help='path to validation dataset') 324 | parser.add_argument('--manualSeed', type=int, default=1111, help='for random seed setting') 325 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=3) 326 | parser.add_argument('--batch_size', type=int, default=192, help='input batch size') 327 | parser.add_argument('--num_iter', type=int, default=300000, help='number of iterations to train for') 328 | parser.add_argument('--valInterval', type=int, default=2000, help='Interval between each validation') 329 | parser.add_argument('--saved_model', default='', help="path to model to continue training") 330 | parser.add_argument('--FT', action='store_true', help='whether to do fine-tuning') 331 | parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is Adadelta)') 332 | parser.add_argument('--lr', type=float, default=1, help='learning rate, default=1.0 for Adadelta') 333 | parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for adam. default=0.9') 334 | parser.add_argument('--rho', type=float, default=0.95, help='decay rate rho for Adadelta. default=0.95') 335 | parser.add_argument('--eps', type=float, default=1e-8, help='eps for Adadelta. default=1e-8') 336 | parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping value. default=5') 337 | parser.add_argument('--baiduCTC', action='store_true', help='for data_filtering_off mode') 338 | """ Data processing """ 339 | parser.add_argument('--select_data', type=str, default='MJ-ST', 340 | help='select training data (default is MJ-ST, which means MJ and ST used as training data)') 341 | parser.add_argument('--batch_ratio', type=str, default='1', 342 | help='assign ratio for each selected data in the batch 0.5-0.5') 343 | parser.add_argument('--total_data_usage_ratio', type=str, default='1.0', 344 | help='total data usage ratio, this ratio is multiplied to total number of data.') 345 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') 346 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') 347 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') 348 | parser.add_argument('--rgb', action='store_true', help='use rgb input') 349 | parser.add_argument('--character', type=str, 350 | default='0123456789一下东丰云亚亨亮会佑佛佳俊信八关兴凯利力勤华南县发号君和圩城壹大天宁宇安宏宾富封州工帆平年广庆建开德恒惠成新昌明机权来柳桂梧森横永江沙河油泰泽洋浮海润清港湖源滨珠田盈益盛石祥福程穗粤翔肇航良英藤行衡西诚诺谢谷货贵辉达运远途通都金长阳雄韶顺颜风飞香鸿鼎龙', help='character label 0123456789abcdefghijklmnopqrstuvwxyz') 351 | parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode') 352 | parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize') 353 | parser.add_argument('--data_filtering_off', action='store_true', help='for data_filtering_off mode') 354 | """ Model Architecture """ 355 | parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS') 356 | parser.add_argument('--FeatureExtraction', type=str, required=True, 357 | help='FeatureExtraction stage. VGG|RCNN|ResNet') 358 | parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM') 359 | parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn') 360 | parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') 361 | parser.add_argument('--input_channel', type=int, default=1, 362 | help='the number of input channel of Feature extractor') 363 | parser.add_argument('--output_channel', type=int, default=512, 364 | help='the number of output channel of Feature extractor') 365 | parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') 366 | parser.add_argument('--source_num', type=int, default=4, help='the number of source domain dataset') 367 | parser.add_argument('--inner_loop_iter', type=int, default=2, help='the iteration of inner loop') 368 | parser.add_argument('--fix_dataset_num', type=int, default=-1, help='fix the number of imgs per dataset') 369 | parser.add_argument('--pseudo_dataset_num', type=int, default=-1, help='fix the number of imgs per dataset') 370 | parser.add_argument('--expsuffix', type=str, default='') 371 | 372 | opt = parser.parse_args() 373 | 374 | if not opt.exp_name: 375 | opt.exp_name = f'{opt.Transformation}-{opt.FeatureExtraction}-{opt.SequenceModeling}-{opt.Prediction}' 376 | opt.exp_name += f'-Seed{opt.manualSeed}' 377 | opt.exp_name += f'-Inneriter{opt.inner_loop_iter}' 378 | assert opt.expsuffix, 'You should specify the exp suffix' 379 | opt.exp_name += f'-{opt.expsuffix}' 380 | 381 | os.makedirs(f'./saved_models/{opt.exp_name}', exist_ok=True) 382 | 383 | """ vocab / character number configuration """ 384 | if opt.sensitive: 385 | opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). 386 | 387 | opt.character = u"啊阿埃挨哎唉哀皑癌蔼矮艾碍爱隘鞍氨安俺按暗岸胺案肮昂盎凹敖熬翱袄傲奥懊澳芭捌扒叭吧笆八疤巴拔跋靶把耙坝霸罢爸白柏百摆佰败拜稗斑班搬扳般颁板版扮拌伴瓣半办绊邦帮梆榜膀绑棒磅蚌镑傍谤苞胞包褒剥薄雹保堡饱宝抱报暴豹鲍爆杯碑悲卑北辈背贝钡倍狈备惫焙被奔苯本笨崩绷甭泵蹦迸逼鼻比鄙笔彼碧蓖蔽毕毙毖币庇痹闭敝弊必辟壁臂避陛鞭边编贬扁便变卞辨辩辫遍标彪膘表鳖憋别瘪彬斌濒滨宾摈兵冰柄丙秉饼炳病并玻菠播拨钵波博勃搏铂箔伯帛舶脖膊渤泊驳捕卜哺补埠不布步簿部怖擦猜裁材才财睬踩采彩菜蔡餐参蚕残惭惨灿苍舱仓沧藏操糙槽曹草厕策侧册测层蹭插叉茬茶查碴搽察岔差诧拆柴豺搀掺蝉馋谗缠铲产阐颤昌猖场尝常长偿肠厂敞畅唱倡超抄钞朝嘲潮巢吵炒车扯撤掣彻澈郴臣辰尘晨忱沉陈趁衬撑称城橙成呈乘程惩澄诚承逞骋秤吃痴持匙池迟弛驰耻齿侈尺赤翅斥炽充冲虫崇宠抽酬畴踌稠愁筹仇绸瞅丑臭初出橱厨躇锄雏滁除楚础储矗搐触处揣川穿椽传船喘串疮窗幢床闯创吹炊捶锤垂春椿醇唇淳纯蠢戳绰疵茨磁雌辞慈瓷词此刺赐次聪葱囱匆从丛凑粗醋簇促蹿篡窜摧崔催脆瘁粹淬翠村存寸磋撮搓措挫错搭达答瘩打大呆歹傣戴带殆代贷袋待逮怠耽担丹单郸掸胆旦氮但惮淡诞弹蛋当挡党荡档刀捣蹈倒岛祷导到稻悼道盗德得的蹬灯登等瞪凳邓堤低滴迪敌笛狄涤翟嫡抵底地蒂第帝弟递缔颠掂滇碘点典靛垫电佃甸店惦奠淀殿碉叼雕凋刁掉吊钓调跌爹碟蝶迭谍叠丁盯叮钉顶鼎锭定订丢东冬董懂动栋侗恫冻洞兜抖斗陡豆逗痘都督毒犊独读堵睹赌杜镀肚度渡妒端短锻段断缎堆兑队对墩吨蹲敦顿囤钝盾遁掇哆多夺垛躲朵跺舵剁惰堕蛾峨鹅俄额讹娥恶厄扼遏鄂饿恩而儿耳尔饵洱二贰发罚筏伐乏阀法珐藩帆番翻樊矾钒繁凡烦反返范贩犯饭泛坊芳方肪房防妨仿访纺放菲非啡飞肥匪诽吠肺废沸费芬酚吩氛分纷坟焚汾粉奋份忿愤粪丰封枫蜂峰锋风疯烽逢冯缝讽奉凤佛否夫敷肤孵扶拂辐幅氟符伏俘服浮涪福袱弗甫抚辅俯釜斧脯腑府腐赴副覆赋复傅付阜父腹负富讣附妇缚咐噶嘎该改概钙盖溉干甘杆柑竿肝赶感秆敢赣冈刚钢缸肛纲岗港杠篙皋高膏羔糕搞镐稿告哥歌搁戈鸽胳疙割革葛格蛤阁隔铬个各给根跟耕更庚羹埂耿梗工攻功恭龚供躬公宫弓巩汞拱贡共钩勾沟苟狗垢构购够辜菇咕箍估沽孤姑鼓古蛊骨谷股故顾固雇刮瓜剐寡挂褂乖拐怪棺关官冠观管馆罐惯灌贯光广逛瑰规圭硅归龟闺轨鬼诡癸桂柜跪贵刽辊滚棍锅郭国果裹过哈骸孩海氦亥害骇酣憨邯韩含涵寒函喊罕翰撼捍旱憾悍焊汗汉夯杭航壕嚎豪毫郝好耗号浩呵喝荷菏核禾和何合盒貉阂河涸赫褐鹤贺嘿黑痕很狠恨哼亨横衡恒轰哄烘虹鸿洪宏弘红喉侯猴吼厚候后呼乎忽瑚壶葫胡蝴狐糊湖弧虎唬护互沪户花哗华猾滑画划化话槐徊怀淮坏欢环桓还缓换患唤痪豢焕涣宦幻荒慌黄磺蝗簧皇凰惶煌晃幌恍谎灰挥辉徽恢蛔回毁悔慧卉惠晦贿秽会烩汇讳诲绘荤昏婚魂浑混豁活伙火获或惑霍货祸击圾基机畸稽积箕肌饥迹激讥鸡姬绩缉吉极棘辑籍集及急疾汲即嫉级挤几脊己蓟技冀季伎祭剂悸济寄寂计记既忌际妓继纪嘉枷夹佳家加荚颊贾甲钾假稼价架驾嫁歼监坚尖笺间煎兼肩艰奸缄茧检柬碱碱拣捡简俭剪减荐槛鉴践贱见键箭件健舰剑饯渐溅涧建僵姜将浆江疆蒋桨奖讲匠酱降蕉椒礁焦胶交郊浇骄娇嚼搅铰矫侥脚狡角饺缴绞剿教酵轿较叫窖揭接皆秸街阶截劫节桔杰捷睫竭洁结解姐戒藉芥界借介疥诫届巾筋斤金今津襟紧锦仅谨进靳晋禁近烬浸尽劲荆兢茎睛晶鲸京惊精粳经井警景颈静境敬镜径痉靖竟竞净炯窘揪究纠玖韭久灸九酒厩救旧臼舅咎就疚鞠拘狙疽居驹菊局咀矩举沮聚拒据巨具距踞锯俱句惧炬剧捐鹃娟倦眷卷绢撅攫抉掘倔爵觉决诀绝均菌钧军君峻俊竣浚郡骏喀咖卡咯开揩楷凯慨刊堪勘坎砍看康慷糠扛抗亢炕考拷烤靠坷苛柯棵磕颗科壳咳可渴克刻客课肯啃垦恳坑吭空恐孔控抠口扣寇枯哭窟苦酷库裤夸垮挎跨胯块筷侩快宽款匡筐狂框矿眶旷况亏盔岿窥葵奎魁傀馈愧溃坤昆捆困括扩廓阔垃拉喇蜡腊辣啦莱来赖蓝婪栏拦篮阑兰澜谰揽览懒缆烂滥琅榔狼廊郎朗浪捞劳牢老佬姥酪烙涝勒乐雷镭蕾磊累儡垒擂肋类泪棱楞冷厘梨犁黎篱狸离漓理李里鲤礼莉荔吏栗丽厉励砾历利僳例俐痢立粒沥隶力璃哩俩联莲连镰廉怜涟帘敛脸链恋炼练粮凉梁粱良两辆量晾亮谅撩聊僚疗燎寥辽潦了撂镣廖料列裂烈劣猎琳林磷霖临邻鳞淋凛赁吝拎玲菱零龄铃伶羚凌灵陵岭领另令溜琉榴硫馏留刘瘤流柳六龙聋咙笼窿隆垄拢陇楼娄搂篓漏陋芦卢颅庐炉掳卤虏鲁麓碌露路赂鹿潞禄录陆戮驴吕铝侣旅履屡缕虑氯律率滤绿峦挛孪滦卵乱掠略抡轮伦仑沦纶论萝螺罗逻锣箩骡裸落洛骆络妈麻玛码蚂马骂嘛吗埋买麦卖迈脉瞒馒蛮满蔓曼慢漫谩芒茫盲氓忙莽猫茅锚毛矛铆卯茂冒帽貌贸么玫枚梅酶霉煤没眉媒镁每美昧寐妹媚门闷们萌蒙檬盟锰猛梦孟眯醚靡糜迷谜弥米秘觅泌蜜密幂棉眠绵冕免勉娩缅面苗描瞄藐秒渺庙妙蔑灭民抿皿敏悯闽明螟鸣铭名命谬摸摹蘑模膜磨摩魔抹末莫墨默沫漠寞陌谋牟某拇牡亩姆母墓暮幕募慕木目睦牧穆拿哪呐钠那娜纳氖乃奶耐奈南男难囊挠脑恼闹淖呢馁内嫩能妮霓倪泥尼拟你匿腻逆溺蔫拈年碾撵捻念娘酿鸟尿捏聂孽啮镊镍涅您柠狞凝宁拧泞牛扭钮纽脓浓农弄奴努怒女暖虐疟挪懦糯诺哦欧鸥殴藕呕偶沤啪趴爬帕怕琶拍排牌徘湃派攀潘盘磐盼畔判叛乓庞旁耪胖抛咆刨炮袍跑泡呸胚培裴赔陪配佩沛喷盆砰抨烹澎彭蓬棚硼篷膨朋鹏捧碰坯砒霹批披劈琵毗啤脾疲皮匹痞僻屁譬篇偏片骗飘漂瓢票撇瞥拼频贫品聘乒坪苹萍平凭瓶评屏坡泼颇婆破魄迫粕剖扑铺仆莆葡菩蒲埔朴圃普浦谱曝瀑期欺栖戚妻七凄漆柒沏其棋奇歧畦崎脐齐旗祈祁骑起岂乞企启契砌器气迄弃汽泣讫掐洽牵扦钎铅千迁签仟谦乾黔钱钳前潜遣浅谴堑嵌欠歉枪呛腔羌墙蔷强抢橇锹敲悄桥瞧乔侨巧鞘撬翘峭俏窍切茄且怯窃钦侵亲秦琴勤芹擒禽寝沁青轻氢倾卿清擎晴氰情顷请庆琼穷秋丘邱球求囚酋泅趋区蛆曲躯屈驱渠取娶龋趣去圈颧权醛泉全痊拳犬券劝缺炔瘸却鹊榷确雀裙群然燃冉染瓤壤攘嚷让饶扰绕惹热壬仁人忍韧任认刃妊纫扔仍日戎茸蓉荣融熔溶容绒冗揉柔肉茹蠕儒孺如辱乳汝入褥软阮蕊瑞锐闰润若弱撒洒萨腮鳃塞赛三叁伞散桑嗓丧搔骚扫嫂瑟色涩森僧莎砂杀刹沙纱傻啥煞筛晒珊苫杉山删煽衫闪陕擅赡膳善汕扇缮墒伤商赏晌上尚裳梢捎稍烧芍勺韶少哨邵绍奢赊蛇舌舍赦摄射慑涉社设砷申呻伸身深娠绅神沈审婶甚肾慎渗声生甥牲升绳省盛剩胜圣师失狮施湿诗尸虱十石拾时什食蚀实识史矢使屎驶始式示士世柿事拭誓逝势是嗜噬适仕侍释饰氏市恃室视试收手首守寿授售受瘦兽蔬枢梳殊抒输叔舒淑疏书赎孰熟薯暑曙署蜀黍鼠属术述树束戍竖墅庶数漱恕刷耍摔衰甩帅栓拴霜双爽谁水睡税吮瞬顺舜说硕朔烁斯撕嘶思私司丝死肆寺嗣四伺似饲巳松耸怂颂送宋讼诵搜艘擞嗽苏酥俗素速粟僳塑溯宿诉肃酸蒜算虽隋随绥髓碎岁穗遂隧祟孙损笋蓑梭唆缩琐索锁所塌他它她塔獭挞蹋踏胎苔抬台泰酞太态汰坍摊贪瘫滩坛檀痰潭谭谈坦毯袒碳探叹炭汤塘搪堂棠膛唐糖倘躺淌趟烫掏涛滔绦萄桃逃淘陶讨套特藤腾疼誊梯剔踢锑提题蹄啼体替嚏惕涕剃屉天添填田甜恬舔腆挑条迢眺跳贴铁帖厅听烃汀廷停亭庭艇通桐酮瞳同铜彤童桶捅筒统痛偷投头透凸秃突图徒途涂屠土吐兔湍团推颓腿蜕褪退吞屯臀拖托脱鸵陀驮驼椭妥拓唾挖哇蛙洼娃瓦袜歪外豌弯湾玩顽丸烷完碗挽晚皖惋宛婉万腕汪王亡枉网往旺望忘妄威巍微危圩韦违桅围唯惟为潍维苇萎委伟伪尾纬未蔚味畏胃喂魏位渭谓尉慰卫瘟温蚊文闻纹吻稳紊问嗡翁瓮挝蜗涡窝我斡卧握沃巫呜钨乌污诬屋无芜梧吾吴毋武五捂午舞伍侮坞戊雾晤物勿务悟误昔熙析西硒矽晰嘻吸锡牺稀息希悉膝夕惜熄烯溪汐犀檄袭席习媳喜铣洗系隙戏细瞎虾匣霞辖暇峡侠狭下厦夏吓掀锨先仙鲜纤咸贤衔舷闲涎弦嫌显险现献县腺馅羡宪陷限线相厢镶香箱襄湘乡翔祥详想响享项巷橡像向象萧硝霄削哮嚣销消宵淆晓小孝校肖啸笑效楔些歇蝎鞋协挟携邪斜胁谐写械卸蟹懈泄泻谢屑薪芯锌欣辛新忻心信衅星腥猩惺兴刑型形邢行醒幸杏性姓兄凶胸匈汹雄熊休修羞朽嗅锈秀袖绣墟戌需虚嘘须徐许蓄酗叙旭序畜恤絮婿绪续轩喧宣悬旋玄选癣眩绚靴薛学穴雪血勋熏循旬询寻驯巡殉汛训讯逊迅压押鸦鸭呀丫芽牙蚜崖衙涯雅哑亚讶焉咽阉烟淹盐严研蜒岩延言颜阎炎沿奄掩眼衍演艳堰燕厌砚雁唁彦焰宴谚验殃央鸯秧杨扬佯疡羊洋阳氧仰痒养样漾邀腰妖瑶摇尧遥窑谣姚咬舀药要耀椰噎耶爷野冶也页掖业叶曳腋夜液一壹医揖铱依伊衣颐夷遗移仪胰疑沂宜姨彝椅蚁倚已乙矣以艺抑易邑屹亿役臆逸肄疫亦裔意毅忆义益溢诣议谊译异翼翌绎茵荫因殷音阴姻吟银淫寅饮尹引隐印英樱婴鹰应缨莹萤营荧蝇迎赢盈影颖硬映哟拥佣臃痈庸雍踊蛹咏泳涌永恿勇用幽优悠忧尤由邮铀犹油游酉有友右佑釉诱又幼迂淤于盂榆虞愚舆余俞逾鱼愉渝渔隅予娱雨与屿禹宇语羽玉域芋郁吁遇喻峪御愈欲狱育誉浴寓裕预豫驭鸳渊冤元垣袁原援辕园员圆猿源缘远苑愿怨院曰约越跃钥岳粤月悦阅耘云郧匀陨允运蕴酝晕韵孕匝砸杂栽哉灾宰载再在咱攒暂赞赃脏葬遭糟凿藻枣早澡蚤躁噪造皂灶燥责择则泽贼怎增憎曾赠扎喳渣札轧铡闸眨栅榨咋乍炸诈摘斋宅窄债寨瞻毡詹粘沾盏斩辗崭展蘸栈占战站湛绽樟章彰漳张掌涨杖丈帐账仗胀瘴障招昭找沼赵照罩兆肇召遮折哲蛰辙者锗蔗这浙珍斟真甄砧臻贞针侦枕疹诊震振镇阵蒸挣睁征狰争怔整拯正政帧症郑证芝枝支吱蜘知肢脂汁之织职直植殖执值侄址指止趾只旨纸志挚掷至致置帜峙制智秩稚质炙痔滞治窒中盅忠钟衷终种肿重仲众舟周州洲诌粥轴肘帚咒皱宙昼骤珠株蛛朱猪诸诛逐竹烛煮拄瞩嘱主著柱助蛀贮铸筑住注祝驻抓爪拽专砖转撰赚篆桩庄装妆撞壮状椎锥追赘坠缀谆准捉拙卓桌琢茁酌啄着灼浊兹咨资姿滋淄孜紫仔籽滓子自渍字鬃棕踪宗综总纵邹走奏揍租足卒族祖诅阻组钻纂嘴醉最罪尊遵昨左佐柞做作坐座1234567890ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 388 | """ Seed and GPU setting """ 389 | random.seed(opt.manualSeed) 390 | np.random.seed(opt.manualSeed) 391 | torch.manual_seed(opt.manualSeed) 392 | torch.cuda.manual_seed(opt.manualSeed) 393 | 394 | cudnn.benchmark = False 395 | cudnn.deterministic = True 396 | opt.num_gpu = torch.cuda.device_count() 397 | print('device count', opt.num_gpu) 398 | opt.num_gpu = 1 399 | 400 | if opt.num_gpu > 1: 401 | print('------ Use multi-GPU setting ------') 402 | print('if you stuck too long time with multi-GPU setting, try to set --workers 0') 403 | # check multi-GPU issue https://github.com/clovaai/deep-text-recognition-benchmark/issues/1 404 | opt.workers = opt.workers * opt.num_gpu 405 | opt.batch_size = opt.batch_size * opt.num_gpu 406 | 407 | """ previous version 408 | print('To equlize batch stats to 1-GPU setting, the batch_size is multiplied with num_gpu and multiplied batch_size is ', opt.batch_size) 409 | opt.batch_size = opt.batch_size * opt.num_gpu 410 | print('To equalize the number of epochs to 1-GPU setting, num_iter is divided with num_gpu by default.') 411 | If you dont care about it, just commnet out these line.) 412 | opt.num_iter = int(opt.num_iter / opt.num_gpu) 413 | """ 414 | 415 | train(opt) 416 | -------------------------------------------------------------------------------- /self_training.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import os 3 | import sys 4 | import time 5 | import random 6 | import string 7 | import argparse 8 | import visdom 9 | 10 | import torch 11 | import torch.backends.cudnn as cudnn 12 | import torch.nn.init as init 13 | import torch.optim as optim 14 | import torch.utils.data 15 | import numpy as np 16 | 17 | from utils import CTCLabelConverter, CTCLabelConverterForBaiduWarpctc, AttnLabelConverter, Averager 18 | from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset, SelfTrainingDataset, self_training_collate 19 | from model import Model 20 | from test import validation 21 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 22 | 23 | 24 | def create_vis_plot(_xlabel, _ylabel, _title, _legend, viz): 25 | return viz.line( 26 | X=torch.zeros((1,)).cpu(), 27 | Y=torch.zeros((1, 1)).cpu(), 28 | opts=dict( 29 | xlabel=_xlabel, 30 | ylabel=_ylabel, 31 | title=_title, 32 | legend=_legend 33 | ) 34 | ) 35 | 36 | def update_vis_plot(iteration, loc, window1, update_type, viz): 37 | viz.line( 38 | X=torch.ones((1, 1)).cpu() * iteration, 39 | Y=torch.Tensor([loc]).unsqueeze(0).cpu(), 40 | win=window1, 41 | update=update_type 42 | ) 43 | 44 | def train(opt): 45 | 46 | """ dataset preparation """ 47 | if not opt.data_filtering_off: 48 | print('Filtering the images containing characters which are not in opt.character') 49 | print('Filtering the images whose label is longer than opt.batch_max_length') 50 | # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 51 | 52 | opt.select_data = opt.select_data.split('-') 53 | opt.batch_ratio = opt.batch_ratio.split('-') 54 | train_dataset = Batch_Balanced_Dataset(opt) 55 | 56 | log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') 57 | AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 58 | AlignCollate_test = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 59 | test_dataset, test_dataset_log = hierarchical_dataset(root=opt.test_data, opt=opt) 60 | test_loader = torch.utils.data.DataLoader( 61 | test_dataset, batch_size=opt.batch_size, 62 | shuffle=True, # 'True' to check training progress with validation function. 63 | num_workers=int(opt.workers), 64 | collate_fn=AlignCollate_test, pin_memory=True) 65 | valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt, pseudo=True) 66 | valid_loader = torch.utils.data.DataLoader( 67 | valid_dataset, batch_size=opt.batch_size, 68 | shuffle=True, # 'True' to check training progress with validation function. 69 | num_workers=int(opt.workers), 70 | collate_fn=AlignCollate_valid, pin_memory=True) 71 | log.write(valid_dataset_log) 72 | print('-' * 80) 73 | log.write('-' * 80 + '\n') 74 | log.close() 75 | 76 | 77 | """ model configuration """ 78 | if 'CTC' in opt.Prediction: 79 | if opt.baiduCTC: 80 | converter = CTCLabelConverterForBaiduWarpctc(opt.character) 81 | else: 82 | converter = CTCLabelConverter(opt.character) 83 | else: 84 | converter = AttnLabelConverter(opt.character) 85 | opt.num_class = len(converter.character) 86 | 87 | if opt.rgb: 88 | opt.input_channel = 3 89 | model = Model(opt) 90 | print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, 91 | opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, 92 | opt.SequenceModeling, opt.Prediction) 93 | 94 | # weight initialization 95 | for name, param in model.named_parameters(): 96 | if 'localization_fc2' in name: 97 | print(f'Skip {name} as it is already initialized') 98 | continue 99 | try: 100 | if 'bias' in name: 101 | init.constant_(param, 0.0) 102 | elif 'weight' in name: 103 | init.kaiming_normal_(param) 104 | except Exception as e: # for batchnorm. 105 | if 'weight' in name: 106 | param.data.fill_(1) 107 | continue 108 | 109 | # data parallel for multi-GPU 110 | model = torch.nn.DataParallel(model).to(device) 111 | model.train() 112 | if opt.saved_model != '': 113 | print(f'loading pretrained model from {opt.saved_model}') 114 | if opt.FT: 115 | model.load_state_dict(torch.load(opt.saved_model), strict=False) 116 | else: 117 | model.load_state_dict(torch.load(opt.saved_model)) 118 | print("Model:") 119 | print(model) 120 | 121 | """ setup loss """ 122 | if 'CTC' in opt.Prediction: 123 | if opt.baiduCTC: 124 | # need to install warpctc. see our guideline. 125 | from warpctc_pytorch import CTCLoss 126 | criterion = CTCLoss() 127 | else: 128 | criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) 129 | else: 130 | criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 131 | # loss averager 132 | loss_avg = Averager() 133 | 134 | # filter that only require gradient decent 135 | filtered_parameters = [] 136 | params_num = [] 137 | 138 | for p in filter(lambda p: p.requires_grad, model.parameters()): 139 | filtered_parameters.append(p) 140 | params_num.append(np.prod(p.size())) 141 | print('Trainable params num : ', sum(params_num)) 142 | 143 | # setup optimizer 144 | if opt.adam: 145 | optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) 146 | else: 147 | optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) 148 | print("Optimizer:") 149 | print(optimizer) 150 | 151 | """ final options """ 152 | # print(opt) 153 | with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: 154 | opt_log = '------------ Options -------------\n' 155 | args = vars(opt) 156 | for k, v in args.items(): 157 | opt_log += f'{str(k)}: {str(v)}\n' 158 | opt_log += '---------------------------------------\n' 159 | print(opt_log) 160 | opt_file.write(opt_log) 161 | 162 | """ start training """ 163 | start_iter = 0 164 | if opt.saved_model != '': 165 | try: 166 | start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) 167 | print(f'continue to train, start_iter: {start_iter}') 168 | except: 169 | pass 170 | 171 | start_time = time.time() 172 | best_accuracy = -1 173 | best_norm_ED = -1 174 | test_best_accuracy = -1 175 | test_best_norm_ED = -1 176 | iteration = start_iter 177 | 178 | while(True): 179 | # train part 180 | if train_dataset.has_pseudo_label_dataset(): 181 | meta_target_index = opt.source_num 182 | image_tensors, labels = train_dataset.get_meta_test_batch(meta_target_index) 183 | else: 184 | image_tensors, labels = train_dataset.get_batch() 185 | image = image_tensors.to(device) 186 | text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) 187 | batch_size = image.size(0) 188 | 189 | if 'CTC' in opt.Prediction: 190 | preds = model(image, text) 191 | preds_size = torch.IntTensor([preds.size(1)] * batch_size) 192 | if opt.baiduCTC: 193 | preds = preds.permute(1, 0, 2) # to use CTCLoss format 194 | cost = criterion(preds, text, preds_size, length) / batch_size 195 | else: 196 | preds = preds.log_softmax(2).permute(1, 0, 2) 197 | cost = criterion(preds, text, preds_size, length) 198 | 199 | else: 200 | preds = model(image, text[:, :-1]) # align with Attention.forward 201 | target = text[:, 1:] # without [GO] Symbol 202 | cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) 203 | 204 | model.zero_grad() 205 | cost.backward() 206 | torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) 207 | optimizer.step() 208 | 209 | loss_avg.add(cost) 210 | # validation part 211 | if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' 212 | elapsed_time = time.time() - start_time 213 | # for log 214 | with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: 215 | model.eval() 216 | with torch.no_grad(): 217 | valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data, all_imgs, all_labels, all_confidences, all_gts = validation( 218 | model, criterion, valid_loader, converter, opt, self_training=True) 219 | test_loss, test_current_accuracy, test_current_norm_ED, test_preds, test_confidence_score, test_labels, test_infer_time, test_length_of_data = validation( 220 | model, criterion, test_loader, converter, opt) 221 | model.train() 222 | print(len(all_imgs), len(all_labels), len(all_confidences)) 223 | print(all_imgs[0].shape) 224 | # training loss and validation loss 225 | loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' 226 | current_test_log = f'{"Test_Current_accuracy":17s}: {test_current_accuracy:0.3f}, {"Test_Current_norm_ED":17s}: {test_current_norm_ED:0.2f}' 227 | loss_avg.reset() 228 | 229 | current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' 230 | 231 | # keep best accuracy model (on valid dataset) 232 | if current_accuracy > best_accuracy: 233 | best_accuracy = current_accuracy 234 | torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy.pth') 235 | if current_norm_ED > best_norm_ED: 236 | best_norm_ED = current_norm_ED 237 | torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth') 238 | if test_current_accuracy > test_best_accuracy: 239 | test_best_accuracy = test_current_accuracy 240 | torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/test_best_accuracy.pth') 241 | if test_current_norm_ED > test_best_norm_ED: 242 | test_best_norm_ED = test_current_norm_ED 243 | torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/test_best_norm_ED.pth') 244 | best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' 245 | test_best_model_log = f'{"Test_Best_accuracy":17s}: {test_best_accuracy:0.3f}, {"Test_Best_norm_ED":17s}: {test_best_norm_ED:0.2f}' 246 | loss_model_log = f'{loss_log}\n{current_model_log}\n{current_test_log}\n{best_model_log}\n{test_best_model_log}' 247 | print(loss_model_log) 248 | log.write(loss_model_log + '\n') 249 | 250 | # show some predicted results 251 | dashed_line = '-' * 80 252 | head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' 253 | predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' 254 | # print(len(valid_dataset), len(labels)) 255 | 256 | self_training_imgs = [] 257 | self_training_labels = [] 258 | right_pseudo_label_counts = 0 259 | 260 | for conf, img, label, gt in zip(all_confidences, all_imgs, all_labels, all_gts): 261 | if conf < opt.pseudo_threshold: continue 262 | label = label[: label.find('[s]')] 263 | self_training_imgs.append(img) 264 | self_training_labels.append(label) 265 | if gt == label: right_pseudo_label_counts += 1 266 | 267 | self_training_dataset_length = len(self_training_imgs) 268 | if self_training_dataset_length > 0 and current_accuracy >= opt.warmup_threshold: 269 | pseudo_log = 'Generating Self-Training Dataset, contains {} imgs, {} of them have right labels'.format(self_training_dataset_length, right_pseudo_label_counts) 270 | print(pseudo_log) 271 | log.write(pseudo_log + '\n') 272 | self_training_dataset = SelfTrainingDataset(self_training_imgs, self_training_labels) 273 | train_dataset.add_pseudo_label_dataset(self_training_dataset, opt) 274 | 275 | 276 | for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): 277 | if 'Attn' in opt.Prediction: 278 | gt = gt[:gt.find('[s]')] 279 | pred = pred[:pred.find('[s]')] 280 | 281 | predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' 282 | predicted_result_log += f'{dashed_line}' 283 | print(predicted_result_log) 284 | log.write(predicted_result_log + '\n') 285 | 286 | # save model per 1e+5 iter. 287 | if (iteration + 1) % 1e+5 == 0: 288 | torch.save( 289 | model.state_dict(), f'./saved_models/{opt.exp_name}/iter_{iteration+1}.pth') 290 | 291 | if (iteration + 1) == opt.num_iter: 292 | print('end the training') 293 | sys.exit() 294 | iteration += 1 295 | 296 | 297 | if __name__ == '__main__': 298 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 299 | parser = argparse.ArgumentParser() 300 | parser.add_argument('--exp_name', help='Where to store logs and models') 301 | parser.add_argument('--train_data', required=True, help='path to training dataset') 302 | parser.add_argument('--valid_data', required=True, help='path to validation dataset') 303 | parser.add_argument('--test_data', required=True, help='path to test dataset') 304 | parser.add_argument('--manualSeed', type=int, default=1111, help='for random seed setting') 305 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=2) 306 | parser.add_argument('--batch_size', type=int, default=192, help='input batch size') 307 | parser.add_argument('--num_iter', type=int, default=300000, help='number of iterations to train for') 308 | parser.add_argument('--valInterval', type=int, default=2000, help='Interval between each validation') 309 | parser.add_argument('--saved_model', default='', help="path to model to continue training") 310 | parser.add_argument('--FT', action='store_true', help='whether to do fine-tuning') 311 | parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is Adadelta)') 312 | parser.add_argument('--lr', type=float, default=1, help='learning rate, default=1.0 for Adadelta') 313 | parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for adam. default=0.9') 314 | parser.add_argument('--rho', type=float, default=0.95, help='decay rate rho for Adadelta. default=0.95') 315 | parser.add_argument('--eps', type=float, default=1e-8, help='eps for Adadelta. default=1e-8') 316 | parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping value. default=5') 317 | parser.add_argument('--baiduCTC', action='store_true', help='for data_filtering_off mode') 318 | """ Data processing """ 319 | parser.add_argument('--select_data', type=str, default='MJ-ST', 320 | help='select training data (default is MJ-ST, which means MJ and ST used as training data)') 321 | parser.add_argument('--batch_ratio', type=str, default='1', 322 | help='assign ratio for each selected data in the batch 0.5-0.5') 323 | parser.add_argument('--total_data_usage_ratio', type=str, default='1.0', 324 | help='total data usage ratio, this ratio is multiplied to total number of data.') 325 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') 326 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') 327 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') 328 | parser.add_argument('--rgb', action='store_true', help='use rgb input') 329 | parser.add_argument('--character', type=str, 330 | default='0123456789一下东丰云亚亨亮会佑佛佳俊信八关兴凯利力勤华南县发号君和圩城壹大天宁宇安宏宾富封州工帆平年广庆建开德恒惠成新昌明机权来柳桂梧森横永江沙河油泰泽洋浮海润清港湖源滨珠田盈益盛石祥福程穗粤翔肇航良英藤行衡西诚诺谢谷货贵辉达运远途通都金长阳雄韶顺颜风飞香鸿鼎龙', help='character label 0123456789abcdefghijklmnopqrstuvwxyz') 331 | parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode') 332 | parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize') 333 | parser.add_argument('--data_filtering_off', action='store_true', help='for data_filtering_off mode') 334 | """ Model Architecture """ 335 | parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS') 336 | parser.add_argument('--FeatureExtraction', type=str, required=True, 337 | help='FeatureExtraction stage. VGG|RCNN|ResNet') 338 | parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM') 339 | parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn') 340 | parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') 341 | parser.add_argument('--input_channel', type=int, default=1, 342 | help='the number of input channel of Feature extractor') 343 | parser.add_argument('--output_channel', type=int, default=512, 344 | help='the number of output channel of Feature extractor') 345 | parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') 346 | parser.add_argument('--fix_dataset_num', type=int, default=-1, help='fix the number of imgs per dataset') 347 | parser.add_argument('--pseudo_dataset_num', type=int, default=-1, help='fix the number of imgs per dataset') 348 | parser.add_argument('--source_num', type=int, default=4, help='the number of source domain dataset') 349 | parser.add_argument('--pseudo_threshold', type=float, default=0.98, help='The threshold of using pseudo labels') 350 | parser.add_argument('--warmup_threshold', type=float, default=27.5, help='The threshold of using pseudo labels') 351 | parser.add_argument('--expsuffix', type=str, default='') 352 | opt = parser.parse_args() 353 | 354 | if not opt.exp_name: 355 | opt.exp_name = f'{opt.Transformation}-{opt.FeatureExtraction}-{opt.SequenceModeling}-{opt.Prediction}' 356 | opt.exp_name += f'-Seed{opt.manualSeed}' 357 | assert opt.expsuffix, 'You should specify the exp suffix' 358 | opt.exp_name += f'-{opt.expsuffix}' 359 | # print(opt.exp_name) 360 | 361 | os.makedirs(f'./saved_models/{opt.exp_name}', exist_ok=True) 362 | 363 | """ vocab / character number configuration """ 364 | if opt.sensitive: 365 | # opt.character += 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' 366 | opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). 367 | 368 | opt.character = u"啊阿埃挨哎唉哀皑癌蔼矮艾碍爱隘鞍氨安俺按暗岸胺案肮昂盎凹敖熬翱袄傲奥懊澳芭捌扒叭吧笆八疤巴拔跋靶把耙坝霸罢爸白柏百摆佰败拜稗斑班搬扳般颁板版扮拌伴瓣半办绊邦帮梆榜膀绑棒磅蚌镑傍谤苞胞包褒剥薄雹保堡饱宝抱报暴豹鲍爆杯碑悲卑北辈背贝钡倍狈备惫焙被奔苯本笨崩绷甭泵蹦迸逼鼻比鄙笔彼碧蓖蔽毕毙毖币庇痹闭敝弊必辟壁臂避陛鞭边编贬扁便变卞辨辩辫遍标彪膘表鳖憋别瘪彬斌濒滨宾摈兵冰柄丙秉饼炳病并玻菠播拨钵波博勃搏铂箔伯帛舶脖膊渤泊驳捕卜哺补埠不布步簿部怖擦猜裁材才财睬踩采彩菜蔡餐参蚕残惭惨灿苍舱仓沧藏操糙槽曹草厕策侧册测层蹭插叉茬茶查碴搽察岔差诧拆柴豺搀掺蝉馋谗缠铲产阐颤昌猖场尝常长偿肠厂敞畅唱倡超抄钞朝嘲潮巢吵炒车扯撤掣彻澈郴臣辰尘晨忱沉陈趁衬撑称城橙成呈乘程惩澄诚承逞骋秤吃痴持匙池迟弛驰耻齿侈尺赤翅斥炽充冲虫崇宠抽酬畴踌稠愁筹仇绸瞅丑臭初出橱厨躇锄雏滁除楚础储矗搐触处揣川穿椽传船喘串疮窗幢床闯创吹炊捶锤垂春椿醇唇淳纯蠢戳绰疵茨磁雌辞慈瓷词此刺赐次聪葱囱匆从丛凑粗醋簇促蹿篡窜摧崔催脆瘁粹淬翠村存寸磋撮搓措挫错搭达答瘩打大呆歹傣戴带殆代贷袋待逮怠耽担丹单郸掸胆旦氮但惮淡诞弹蛋当挡党荡档刀捣蹈倒岛祷导到稻悼道盗德得的蹬灯登等瞪凳邓堤低滴迪敌笛狄涤翟嫡抵底地蒂第帝弟递缔颠掂滇碘点典靛垫电佃甸店惦奠淀殿碉叼雕凋刁掉吊钓调跌爹碟蝶迭谍叠丁盯叮钉顶鼎锭定订丢东冬董懂动栋侗恫冻洞兜抖斗陡豆逗痘都督毒犊独读堵睹赌杜镀肚度渡妒端短锻段断缎堆兑队对墩吨蹲敦顿囤钝盾遁掇哆多夺垛躲朵跺舵剁惰堕蛾峨鹅俄额讹娥恶厄扼遏鄂饿恩而儿耳尔饵洱二贰发罚筏伐乏阀法珐藩帆番翻樊矾钒繁凡烦反返范贩犯饭泛坊芳方肪房防妨仿访纺放菲非啡飞肥匪诽吠肺废沸费芬酚吩氛分纷坟焚汾粉奋份忿愤粪丰封枫蜂峰锋风疯烽逢冯缝讽奉凤佛否夫敷肤孵扶拂辐幅氟符伏俘服浮涪福袱弗甫抚辅俯釜斧脯腑府腐赴副覆赋复傅付阜父腹负富讣附妇缚咐噶嘎该改概钙盖溉干甘杆柑竿肝赶感秆敢赣冈刚钢缸肛纲岗港杠篙皋高膏羔糕搞镐稿告哥歌搁戈鸽胳疙割革葛格蛤阁隔铬个各给根跟耕更庚羹埂耿梗工攻功恭龚供躬公宫弓巩汞拱贡共钩勾沟苟狗垢构购够辜菇咕箍估沽孤姑鼓古蛊骨谷股故顾固雇刮瓜剐寡挂褂乖拐怪棺关官冠观管馆罐惯灌贯光广逛瑰规圭硅归龟闺轨鬼诡癸桂柜跪贵刽辊滚棍锅郭国果裹过哈骸孩海氦亥害骇酣憨邯韩含涵寒函喊罕翰撼捍旱憾悍焊汗汉夯杭航壕嚎豪毫郝好耗号浩呵喝荷菏核禾和何合盒貉阂河涸赫褐鹤贺嘿黑痕很狠恨哼亨横衡恒轰哄烘虹鸿洪宏弘红喉侯猴吼厚候后呼乎忽瑚壶葫胡蝴狐糊湖弧虎唬护互沪户花哗华猾滑画划化话槐徊怀淮坏欢环桓还缓换患唤痪豢焕涣宦幻荒慌黄磺蝗簧皇凰惶煌晃幌恍谎灰挥辉徽恢蛔回毁悔慧卉惠晦贿秽会烩汇讳诲绘荤昏婚魂浑混豁活伙火获或惑霍货祸击圾基机畸稽积箕肌饥迹激讥鸡姬绩缉吉极棘辑籍集及急疾汲即嫉级挤几脊己蓟技冀季伎祭剂悸济寄寂计记既忌际妓继纪嘉枷夹佳家加荚颊贾甲钾假稼价架驾嫁歼监坚尖笺间煎兼肩艰奸缄茧检柬碱碱拣捡简俭剪减荐槛鉴践贱见键箭件健舰剑饯渐溅涧建僵姜将浆江疆蒋桨奖讲匠酱降蕉椒礁焦胶交郊浇骄娇嚼搅铰矫侥脚狡角饺缴绞剿教酵轿较叫窖揭接皆秸街阶截劫节桔杰捷睫竭洁结解姐戒藉芥界借介疥诫届巾筋斤金今津襟紧锦仅谨进靳晋禁近烬浸尽劲荆兢茎睛晶鲸京惊精粳经井警景颈静境敬镜径痉靖竟竞净炯窘揪究纠玖韭久灸九酒厩救旧臼舅咎就疚鞠拘狙疽居驹菊局咀矩举沮聚拒据巨具距踞锯俱句惧炬剧捐鹃娟倦眷卷绢撅攫抉掘倔爵觉决诀绝均菌钧军君峻俊竣浚郡骏喀咖卡咯开揩楷凯慨刊堪勘坎砍看康慷糠扛抗亢炕考拷烤靠坷苛柯棵磕颗科壳咳可渴克刻客课肯啃垦恳坑吭空恐孔控抠口扣寇枯哭窟苦酷库裤夸垮挎跨胯块筷侩快宽款匡筐狂框矿眶旷况亏盔岿窥葵奎魁傀馈愧溃坤昆捆困括扩廓阔垃拉喇蜡腊辣啦莱来赖蓝婪栏拦篮阑兰澜谰揽览懒缆烂滥琅榔狼廊郎朗浪捞劳牢老佬姥酪烙涝勒乐雷镭蕾磊累儡垒擂肋类泪棱楞冷厘梨犁黎篱狸离漓理李里鲤礼莉荔吏栗丽厉励砾历利僳例俐痢立粒沥隶力璃哩俩联莲连镰廉怜涟帘敛脸链恋炼练粮凉梁粱良两辆量晾亮谅撩聊僚疗燎寥辽潦了撂镣廖料列裂烈劣猎琳林磷霖临邻鳞淋凛赁吝拎玲菱零龄铃伶羚凌灵陵岭领另令溜琉榴硫馏留刘瘤流柳六龙聋咙笼窿隆垄拢陇楼娄搂篓漏陋芦卢颅庐炉掳卤虏鲁麓碌露路赂鹿潞禄录陆戮驴吕铝侣旅履屡缕虑氯律率滤绿峦挛孪滦卵乱掠略抡轮伦仑沦纶论萝螺罗逻锣箩骡裸落洛骆络妈麻玛码蚂马骂嘛吗埋买麦卖迈脉瞒馒蛮满蔓曼慢漫谩芒茫盲氓忙莽猫茅锚毛矛铆卯茂冒帽貌贸么玫枚梅酶霉煤没眉媒镁每美昧寐妹媚门闷们萌蒙檬盟锰猛梦孟眯醚靡糜迷谜弥米秘觅泌蜜密幂棉眠绵冕免勉娩缅面苗描瞄藐秒渺庙妙蔑灭民抿皿敏悯闽明螟鸣铭名命谬摸摹蘑模膜磨摩魔抹末莫墨默沫漠寞陌谋牟某拇牡亩姆母墓暮幕募慕木目睦牧穆拿哪呐钠那娜纳氖乃奶耐奈南男难囊挠脑恼闹淖呢馁内嫩能妮霓倪泥尼拟你匿腻逆溺蔫拈年碾撵捻念娘酿鸟尿捏聂孽啮镊镍涅您柠狞凝宁拧泞牛扭钮纽脓浓农弄奴努怒女暖虐疟挪懦糯诺哦欧鸥殴藕呕偶沤啪趴爬帕怕琶拍排牌徘湃派攀潘盘磐盼畔判叛乓庞旁耪胖抛咆刨炮袍跑泡呸胚培裴赔陪配佩沛喷盆砰抨烹澎彭蓬棚硼篷膨朋鹏捧碰坯砒霹批披劈琵毗啤脾疲皮匹痞僻屁譬篇偏片骗飘漂瓢票撇瞥拼频贫品聘乒坪苹萍平凭瓶评屏坡泼颇婆破魄迫粕剖扑铺仆莆葡菩蒲埔朴圃普浦谱曝瀑期欺栖戚妻七凄漆柒沏其棋奇歧畦崎脐齐旗祈祁骑起岂乞企启契砌器气迄弃汽泣讫掐洽牵扦钎铅千迁签仟谦乾黔钱钳前潜遣浅谴堑嵌欠歉枪呛腔羌墙蔷强抢橇锹敲悄桥瞧乔侨巧鞘撬翘峭俏窍切茄且怯窃钦侵亲秦琴勤芹擒禽寝沁青轻氢倾卿清擎晴氰情顷请庆琼穷秋丘邱球求囚酋泅趋区蛆曲躯屈驱渠取娶龋趣去圈颧权醛泉全痊拳犬券劝缺炔瘸却鹊榷确雀裙群然燃冉染瓤壤攘嚷让饶扰绕惹热壬仁人忍韧任认刃妊纫扔仍日戎茸蓉荣融熔溶容绒冗揉柔肉茹蠕儒孺如辱乳汝入褥软阮蕊瑞锐闰润若弱撒洒萨腮鳃塞赛三叁伞散桑嗓丧搔骚扫嫂瑟色涩森僧莎砂杀刹沙纱傻啥煞筛晒珊苫杉山删煽衫闪陕擅赡膳善汕扇缮墒伤商赏晌上尚裳梢捎稍烧芍勺韶少哨邵绍奢赊蛇舌舍赦摄射慑涉社设砷申呻伸身深娠绅神沈审婶甚肾慎渗声生甥牲升绳省盛剩胜圣师失狮施湿诗尸虱十石拾时什食蚀实识史矢使屎驶始式示士世柿事拭誓逝势是嗜噬适仕侍释饰氏市恃室视试收手首守寿授售受瘦兽蔬枢梳殊抒输叔舒淑疏书赎孰熟薯暑曙署蜀黍鼠属术述树束戍竖墅庶数漱恕刷耍摔衰甩帅栓拴霜双爽谁水睡税吮瞬顺舜说硕朔烁斯撕嘶思私司丝死肆寺嗣四伺似饲巳松耸怂颂送宋讼诵搜艘擞嗽苏酥俗素速粟僳塑溯宿诉肃酸蒜算虽隋随绥髓碎岁穗遂隧祟孙损笋蓑梭唆缩琐索锁所塌他它她塔獭挞蹋踏胎苔抬台泰酞太态汰坍摊贪瘫滩坛檀痰潭谭谈坦毯袒碳探叹炭汤塘搪堂棠膛唐糖倘躺淌趟烫掏涛滔绦萄桃逃淘陶讨套特藤腾疼誊梯剔踢锑提题蹄啼体替嚏惕涕剃屉天添填田甜恬舔腆挑条迢眺跳贴铁帖厅听烃汀廷停亭庭艇通桐酮瞳同铜彤童桶捅筒统痛偷投头透凸秃突图徒途涂屠土吐兔湍团推颓腿蜕褪退吞屯臀拖托脱鸵陀驮驼椭妥拓唾挖哇蛙洼娃瓦袜歪外豌弯湾玩顽丸烷完碗挽晚皖惋宛婉万腕汪王亡枉网往旺望忘妄威巍微危圩韦违桅围唯惟为潍维苇萎委伟伪尾纬未蔚味畏胃喂魏位渭谓尉慰卫瘟温蚊文闻纹吻稳紊问嗡翁瓮挝蜗涡窝我斡卧握沃巫呜钨乌污诬屋无芜梧吾吴毋武五捂午舞伍侮坞戊雾晤物勿务悟误昔熙析西硒矽晰嘻吸锡牺稀息希悉膝夕惜熄烯溪汐犀檄袭席习媳喜铣洗系隙戏细瞎虾匣霞辖暇峡侠狭下厦夏吓掀锨先仙鲜纤咸贤衔舷闲涎弦嫌显险现献县腺馅羡宪陷限线相厢镶香箱襄湘乡翔祥详想响享项巷橡像向象萧硝霄削哮嚣销消宵淆晓小孝校肖啸笑效楔些歇蝎鞋协挟携邪斜胁谐写械卸蟹懈泄泻谢屑薪芯锌欣辛新忻心信衅星腥猩惺兴刑型形邢行醒幸杏性姓兄凶胸匈汹雄熊休修羞朽嗅锈秀袖绣墟戌需虚嘘须徐许蓄酗叙旭序畜恤絮婿绪续轩喧宣悬旋玄选癣眩绚靴薛学穴雪血勋熏循旬询寻驯巡殉汛训讯逊迅压押鸦鸭呀丫芽牙蚜崖衙涯雅哑亚讶焉咽阉烟淹盐严研蜒岩延言颜阎炎沿奄掩眼衍演艳堰燕厌砚雁唁彦焰宴谚验殃央鸯秧杨扬佯疡羊洋阳氧仰痒养样漾邀腰妖瑶摇尧遥窑谣姚咬舀药要耀椰噎耶爷野冶也页掖业叶曳腋夜液一壹医揖铱依伊衣颐夷遗移仪胰疑沂宜姨彝椅蚁倚已乙矣以艺抑易邑屹亿役臆逸肄疫亦裔意毅忆义益溢诣议谊译异翼翌绎茵荫因殷音阴姻吟银淫寅饮尹引隐印英樱婴鹰应缨莹萤营荧蝇迎赢盈影颖硬映哟拥佣臃痈庸雍踊蛹咏泳涌永恿勇用幽优悠忧尤由邮铀犹油游酉有友右佑釉诱又幼迂淤于盂榆虞愚舆余俞逾鱼愉渝渔隅予娱雨与屿禹宇语羽玉域芋郁吁遇喻峪御愈欲狱育誉浴寓裕预豫驭鸳渊冤元垣袁原援辕园员圆猿源缘远苑愿怨院曰约越跃钥岳粤月悦阅耘云郧匀陨允运蕴酝晕韵孕匝砸杂栽哉灾宰载再在咱攒暂赞赃脏葬遭糟凿藻枣早澡蚤躁噪造皂灶燥责择则泽贼怎增憎曾赠扎喳渣札轧铡闸眨栅榨咋乍炸诈摘斋宅窄债寨瞻毡詹粘沾盏斩辗崭展蘸栈占战站湛绽樟章彰漳张掌涨杖丈帐账仗胀瘴障招昭找沼赵照罩兆肇召遮折哲蛰辙者锗蔗这浙珍斟真甄砧臻贞针侦枕疹诊震振镇阵蒸挣睁征狰争怔整拯正政帧症郑证芝枝支吱蜘知肢脂汁之织职直植殖执值侄址指止趾只旨纸志挚掷至致置帜峙制智秩稚质炙痔滞治窒中盅忠钟衷终种肿重仲众舟周州洲诌粥轴肘帚咒皱宙昼骤珠株蛛朱猪诸诛逐竹烛煮拄瞩嘱主著柱助蛀贮铸筑住注祝驻抓爪拽专砖转撰赚篆桩庄装妆撞壮状椎锥追赘坠缀谆准捉拙卓桌琢茁酌啄着灼浊兹咨资姿滋淄孜紫仔籽滓子自渍字鬃棕踪宗综总纵邹走奏揍租足卒族祖诅阻组钻纂嘴醉最罪尊遵昨左佐柞做作坐座1234567890ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 369 | # opt.character = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' 370 | """ Seed and GPU setting """ 371 | # print("Random Seed: ", opt.manualSeed) 372 | random.seed(opt.manualSeed) 373 | np.random.seed(opt.manualSeed) 374 | torch.manual_seed(opt.manualSeed) 375 | torch.cuda.manual_seed(opt.manualSeed) 376 | 377 | cudnn.benchmark = True 378 | cudnn.deterministic = True 379 | opt.num_gpu = torch.cuda.device_count() 380 | print('device count', opt.num_gpu) 381 | opt.num_gpu = 1 382 | 383 | if opt.num_gpu > 1: 384 | print('------ Use multi-GPU setting ------') 385 | print('if you stuck too long time with multi-GPU setting, try to set --workers 0') 386 | # check multi-GPU issue https://github.com/clovaai/deep-text-recognition-benchmark/issues/1 387 | opt.workers = opt.workers * opt.num_gpu 388 | opt.batch_size = opt.batch_size * opt.num_gpu 389 | 390 | """ previous version 391 | print('To equlize batch stats to 1-GPU setting, the batch_size is multiplied with num_gpu and multiplied batch_size is ', opt.batch_size) 392 | opt.batch_size = opt.batch_size * opt.num_gpu 393 | print('To equalize the number of epochs to 1-GPU setting, num_iter is divided with num_gpu by default.') 394 | If you dont care about it, just commnet out these line.) 395 | opt.num_iter = int(opt.num_iter / opt.num_gpu) 396 | """ 397 | 398 | train(opt) 399 | -------------------------------------------------------------------------------- /meta_self_learning.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import os 3 | import sys 4 | import time 5 | import random 6 | import string 7 | import argparse 8 | import visdom 9 | from copy import deepcopy 10 | from collections import OrderedDict 11 | 12 | import torch 13 | import torch.backends.cudnn as cudnn 14 | import torch.nn.init as init 15 | import torch.optim as optim 16 | import torch.utils.data 17 | import numpy as np 18 | 19 | from utils import CTCLabelConverter, CTCLabelConverterForBaiduWarpctc, AttnLabelConverter, Averager 20 | from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset, Batch_Balanced_Dataset0, SelfTrainingDataset, self_training_collate 21 | from model import Model 22 | from test import validation 23 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 24 | 25 | 26 | def create_vis_plot(_xlabel, _ylabel, _title, _legend, viz): 27 | return viz.line( 28 | X=torch.zeros((1,)).cpu(), 29 | Y=torch.zeros((1, 1)).cpu(), 30 | opts=dict( 31 | xlabel=_xlabel, 32 | ylabel=_ylabel, 33 | title=_title, 34 | legend=_legend 35 | ) 36 | ) 37 | 38 | def update_vis_plot(iteration, loc, window1, update_type, viz): 39 | viz.line( 40 | X=torch.ones((1, 1)).cpu() * iteration, 41 | Y=torch.Tensor([loc]).unsqueeze(0).cpu(), 42 | win=window1, 43 | update=update_type 44 | ) 45 | 46 | def train(opt): 47 | 48 | """ dataset preparation """ 49 | if not opt.data_filtering_off: 50 | print('Filtering the images containing characters which are not in opt.character') 51 | print('Filtering the images whose label is longer than opt.batch_max_length') 52 | # see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L130 53 | 54 | opt.select_data = opt.select_data.split('-') 55 | opt.batch_ratio = opt.batch_ratio.split('-') 56 | train_dataset = Batch_Balanced_Dataset(opt) 57 | 58 | log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') 59 | AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 60 | AlignCollate_test = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 61 | test_dataset, test_dataset_log = hierarchical_dataset(root=opt.test_data, opt=opt) 62 | test_loader = torch.utils.data.DataLoader( 63 | test_dataset, batch_size=opt.batch_size, 64 | shuffle=True, # 'True' to check training progress with validation function. 65 | num_workers=int(opt.workers), 66 | collate_fn=AlignCollate_test, pin_memory=True) 67 | valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt, pseudo=True) 68 | valid_loader = torch.utils.data.DataLoader( 69 | valid_dataset, batch_size=opt.batch_size, 70 | shuffle=True, # 'True' to check training progress with validation function. 71 | num_workers=int(opt.workers), 72 | collate_fn=AlignCollate_valid, pin_memory=True) 73 | log.write(valid_dataset_log) 74 | print('-' * 80) 75 | log.write('-' * 80 + '\n') 76 | log.close() 77 | 78 | self_training_dataset = SelfTrainingDataset([], []) 79 | self_training_loader = None 80 | 81 | """ model configuration """ 82 | if 'CTC' in opt.Prediction: 83 | if opt.baiduCTC: 84 | converter = CTCLabelConverterForBaiduWarpctc(opt.character) 85 | else: 86 | converter = CTCLabelConverter(opt.character) 87 | else: 88 | converter = AttnLabelConverter(opt.character) 89 | opt.num_class = len(converter.character) 90 | 91 | if opt.rgb: 92 | opt.input_channel = 3 93 | model = Model(opt) 94 | print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, 95 | opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, 96 | opt.SequenceModeling, opt.Prediction) 97 | 98 | # weight initialization 99 | for name, param in model.named_parameters(): 100 | if 'localization_fc2' in name: 101 | print(f'Skip {name} as it is already initialized') 102 | continue 103 | try: 104 | if 'bias' in name: 105 | init.constant_(param, 0.0) 106 | elif 'weight' in name: 107 | init.kaiming_normal_(param) 108 | except Exception as e: # for batchnorm. 109 | if 'weight' in name: 110 | param.data.fill_(1) 111 | continue 112 | 113 | # data parallel for multi-GPU 114 | model = torch.nn.DataParallel(model).to(device) 115 | model.train() 116 | if opt.saved_model != '': 117 | print(f'loading pretrained model from {opt.saved_model}') 118 | if opt.FT: 119 | model.load_state_dict(torch.load(opt.saved_model), strict=False) 120 | else: 121 | model.load_state_dict(torch.load(opt.saved_model)) 122 | print("Model:") 123 | print(model) 124 | 125 | """ setup loss """ 126 | if 'CTC' in opt.Prediction: 127 | if opt.baiduCTC: 128 | # need to install warpctc. see our guideline. 129 | from warpctc_pytorch import CTCLoss 130 | criterion = CTCLoss() 131 | else: 132 | criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) 133 | else: 134 | criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 忽略0类别,所有长度不足25的标签均用0进行填充 135 | # loss averager 136 | loss_avg = Averager() 137 | 138 | # filter that only require gradient decent 139 | filtered_parameters = [] 140 | params_num = [] 141 | 142 | for p in filter(lambda p: p.requires_grad, model.parameters()): 143 | filtered_parameters.append(p) 144 | params_num.append(np.prod(p.size())) 145 | print('Trainable params num : ', sum(params_num)) 146 | 147 | # setup optimizer 148 | if opt.adam: 149 | optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) 150 | else: 151 | optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) 152 | print("Optimizer:") 153 | print(optimizer) 154 | 155 | """ final options """ 156 | with open(f'./saved_models/{opt.exp_name}/opt.txt', 'a') as opt_file: 157 | opt_log = '------------ Options -------------\n' 158 | args = vars(opt) 159 | for k, v in args.items(): 160 | opt_log += f'{str(k)}: {str(v)}\n' 161 | opt_log += '---------------------------------------\n' 162 | print(opt_log) 163 | opt_file.write(opt_log) 164 | 165 | """ start training """ 166 | start_iter = 0 167 | if opt.saved_model != '': 168 | try: 169 | start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) 170 | print(f'continue to train, start_iter: {start_iter}') 171 | except: 172 | if opt.continue_epoch != -1: start_iter=opt.continue_epoch 173 | pass 174 | 175 | start_time = time.time() 176 | best_accuracy = -1 177 | best_norm_ED = -1 178 | test_best_accuracy = -1 179 | test_best_norm_ED = -1 180 | iteration = start_iter 181 | beta = opt.beta 182 | 183 | while(True): 184 | if train_dataset.has_pseudo_label_dataset(): meta_target_index = random.randint(0, opt.source_num) 185 | else: meta_target_index = random.randint(0, opt.source_num - 1) 186 | 187 | meta_target_index = random.randint(0, opt.source_num) 188 | if train_dataset.has_pseudo_label_dataset(): meta_target_index = opt.source_num 189 | 190 | 191 | old_state_dict = deepcopy(model.state_dict()) 192 | weight_buffer, buffer_buffer = [], [] 193 | first_order_weight_buffer, first_order_buffer_buffer = [], [] 194 | inner_optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4) 195 | 196 | for mix_iter in range(opt.inner_loop_iter + 1): 197 | if mix_iter == opt.inner_loop_iter: 198 | image_tensors, labels = train_dataset.get_meta_test_batch(meta_target_index) 199 | else: image_tensors, labels = train_dataset.get_batch(meta_target_index) 200 | image = image_tensors.to(device) 201 | text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) 202 | batch_size = image.size(0) 203 | if 'CTC' in opt.Prediction: 204 | preds = model(image, text) 205 | preds_size = torch.IntTensor([preds.size(1)] * batch_size) 206 | if opt.baiduCTC: 207 | preds = preds.permute(1, 0, 2) # to use CTCLoss format 208 | cost = criterion(preds, text, preds_size, length) / batch_size 209 | else: 210 | preds = preds.log_softmax(2).permute(1, 0, 2) 211 | cost = criterion(preds, text, preds_size, length) 212 | 213 | else: 214 | preds = model(image, text[:, :-1]) # align with Attention.forward 215 | target = text[:, 1:] # without [GO] Symbol 216 | cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) 217 | 218 | model.zero_grad() 219 | cost.backward() 220 | torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) 221 | inner_optimizer.step() 222 | loss_avg.add(cost) 223 | 224 | if mix_iter == opt.inner_loop_iter - 1: 225 | first_order_weight_buffer.extend([deepcopy(p.grad) for p in model.parameters()]) 226 | first_order_buffer_buffer.extend([deepcopy(p) for p in model.buffers()]) 227 | , 228 | if mix_iter == opt.inner_loop_iter: 229 | weight_buffer.extend([deepcopy(p.grad)for p in model.parameters()]) 230 | buffer_buffer.extend([deepcopy(p) for p in model.buffers()]) 231 | break 232 | model.load_state_dict(old_state_dict) 233 | 234 | old_param = list(map(lambda p: p, model.parameters())) 235 | for old, new, first_order in zip(old_param, weight_buffer, first_order_weight_buffer): 236 | old.grad = new 237 | 238 | index = 0 239 | buffer_index = 0 240 | new_weights_dict = OrderedDict() 241 | 242 | for key, param in model.state_dict().items(): 243 | if key.find('running') != -1 or key.find('tracked') != -1 or key.find('delta') != -1 or key.find('hat') != -1: 244 | new_weights_dict[key] = buffer_buffer[buffer_index] # buffer要用inner loop中的状态,因为外部BN层没有被调用过,tracking的均值和方差都为0 245 | buffer_index += 1 246 | continue 247 | new_weights_dict[key] = old_param[index] 248 | index += 1 249 | 250 | assert index == len(old_param) 251 | assert buffer_index == len(buffer_buffer) 252 | 253 | model.load_state_dict(new_weights_dict) 254 | optimizer.step() 255 | optimizer.zero_grad() 256 | 257 | if train_dataset.has_pseudo_label_dataset(): 258 | image_tensors, labels = train_dataset.get_meta_test_batch(opt.source_num) 259 | else: 260 | image_tensors, labels = train_dataset.get_batch() 261 | image = image_tensors.to(device) 262 | text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) 263 | batch_size = image.size(0) 264 | 265 | if 'CTC' in opt.Prediction: 266 | preds = model(image, text) 267 | preds_size = torch.IntTensor([preds.size(1)] * batch_size) 268 | if opt.baiduCTC: 269 | preds = preds.permute(1, 0, 2) # to use CTCLoss format 270 | cost = criterion(preds, text, preds_size, length) / batch_size 271 | else: 272 | preds = preds.log_softmax(2).permute(1, 0, 2) 273 | cost = criterion(preds, text, preds_size, length) 274 | 275 | else: 276 | preds = model(image, text[:, :-1]) # align with Attention.forward 277 | target = text[:, 1:] # without [GO] Symbol 278 | cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) 279 | 280 | model.zero_grad() 281 | cost.backward() 282 | torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) # gradient clipping with 5 (Default) 283 | optimizer.step() 284 | 285 | if (iteration + 1) % opt.valInterval == 0 or iteration == 0: # To see training progress, we also conduct validation when 'iteration == 0' 286 | elapsed_time = time.time() - start_time 287 | # for log 288 | with open(f'./saved_models/{opt.exp_name}/log_train.txt', 'a') as log: 289 | model.eval() 290 | with torch.no_grad(): 291 | valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data, all_imgs, all_labels, all_confidences, all_gts = validation( 292 | model, criterion, valid_loader, converter, opt, self_training=True) 293 | test_loss, test_current_accuracy, test_current_norm_ED, test_preds, test_confidence_score, test_labels, test_infer_time, test_length_of_data = validation( 294 | model, criterion, test_loader, converter, opt) 295 | model.train() 296 | print(len(all_imgs), len(all_labels), len(all_confidences)) 297 | print(all_imgs[0].shape) 298 | # training loss and validation loss 299 | loss_log = f'[{iteration+1}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' 300 | current_test_log = f'{"Test_Current_accuracy":17s}: {test_current_accuracy:0.3f}, {"Test_Current_norm_ED":17s}: {test_current_norm_ED:0.2f}' 301 | loss_avg.reset() 302 | 303 | current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}' 304 | 305 | if current_accuracy > best_accuracy: 306 | best_accuracy = current_accuracy 307 | torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_accuracy.pth') 308 | if current_norm_ED > best_norm_ED: 309 | best_norm_ED = current_norm_ED 310 | torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/best_norm_ED.pth') 311 | if test_current_accuracy > test_best_accuracy: 312 | test_best_accuracy = test_current_accuracy 313 | torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/test_best_accuracy.pth') 314 | if test_current_norm_ED > test_best_norm_ED: 315 | test_best_norm_ED = test_current_norm_ED 316 | torch.save(model.state_dict(), f'./saved_models/{opt.exp_name}/test_best_norm_ED.pth') 317 | best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}' 318 | test_best_model_log = f'{"Test_Best_accuracy":17s}: {test_best_accuracy:0.3f}, {"Test_Best_norm_ED":17s}: {test_best_norm_ED:0.2f}' 319 | loss_model_log = f'{loss_log}\n{current_model_log}\n{current_test_log}\n{best_model_log}\n{test_best_model_log}' 320 | print(loss_model_log) 321 | log.write(loss_model_log + '\n') 322 | 323 | 324 | dashed_line = '-' * 80 325 | head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' 326 | predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' 327 | 328 | self_training_imgs = [] 329 | self_training_labels = [] 330 | right_pseudo_label_counts = 0 331 | 332 | for conf, img, label, gt in zip(all_confidences, all_imgs, all_labels, all_gts): 333 | if conf < opt.pseudo_threshold: continue 334 | label = label[: label.find('[s]')] 335 | self_training_imgs.append(img) 336 | self_training_labels.append(label) 337 | if gt == label: right_pseudo_label_counts += 1 338 | 339 | self_training_dataset_length = len(self_training_imgs) 340 | if self_training_dataset_length > 0 and current_accuracy >= opt.warmup_threshold: 341 | pseudo_log = 'Generating Self-Training Dataset, contains {} imgs, {} of them have right labels'.format(self_training_dataset_length, right_pseudo_label_counts) 342 | print(pseudo_log) 343 | log.write(pseudo_log + '\n') 344 | self_training_dataset = SelfTrainingDataset(self_training_imgs, self_training_labels) 345 | train_dataset.add_pseudo_label_dataset(self_training_dataset, opt) 346 | 347 | 348 | for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]): 349 | if 'Attn' in opt.Prediction: 350 | gt = gt[:gt.find('[s]')] 351 | pred = pred[:pred.find('[s]')] 352 | 353 | predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' 354 | predicted_result_log += f'{dashed_line}' 355 | print(predicted_result_log) 356 | log.write(predicted_result_log + '\n') 357 | 358 | # save model per 1e+5 iter. 359 | if (iteration + 1) % 1e+5 == 0: 360 | torch.save( 361 | model.state_dict(), f'./saved_models/{opt.exp_name}/iter_{iteration+1}.pth') 362 | 363 | if (iteration + 1) == opt.num_iter: 364 | print('end the training') 365 | sys.exit() 366 | iteration += 1 367 | 368 | 369 | if __name__ == '__main__': 370 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 371 | parser = argparse.ArgumentParser() 372 | parser.add_argument('--exp_name', help='Where to store logs and models') 373 | parser.add_argument('--train_data', required=True, help='path to training dataset') 374 | parser.add_argument('--valid_data', required=True, help='path to validation dataset') 375 | parser.add_argument('--test_data', required=True, help='path to test dataset') 376 | parser.add_argument('--manualSeed', type=int, default=1111, help='for random seed setting') 377 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=2) 378 | parser.add_argument('--batch_size', type=int, default=192, help='input batch size') 379 | parser.add_argument('--num_iter', type=int, default=300000, help='number of iterations to train for') 380 | parser.add_argument('--valInterval', type=int, default=2000, help='Interval between each validation') 381 | parser.add_argument('--saved_model', default='', help="path to model to continue training") 382 | parser.add_argument('--FT', action='store_true', help='whether to do fine-tuning') 383 | parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is Adadelta)') 384 | parser.add_argument('--lr', type=float, default=1, help='learning rate, default=1.0 for Adadelta') 385 | parser.add_argument('--beta', type=float, default=1, help='beta to weigh the second order derivative loss') 386 | parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for adam. default=0.9') 387 | parser.add_argument('--rho', type=float, default=0.95, help='decay rate rho for Adadelta. default=0.95') 388 | parser.add_argument('--eps', type=float, default=1e-8, help='eps for Adadelta. default=1e-8') 389 | parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping value. default=5') 390 | parser.add_argument('--baiduCTC', action='store_true', help='for data_filtering_off mode') 391 | """ Data processing """ 392 | parser.add_argument('--select_data', type=str, default='MJ-ST', 393 | help='select training data (default is MJ-ST, which means MJ and ST used as training data)') 394 | parser.add_argument('--batch_ratio', type=str, default='1', 395 | help='assign ratio for each selected data in the batch 0.5-0.5') 396 | parser.add_argument('--total_data_usage_ratio', type=str, default='1.0', 397 | help='total data usage ratio, this ratio is multiplied to total number of data.') 398 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') 399 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') 400 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') 401 | parser.add_argument('--rgb', action='store_true', help='use rgb input') 402 | parser.add_argument('--character', type=str, 403 | default='0123456789一下东丰云亚亨亮会佑佛佳俊信八关兴凯利力勤华南县发号君和圩城壹大天宁宇安宏宾富封州工帆平年广庆建开德恒惠成新昌明机权来柳桂梧森横永江沙河油泰泽洋浮海润清港湖源滨珠田盈益盛石祥福程穗粤翔肇航良英藤行衡西诚诺谢谷货贵辉达运远途通都金长阳雄韶顺颜风飞香鸿鼎龙', help='character label 0123456789abcdefghijklmnopqrstuvwxyz') 404 | parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode') 405 | parser.add_argument('--PAD', action='store_true', help='whether to keep ratio then pad for image resize') 406 | parser.add_argument('--data_filtering_off', action='store_true', help='for data_filtering_off mode') 407 | """ Model Architecture """ 408 | parser.add_argument('--Transformation', type=str, required=True, help='Transformation stage. None|TPS') 409 | parser.add_argument('--FeatureExtraction', type=str, required=True, 410 | help='FeatureExtraction stage. VGG|RCNN|ResNet') 411 | parser.add_argument('--SequenceModeling', type=str, required=True, help='SequenceModeling stage. None|BiLSTM') 412 | parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn') 413 | parser.add_argument('--num_fiducial', type=int, default=20, help='number of fiducial points of TPS-STN') 414 | parser.add_argument('--input_channel', type=int, default=1, 415 | help='the number of input channel of Feature extractor') 416 | parser.add_argument('--output_channel', type=int, default=512, 417 | help='the number of output channel of Feature extractor') 418 | parser.add_argument('--hidden_size', type=int, default=256, help='the size of the LSTM hidden state') 419 | parser.add_argument('--fix_dataset_num', type=int, default=-1, help='fix the number of imgs per dataset') 420 | parser.add_argument('--pseudo_dataset_num', type=int, default=-1, help='fix the number of imgs per dataset') 421 | parser.add_argument('--pseudo_threshold', type=float, default=0.8, help='The threshold of using pseudo labels') 422 | parser.add_argument('--warmup_threshold', type=float, default=27.5, help='The threshold of using pseudo labels') 423 | parser.add_argument('--source_num', type=int, default=3, help='the number of source domain dataset') 424 | parser.add_argument('--inner_loop_iter', type=int, default=2, help='the iteration of inner loop') 425 | parser.add_argument('--expsuffix', type=str, default='') 426 | parser.add_argument('--continue_epoch', type=int, default=-1) 427 | opt = parser.parse_args() 428 | 429 | if not opt.exp_name: 430 | opt.exp_name = f'{opt.Transformation}-{opt.FeatureExtraction}-{opt.SequenceModeling}-{opt.Prediction}' 431 | opt.exp_name += f'-Seed{opt.manualSeed}' 432 | assert opt.expsuffix, 'You should specify the exp suffix' 433 | opt.exp_name += f'-{opt.expsuffix}' 434 | 435 | os.makedirs(f'./saved_models/{opt.exp_name}', exist_ok=True) 436 | 437 | """ vocab / character number configuration """ 438 | if opt.sensitive: 439 | opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). 440 | 441 | opt.character = u"啊阿埃挨哎唉哀皑癌蔼矮艾碍爱隘鞍氨安俺按暗岸胺案肮昂盎凹敖熬翱袄傲奥懊澳芭捌扒叭吧笆八疤巴拔跋靶把耙坝霸罢爸白柏百摆佰败拜稗斑班搬扳般颁板版扮拌伴瓣半办绊邦帮梆榜膀绑棒磅蚌镑傍谤苞胞包褒剥薄雹保堡饱宝抱报暴豹鲍爆杯碑悲卑北辈背贝钡倍狈备惫焙被奔苯本笨崩绷甭泵蹦迸逼鼻比鄙笔彼碧蓖蔽毕毙毖币庇痹闭敝弊必辟壁臂避陛鞭边编贬扁便变卞辨辩辫遍标彪膘表鳖憋别瘪彬斌濒滨宾摈兵冰柄丙秉饼炳病并玻菠播拨钵波博勃搏铂箔伯帛舶脖膊渤泊驳捕卜哺补埠不布步簿部怖擦猜裁材才财睬踩采彩菜蔡餐参蚕残惭惨灿苍舱仓沧藏操糙槽曹草厕策侧册测层蹭插叉茬茶查碴搽察岔差诧拆柴豺搀掺蝉馋谗缠铲产阐颤昌猖场尝常长偿肠厂敞畅唱倡超抄钞朝嘲潮巢吵炒车扯撤掣彻澈郴臣辰尘晨忱沉陈趁衬撑称城橙成呈乘程惩澄诚承逞骋秤吃痴持匙池迟弛驰耻齿侈尺赤翅斥炽充冲虫崇宠抽酬畴踌稠愁筹仇绸瞅丑臭初出橱厨躇锄雏滁除楚础储矗搐触处揣川穿椽传船喘串疮窗幢床闯创吹炊捶锤垂春椿醇唇淳纯蠢戳绰疵茨磁雌辞慈瓷词此刺赐次聪葱囱匆从丛凑粗醋簇促蹿篡窜摧崔催脆瘁粹淬翠村存寸磋撮搓措挫错搭达答瘩打大呆歹傣戴带殆代贷袋待逮怠耽担丹单郸掸胆旦氮但惮淡诞弹蛋当挡党荡档刀捣蹈倒岛祷导到稻悼道盗德得的蹬灯登等瞪凳邓堤低滴迪敌笛狄涤翟嫡抵底地蒂第帝弟递缔颠掂滇碘点典靛垫电佃甸店惦奠淀殿碉叼雕凋刁掉吊钓调跌爹碟蝶迭谍叠丁盯叮钉顶鼎锭定订丢东冬董懂动栋侗恫冻洞兜抖斗陡豆逗痘都督毒犊独读堵睹赌杜镀肚度渡妒端短锻段断缎堆兑队对墩吨蹲敦顿囤钝盾遁掇哆多夺垛躲朵跺舵剁惰堕蛾峨鹅俄额讹娥恶厄扼遏鄂饿恩而儿耳尔饵洱二贰发罚筏伐乏阀法珐藩帆番翻樊矾钒繁凡烦反返范贩犯饭泛坊芳方肪房防妨仿访纺放菲非啡飞肥匪诽吠肺废沸费芬酚吩氛分纷坟焚汾粉奋份忿愤粪丰封枫蜂峰锋风疯烽逢冯缝讽奉凤佛否夫敷肤孵扶拂辐幅氟符伏俘服浮涪福袱弗甫抚辅俯釜斧脯腑府腐赴副覆赋复傅付阜父腹负富讣附妇缚咐噶嘎该改概钙盖溉干甘杆柑竿肝赶感秆敢赣冈刚钢缸肛纲岗港杠篙皋高膏羔糕搞镐稿告哥歌搁戈鸽胳疙割革葛格蛤阁隔铬个各给根跟耕更庚羹埂耿梗工攻功恭龚供躬公宫弓巩汞拱贡共钩勾沟苟狗垢构购够辜菇咕箍估沽孤姑鼓古蛊骨谷股故顾固雇刮瓜剐寡挂褂乖拐怪棺关官冠观管馆罐惯灌贯光广逛瑰规圭硅归龟闺轨鬼诡癸桂柜跪贵刽辊滚棍锅郭国果裹过哈骸孩海氦亥害骇酣憨邯韩含涵寒函喊罕翰撼捍旱憾悍焊汗汉夯杭航壕嚎豪毫郝好耗号浩呵喝荷菏核禾和何合盒貉阂河涸赫褐鹤贺嘿黑痕很狠恨哼亨横衡恒轰哄烘虹鸿洪宏弘红喉侯猴吼厚候后呼乎忽瑚壶葫胡蝴狐糊湖弧虎唬护互沪户花哗华猾滑画划化话槐徊怀淮坏欢环桓还缓换患唤痪豢焕涣宦幻荒慌黄磺蝗簧皇凰惶煌晃幌恍谎灰挥辉徽恢蛔回毁悔慧卉惠晦贿秽会烩汇讳诲绘荤昏婚魂浑混豁活伙火获或惑霍货祸击圾基机畸稽积箕肌饥迹激讥鸡姬绩缉吉极棘辑籍集及急疾汲即嫉级挤几脊己蓟技冀季伎祭剂悸济寄寂计记既忌际妓继纪嘉枷夹佳家加荚颊贾甲钾假稼价架驾嫁歼监坚尖笺间煎兼肩艰奸缄茧检柬碱碱拣捡简俭剪减荐槛鉴践贱见键箭件健舰剑饯渐溅涧建僵姜将浆江疆蒋桨奖讲匠酱降蕉椒礁焦胶交郊浇骄娇嚼搅铰矫侥脚狡角饺缴绞剿教酵轿较叫窖揭接皆秸街阶截劫节桔杰捷睫竭洁结解姐戒藉芥界借介疥诫届巾筋斤金今津襟紧锦仅谨进靳晋禁近烬浸尽劲荆兢茎睛晶鲸京惊精粳经井警景颈静境敬镜径痉靖竟竞净炯窘揪究纠玖韭久灸九酒厩救旧臼舅咎就疚鞠拘狙疽居驹菊局咀矩举沮聚拒据巨具距踞锯俱句惧炬剧捐鹃娟倦眷卷绢撅攫抉掘倔爵觉决诀绝均菌钧军君峻俊竣浚郡骏喀咖卡咯开揩楷凯慨刊堪勘坎砍看康慷糠扛抗亢炕考拷烤靠坷苛柯棵磕颗科壳咳可渴克刻客课肯啃垦恳坑吭空恐孔控抠口扣寇枯哭窟苦酷库裤夸垮挎跨胯块筷侩快宽款匡筐狂框矿眶旷况亏盔岿窥葵奎魁傀馈愧溃坤昆捆困括扩廓阔垃拉喇蜡腊辣啦莱来赖蓝婪栏拦篮阑兰澜谰揽览懒缆烂滥琅榔狼廊郎朗浪捞劳牢老佬姥酪烙涝勒乐雷镭蕾磊累儡垒擂肋类泪棱楞冷厘梨犁黎篱狸离漓理李里鲤礼莉荔吏栗丽厉励砾历利僳例俐痢立粒沥隶力璃哩俩联莲连镰廉怜涟帘敛脸链恋炼练粮凉梁粱良两辆量晾亮谅撩聊僚疗燎寥辽潦了撂镣廖料列裂烈劣猎琳林磷霖临邻鳞淋凛赁吝拎玲菱零龄铃伶羚凌灵陵岭领另令溜琉榴硫馏留刘瘤流柳六龙聋咙笼窿隆垄拢陇楼娄搂篓漏陋芦卢颅庐炉掳卤虏鲁麓碌露路赂鹿潞禄录陆戮驴吕铝侣旅履屡缕虑氯律率滤绿峦挛孪滦卵乱掠略抡轮伦仑沦纶论萝螺罗逻锣箩骡裸落洛骆络妈麻玛码蚂马骂嘛吗埋买麦卖迈脉瞒馒蛮满蔓曼慢漫谩芒茫盲氓忙莽猫茅锚毛矛铆卯茂冒帽貌贸么玫枚梅酶霉煤没眉媒镁每美昧寐妹媚门闷们萌蒙檬盟锰猛梦孟眯醚靡糜迷谜弥米秘觅泌蜜密幂棉眠绵冕免勉娩缅面苗描瞄藐秒渺庙妙蔑灭民抿皿敏悯闽明螟鸣铭名命谬摸摹蘑模膜磨摩魔抹末莫墨默沫漠寞陌谋牟某拇牡亩姆母墓暮幕募慕木目睦牧穆拿哪呐钠那娜纳氖乃奶耐奈南男难囊挠脑恼闹淖呢馁内嫩能妮霓倪泥尼拟你匿腻逆溺蔫拈年碾撵捻念娘酿鸟尿捏聂孽啮镊镍涅您柠狞凝宁拧泞牛扭钮纽脓浓农弄奴努怒女暖虐疟挪懦糯诺哦欧鸥殴藕呕偶沤啪趴爬帕怕琶拍排牌徘湃派攀潘盘磐盼畔判叛乓庞旁耪胖抛咆刨炮袍跑泡呸胚培裴赔陪配佩沛喷盆砰抨烹澎彭蓬棚硼篷膨朋鹏捧碰坯砒霹批披劈琵毗啤脾疲皮匹痞僻屁譬篇偏片骗飘漂瓢票撇瞥拼频贫品聘乒坪苹萍平凭瓶评屏坡泼颇婆破魄迫粕剖扑铺仆莆葡菩蒲埔朴圃普浦谱曝瀑期欺栖戚妻七凄漆柒沏其棋奇歧畦崎脐齐旗祈祁骑起岂乞企启契砌器气迄弃汽泣讫掐洽牵扦钎铅千迁签仟谦乾黔钱钳前潜遣浅谴堑嵌欠歉枪呛腔羌墙蔷强抢橇锹敲悄桥瞧乔侨巧鞘撬翘峭俏窍切茄且怯窃钦侵亲秦琴勤芹擒禽寝沁青轻氢倾卿清擎晴氰情顷请庆琼穷秋丘邱球求囚酋泅趋区蛆曲躯屈驱渠取娶龋趣去圈颧权醛泉全痊拳犬券劝缺炔瘸却鹊榷确雀裙群然燃冉染瓤壤攘嚷让饶扰绕惹热壬仁人忍韧任认刃妊纫扔仍日戎茸蓉荣融熔溶容绒冗揉柔肉茹蠕儒孺如辱乳汝入褥软阮蕊瑞锐闰润若弱撒洒萨腮鳃塞赛三叁伞散桑嗓丧搔骚扫嫂瑟色涩森僧莎砂杀刹沙纱傻啥煞筛晒珊苫杉山删煽衫闪陕擅赡膳善汕扇缮墒伤商赏晌上尚裳梢捎稍烧芍勺韶少哨邵绍奢赊蛇舌舍赦摄射慑涉社设砷申呻伸身深娠绅神沈审婶甚肾慎渗声生甥牲升绳省盛剩胜圣师失狮施湿诗尸虱十石拾时什食蚀实识史矢使屎驶始式示士世柿事拭誓逝势是嗜噬适仕侍释饰氏市恃室视试收手首守寿授售受瘦兽蔬枢梳殊抒输叔舒淑疏书赎孰熟薯暑曙署蜀黍鼠属术述树束戍竖墅庶数漱恕刷耍摔衰甩帅栓拴霜双爽谁水睡税吮瞬顺舜说硕朔烁斯撕嘶思私司丝死肆寺嗣四伺似饲巳松耸怂颂送宋讼诵搜艘擞嗽苏酥俗素速粟僳塑溯宿诉肃酸蒜算虽隋随绥髓碎岁穗遂隧祟孙损笋蓑梭唆缩琐索锁所塌他它她塔獭挞蹋踏胎苔抬台泰酞太态汰坍摊贪瘫滩坛檀痰潭谭谈坦毯袒碳探叹炭汤塘搪堂棠膛唐糖倘躺淌趟烫掏涛滔绦萄桃逃淘陶讨套特藤腾疼誊梯剔踢锑提题蹄啼体替嚏惕涕剃屉天添填田甜恬舔腆挑条迢眺跳贴铁帖厅听烃汀廷停亭庭艇通桐酮瞳同铜彤童桶捅筒统痛偷投头透凸秃突图徒途涂屠土吐兔湍团推颓腿蜕褪退吞屯臀拖托脱鸵陀驮驼椭妥拓唾挖哇蛙洼娃瓦袜歪外豌弯湾玩顽丸烷完碗挽晚皖惋宛婉万腕汪王亡枉网往旺望忘妄威巍微危圩韦违桅围唯惟为潍维苇萎委伟伪尾纬未蔚味畏胃喂魏位渭谓尉慰卫瘟温蚊文闻纹吻稳紊问嗡翁瓮挝蜗涡窝我斡卧握沃巫呜钨乌污诬屋无芜梧吾吴毋武五捂午舞伍侮坞戊雾晤物勿务悟误昔熙析西硒矽晰嘻吸锡牺稀息希悉膝夕惜熄烯溪汐犀檄袭席习媳喜铣洗系隙戏细瞎虾匣霞辖暇峡侠狭下厦夏吓掀锨先仙鲜纤咸贤衔舷闲涎弦嫌显险现献县腺馅羡宪陷限线相厢镶香箱襄湘乡翔祥详想响享项巷橡像向象萧硝霄削哮嚣销消宵淆晓小孝校肖啸笑效楔些歇蝎鞋协挟携邪斜胁谐写械卸蟹懈泄泻谢屑薪芯锌欣辛新忻心信衅星腥猩惺兴刑型形邢行醒幸杏性姓兄凶胸匈汹雄熊休修羞朽嗅锈秀袖绣墟戌需虚嘘须徐许蓄酗叙旭序畜恤絮婿绪续轩喧宣悬旋玄选癣眩绚靴薛学穴雪血勋熏循旬询寻驯巡殉汛训讯逊迅压押鸦鸭呀丫芽牙蚜崖衙涯雅哑亚讶焉咽阉烟淹盐严研蜒岩延言颜阎炎沿奄掩眼衍演艳堰燕厌砚雁唁彦焰宴谚验殃央鸯秧杨扬佯疡羊洋阳氧仰痒养样漾邀腰妖瑶摇尧遥窑谣姚咬舀药要耀椰噎耶爷野冶也页掖业叶曳腋夜液一壹医揖铱依伊衣颐夷遗移仪胰疑沂宜姨彝椅蚁倚已乙矣以艺抑易邑屹亿役臆逸肄疫亦裔意毅忆义益溢诣议谊译异翼翌绎茵荫因殷音阴姻吟银淫寅饮尹引隐印英樱婴鹰应缨莹萤营荧蝇迎赢盈影颖硬映哟拥佣臃痈庸雍踊蛹咏泳涌永恿勇用幽优悠忧尤由邮铀犹油游酉有友右佑釉诱又幼迂淤于盂榆虞愚舆余俞逾鱼愉渝渔隅予娱雨与屿禹宇语羽玉域芋郁吁遇喻峪御愈欲狱育誉浴寓裕预豫驭鸳渊冤元垣袁原援辕园员圆猿源缘远苑愿怨院曰约越跃钥岳粤月悦阅耘云郧匀陨允运蕴酝晕韵孕匝砸杂栽哉灾宰载再在咱攒暂赞赃脏葬遭糟凿藻枣早澡蚤躁噪造皂灶燥责择则泽贼怎增憎曾赠扎喳渣札轧铡闸眨栅榨咋乍炸诈摘斋宅窄债寨瞻毡詹粘沾盏斩辗崭展蘸栈占战站湛绽樟章彰漳张掌涨杖丈帐账仗胀瘴障招昭找沼赵照罩兆肇召遮折哲蛰辙者锗蔗这浙珍斟真甄砧臻贞针侦枕疹诊震振镇阵蒸挣睁征狰争怔整拯正政帧症郑证芝枝支吱蜘知肢脂汁之织职直植殖执值侄址指止趾只旨纸志挚掷至致置帜峙制智秩稚质炙痔滞治窒中盅忠钟衷终种肿重仲众舟周州洲诌粥轴肘帚咒皱宙昼骤珠株蛛朱猪诸诛逐竹烛煮拄瞩嘱主著柱助蛀贮铸筑住注祝驻抓爪拽专砖转撰赚篆桩庄装妆撞壮状椎锥追赘坠缀谆准捉拙卓桌琢茁酌啄着灼浊兹咨资姿滋淄孜紫仔籽滓子自渍字鬃棕踪宗综总纵邹走奏揍租足卒族祖诅阻组钻纂嘴醉最罪尊遵昨左佐柞做作坐座1234567890ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 442 | """ Seed and GPU setting """ 443 | random.seed(opt.manualSeed) 444 | np.random.seed(opt.manualSeed) 445 | torch.manual_seed(opt.manualSeed) 446 | torch.cuda.manual_seed(opt.manualSeed) 447 | 448 | cudnn.benchmark = True 449 | cudnn.deterministic = True 450 | opt.num_gpu = torch.cuda.device_count() 451 | print('device count', opt.num_gpu) 452 | opt.num_gpu = 1 453 | 454 | if opt.num_gpu > 1: 455 | print('------ Use multi-GPU setting ------') 456 | print('if you stuck too long time with multi-GPU setting, try to set --workers 0') 457 | # check multi-GPU issue https://github.com/clovaai/deep-text-recognition-benchmark/issues/1 458 | opt.workers = opt.workers * opt.num_gpu 459 | opt.batch_size = opt.batch_size * opt.num_gpu 460 | 461 | """ previous version 462 | print('To equlize batch stats to 1-GPU setting, the batch_size is multiplied with num_gpu and multiplied batch_size is ', opt.batch_size) 463 | opt.batch_size = opt.batch_size * opt.num_gpu 464 | print('To equalize the number of epochs to 1-GPU setting, num_iter is divided with num_gpu by default.') 465 | If you dont care about it, just commnet out these line.) 466 | opt.num_iter = int(opt.num_iter / opt.num_gpu) 467 | """ 468 | 469 | train(opt) 470 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | import six 5 | import math 6 | import lmdb 7 | import torch 8 | 9 | from natsort import natsorted 10 | import itertools 11 | from PIL import Image 12 | from copy import deepcopy 13 | import numpy as np 14 | from torch.utils.data import Dataset, ConcatDataset, Subset 15 | from torch._utils import _accumulate 16 | import torchvision.transforms as transforms 17 | 18 | 19 | class Batch_Balanced_Dataset(object): 20 | 21 | def __init__(self, opt): 22 | """ 23 | Modulate the data ratio in the batch. 24 | For example, when select_data is "MJ-ST" and batch_ratio is "0.5-0.5", 25 | the 50% of the batch is filled with MJ and the other 50% of the batch is filled with ST. 26 | 27 | opt.batch_ratio: 一个batch中包含的不同数据集的比例 28 | opt.total_data_usage_ratio: 对于每一个数据及,使用这个数据集的百分之多少,默认是1(100%) 29 | """ 30 | self.opt = opt 31 | log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') 32 | dashed_line = '-' * 80 33 | print(dashed_line) 34 | log.write(dashed_line + '\n') 35 | print(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}') 36 | log.write(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}\n') 37 | assert len(opt.select_data) == len(opt.batch_ratio) 38 | 39 | # 为每个dataloader应用collate函数,直接输出一整个batch, 40 | _AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 41 | self.data_loader_list = [] 42 | self.dataloader_iter_list = [] 43 | batch_size_list = [] 44 | Total_batch_size = 0 45 | for selected_d, batch_ratio_d in zip(opt.select_data, opt.batch_ratio): 46 | _batch_size = max(round(opt.batch_size * float(batch_ratio_d)), 1) 47 | print(dashed_line) 48 | log.write(dashed_line + '\n') 49 | _dataset, _dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt, select_data=[selected_d]) 50 | total_number_dataset = len(_dataset) # 当前数据集包含的图片数量 51 | log.write(_dataset_log) 52 | 53 | """ 54 | The total number of data can be modified with opt.total_data_usage_ratio. 55 | ex) opt.total_data_usage_ratio = 1 indicates 100% usage, and 0.2 indicates 20% usage. 56 | See 4.2 section in our paper. 57 | """ 58 | number_dataset = int(total_number_dataset * float(opt.total_data_usage_ratio)) # 使用的比例 59 | if opt.fix_dataset_num != -1: number_dataset = opt.fix_dataset_num 60 | dataset_split = [number_dataset, total_number_dataset - number_dataset] # List[int] e.g. [50, 50] 61 | indices = range(total_number_dataset) 62 | 63 | # accumulate函数: _accumulate([1,2,3,4,5]) --> 1 3 6 10 15 64 | # Subset就是根据indices取一个数据集的子集,indice根据opt.total_data_usage_ratio来取值 65 | _dataset, _ = [Subset(_dataset, indices[offset - length:offset]) 66 | for offset, length in zip(_accumulate(dataset_split), dataset_split)] 67 | selected_d_log = f'num total samples of {selected_d}: {total_number_dataset} x {opt.total_data_usage_ratio} (total_data_usage_ratio) = {len(_dataset)}\n' 68 | selected_d_log += f'num samples of {selected_d} per batch: {opt.batch_size} x {float(batch_ratio_d)} (batch_ratio) = {_batch_size}' 69 | print(selected_d_log) 70 | log.write(selected_d_log + '\n') 71 | batch_size_list.append(str(_batch_size)) 72 | Total_batch_size += _batch_size 73 | 74 | _data_loader = torch.utils.data.DataLoader( 75 | _dataset, batch_size=_batch_size, 76 | shuffle=True, 77 | num_workers=int(opt.workers), 78 | collate_fn=_AlignCollate, pin_memory=False, drop_last=True) 79 | self.data_loader_list.append(_data_loader) 80 | self.dataloader_iter_list.append(iter(_data_loader)) 81 | 82 | Total_batch_size_log = f'{dashed_line}\n' 83 | batch_size_sum = '+'.join(batch_size_list) 84 | Total_batch_size_log += f'Total_batch_size: {batch_size_sum} = {Total_batch_size}\n' 85 | Total_batch_size_log += f'{dashed_line}' 86 | opt.batch_size = Total_batch_size 87 | 88 | print(Total_batch_size_log) 89 | log.write(Total_batch_size_log + '\n') 90 | log.close() 91 | 92 | def get_batch(self, meta_target_index=-1, no_pseudo=False): # 如果指定了meta_target_index,则忽略第meta_target_index个数据集 93 | balanced_batch_images = [] 94 | balanced_batch_texts = [] 95 | 96 | for i, data_loader_iter in enumerate(self.dataloader_iter_list): 97 | if i == meta_target_index: continue 98 | # 如果要求不采样伪标签数据集,且目前包含伪标签数据集则跳过 99 | if i == len(self.dataloader_iter_list) - 1 and no_pseudo and self.has_pseudo_label_dataset(): continue 100 | try: 101 | image, text = data_loader_iter.next() 102 | balanced_batch_images.append(image) 103 | balanced_batch_texts += text 104 | except StopIteration: # 如果一个数据集图片数量不够了,则重新构建迭代器进行训练 105 | self.dataloader_iter_list[i] = iter(self.data_loader_list[i]) 106 | image, text = self.dataloader_iter_list[i].next() 107 | balanced_batch_images.append(image) 108 | balanced_batch_texts += text 109 | except ValueError: 110 | pass 111 | 112 | balanced_batch_images = torch.cat(balanced_batch_images, 0) 113 | 114 | return balanced_batch_images, balanced_batch_texts 115 | 116 | def get_meta_test_batch(self, meta_target_index=-1): # 如果指定了meta_target_index,则忽略第meta_target_index个数据集 117 | 118 | if meta_target_index == self.opt.source_num: 119 | assert len(self.data_loader_list) == self.opt.source_num + 1, 'There is no target dataset' 120 | balanced_batch_images = [] 121 | balanced_batch_texts = [] 122 | 123 | for i, data_loader_iter in enumerate(self.dataloader_iter_list): 124 | if i == meta_target_index: 125 | try: 126 | image, text = data_loader_iter.next() 127 | balanced_batch_images.append(image) 128 | balanced_batch_texts += text 129 | except StopIteration: # 如果一个数据集图片数量不够了,则重新构建迭代器进行训练 130 | self.dataloader_iter_list[i] = iter(self.data_loader_list[i]) 131 | image, text = self.dataloader_iter_list[i].next() 132 | balanced_batch_images.append(image) 133 | balanced_batch_texts += text 134 | except ValueError: 135 | pass 136 | # print(balanced_batch_images[0].shape) 137 | balanced_batch_images = torch.cat(balanced_batch_images, 0) 138 | 139 | return balanced_batch_images, balanced_batch_texts 140 | 141 | def add_target_domain_dataset(self, dataset, opt): 142 | _AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 143 | avg_batch_size = opt.batch_size // opt.source_num 144 | batch_size = len(dataset) if len(dataset) <= avg_batch_size else avg_batch_size 145 | self_training_loader = torch.utils.data.DataLoader( 146 | dataset, batch_size=batch_size, 147 | shuffle=True, # 'True' to check training progress with validation function. 148 | num_workers=int(opt.workers), pin_memory=False, collate_fn=_AlignCollate, drop_last=True) 149 | if self.has_pseudo_label_dataset(): 150 | self.data_loader_list[opt.source_num] = self_training_loader 151 | self.dataloader_iter_list[opt.source_num] = (iter(self_training_loader)) 152 | else: 153 | self.data_loader_list.append(self_training_loader) 154 | self.dataloader_iter_list.append(iter(self_training_loader)) 155 | 156 | def add_pseudo_label_dataset(self, dataset, opt): 157 | avg_batch_size = opt.batch_size // opt.source_num 158 | batch_size = len(dataset) if len(dataset) <= avg_batch_size else avg_batch_size 159 | self_training_loader = torch.utils.data.DataLoader( 160 | dataset, batch_size=batch_size, 161 | shuffle=True, # 'True' to check training progress with validation function. 162 | num_workers=int(opt.workers), pin_memory=False, collate_fn=self_training_collate) 163 | if self.has_pseudo_label_dataset(): 164 | self.data_loader_list[opt.source_num] = self_training_loader 165 | self.dataloader_iter_list[opt.source_num] = (iter(self_training_loader)) 166 | else: 167 | self.data_loader_list.append(self_training_loader) 168 | self.dataloader_iter_list.append(iter(self_training_loader)) 169 | 170 | def add_residual_pseudo_label_dataset(self, dataset, opt): 171 | avg_batch_size = opt.batch_size // opt.source_num 172 | batch_size = len(dataset) if len(dataset) <= avg_batch_size else avg_batch_size 173 | self_training_loader = torch.utils.data.DataLoader( 174 | dataset, batch_size=batch_size, 175 | shuffle=True, # 'True' to check training progress with validation function. 176 | num_workers=int(opt.workers), pin_memory=False, collate_fn=self_training_collate) 177 | if self.has_residual_pseudo_label_dataset(): 178 | self.data_loader_list[opt.source_num + 1] = self_training_loader 179 | self.dataloader_iter_list[opt.source_num + 1] = (iter(self_training_loader)) 180 | else: 181 | self.data_loader_list.append(self_training_loader) 182 | self.dataloader_iter_list.append(iter(self_training_loader)) 183 | 184 | def has_pseudo_label_dataset(self): 185 | return True if len(self.data_loader_list) > self.opt.source_num else False 186 | 187 | def has_residual_pseudo_label_dataset(self): 188 | return True if len(self.data_loader_list) > self.opt.source_num + 1 else False 189 | 190 | class Batch_Balanced_Sampler(object): 191 | def __init__(self, dataset_len, batch_size): 192 | dataset_len.insert(0,0) 193 | self.dataset_len = dataset_len 194 | self.start_index = list(itertools.accumulate(self.dataset_len))[:-1] 195 | self.batch_size = batch_size # 每个子数据集的batchsize 196 | self.counter = 0 197 | 198 | def __len__(self): 199 | return self.dataset_len 200 | 201 | def __iter__(self): 202 | data_index = [] 203 | while True: 204 | for i in range(len(self.start_index)): 205 | data_index.extend([self.start_index[i] + (self.counter * self.batch_size + j) % self.dataset_len[i + 1] for j in range(self.batch_size)]) 206 | yield data_index 207 | data_index = [] 208 | self.counter += 1 209 | 210 | 211 | class Batch_Balanced_Dataset0(object): 212 | 213 | def __init__(self, opt): 214 | """ 215 | Modulate the data ratio in the batch. 216 | For example, when select_data is "MJ-ST" and batch_ratio is "0.5-0.5", 217 | the 50% of the batch is filled with MJ and the other 50% of the batch is filled with ST. 218 | 219 | opt.batch_ratio: 一个batch中包含的不同数据集的比例 220 | opt.total_data_usage_ratio: 对于每一个数据及,使用这个数据集的百分之多少,默认是1(100%) 221 | """ 222 | self.opt = opt 223 | log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a') 224 | dashed_line = '-' * 80 225 | print(dashed_line) 226 | log.write(dashed_line + '\n') 227 | print(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}') 228 | log.write(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}\n') 229 | assert len(opt.select_data) == len(opt.batch_ratio) 230 | 231 | # 为每个dataloader应用collate函数,直接输出一整个batch, 232 | _AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 233 | self.data_loader_list = [] 234 | self.dataloader_iter_list = [] 235 | self.batch_size_list = [] 236 | Total_batch_size = 0 237 | 238 | self.dataset_list = [] 239 | self.dataset_len_list = [] 240 | 241 | self.pseudo_dataloader = None 242 | self.pseudo_batch_size = -1 243 | 244 | for selected_d, batch_ratio_d in zip(opt.select_data, opt.batch_ratio): 245 | _batch_size = max(round(opt.batch_size * float(batch_ratio_d)), 1) 246 | print(dashed_line) 247 | log.write(dashed_line + '\n') 248 | _dataset, _dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt, select_data=[selected_d]) 249 | total_number_dataset = len(_dataset) # 当前数据集包含的图片数量 250 | 251 | 252 | log.write(_dataset_log) 253 | 254 | """ 255 | The total number of data can be modified with opt.total_data_usage_ratio. 256 | ex) opt.total_data_usage_ratio = 1 indicates 100% usage, and 0.2 indicates 20% usage. 257 | See 4.2 section in our paper. 258 | """ 259 | number_dataset = int(total_number_dataset * float(opt.total_data_usage_ratio)) # 使用的比例 260 | if opt.fix_dataset_num != -1: number_dataset = opt.fix_dataset_num 261 | dataset_split = [number_dataset, total_number_dataset - number_dataset] # List[int] e.g. [50, 50] 262 | indices = range(total_number_dataset) 263 | 264 | # accumulate函数: _accumulate([1,2,3,4,5]) --> 1 3 6 10 15 265 | # Subset就是根据indices取一个数据集的子集,indice根据opt.total_data_usage_ratio来取值 266 | _dataset, _ = [Subset(_dataset, indices[offset - length:offset]) 267 | for offset, length in zip(_accumulate(dataset_split), dataset_split)] 268 | selected_d_log = f'num total samples of {selected_d}: {total_number_dataset} x {opt.total_data_usage_ratio} (total_data_usage_ratio) = {len(_dataset)}\n' 269 | selected_d_log += f'num samples of {selected_d} per batch: {opt.batch_size} x {float(batch_ratio_d)} (batch_ratio) = {_batch_size}' 270 | print(selected_d_log) 271 | log.write(selected_d_log + '\n') 272 | self.batch_size_list.append(str(_batch_size)) 273 | Total_batch_size += _batch_size 274 | 275 | self.dataset_list.append(_dataset) 276 | self.dataset_len_list.append(number_dataset) 277 | 278 | 279 | 280 | concatenated_dataset = ConcatDataset(self.dataset_list) 281 | assert len(concatenated_dataset) == sum(self.dataset_len_list) 282 | 283 | batch_sampler = Batch_Balanced_Sampler(self.dataset_len_list, _batch_size) 284 | self.data_loader = iter(torch.utils.data.DataLoader( 285 | concatenated_dataset, 286 | batch_sampler=batch_sampler, 287 | num_workers=int(opt.workers), 288 | collate_fn=_AlignCollate, pin_memory=False)) 289 | 290 | Total_batch_size_log = f'{dashed_line}\n' 291 | batch_size_sum = '+'.join(self.batch_size_list) 292 | self.batch_size_list = list(map(int, self.batch_size_list)) 293 | Total_batch_size_log += f'Total_batch_size: {batch_size_sum} = {Total_batch_size}\n' 294 | Total_batch_size_log += f'{dashed_line}' 295 | opt.batch_size = Total_batch_size 296 | 297 | print(Total_batch_size_log) 298 | log.write(Total_batch_size_log + '\n') 299 | log.close() 300 | 301 | def get_batch(self, meta_target_index=-1, no_pseudo=False): # 如果指定了meta_target_index,则忽略第meta_target_index个数据集 302 | 303 | imgs, texts = next(self.data_loader) 304 | # 如果未指定或指定为伪标签数据集,则直接返回所有 305 | if meta_target_index == -1 or meta_target_index >= len(self.batch_size_list): return imgs, texts 306 | start_index_list = list(itertools.accumulate(self.batch_size_list)) 307 | start_index_list.insert(0, 0) 308 | 309 | ret_imgs, ret_texts = [], [] 310 | for i in range(len(self.batch_size_list)): 311 | if i == meta_target_index: continue 312 | ret_imgs.extend(imgs[start_index_list[i] : start_index_list[i] + self.batch_size_list[i]]) 313 | ret_texts.extend(texts[start_index_list[i] : start_index_list[i] + self.batch_size_list[i]]) 314 | ret_imgs = torch.stack(ret_imgs, 0) 315 | 316 | # assert self.has_pseudo_label_dataset() == True, 'Pseudo label dataset can\'t be empty' 317 | if self.has_pseudo_label_dataset(): 318 | try: 319 | psuedo_imgs, pseudo_texts = next(self.pseudo_dataloader_iter) 320 | except StopIteration: 321 | self.pseudo_dataloader_iter = iter(self.pseudo_dataloader) 322 | psuedo_imgs, pseudo_texts = next(self.pseudo_dataloader_iter) 323 | ret_imgs = torch.cat([ret_imgs, psuedo_imgs], 0) 324 | ret_texts += pseudo_texts 325 | 326 | return ret_imgs, ret_texts 327 | 328 | def get_meta_test_batch(self, meta_target_index=-1): # 如果指定了meta_target_index,则忽略第meta_target_index个数据集 329 | 330 | assert meta_target_index != -1, 'Meta target index should be specified' 331 | if meta_target_index >= len(self.batch_size_list) and self.has_pseudo_label_dataset(): 332 | try: 333 | img, text = next(self.pseudo_dataloader_iter) 334 | except StopIteration: 335 | self.pseudo_dataloader_iter = iter(self.pseudo_dataloader) 336 | img, text = next(self.pseudo_dataloader_iter) 337 | 338 | return img, text 339 | 340 | imgs, texts = next(self.data_loader) 341 | start_index_list = list(itertools.accumulate(self.batch_size_list)) 342 | start_index_list.insert(0, 0) 343 | ret_img, ret_text = None, None 344 | for i in range(len(self.batch_size_list)): 345 | if i == meta_target_index: 346 | ret_img = imgs[start_index_list[i]:start_index_list[i] + self.batch_size_list[i]] 347 | ret_text = texts[start_index_list[i]:start_index_list[i] + self.batch_size_list[i]] 348 | 349 | return ret_img, ret_text 350 | 351 | def add_pseudo_label_dataset(self, dataset, opt): 352 | avg_batch_size = opt.batch_size // opt.source_num 353 | batch_size = len(dataset) if len(dataset) <= avg_batch_size else avg_batch_size 354 | self.pseudo_batch_size = batch_size 355 | self.pseudo_dataloader = torch.utils.data.DataLoader( 356 | dataset, batch_size=batch_size, 357 | shuffle=True, # 'True' to check training progress with validation function. 358 | num_workers=int(opt.workers), pin_memory=False, collate_fn=self_training_collate) 359 | self.pseudo_dataloader_iter = iter(self.pseudo_dataloader) 360 | 361 | 362 | def has_pseudo_label_dataset(self): 363 | return True if self.pseudo_dataloader else False 364 | 365 | 366 | def hierarchical_dataset(root, opt, select_data='/', pseudo=False): 367 | """ select_data='/' contains all sub-directory of root directory """ 368 | dataset_list = [] 369 | dataset_log = f'dataset_root: {root}\t dataset: {select_data[0]}' 370 | print(dataset_log) 371 | dataset_log += '\n' 372 | for dirpath, dirnames, filenames in os.walk(root+'/', followlinks=True): 373 | print(dirpath, dirnames, filenames) 374 | if not dirnames: # 当dirnames为空,即当前dirpath下只包含(lmdb)文件时,进行操作 375 | select_flag = False 376 | for selected_d in select_data: # select_data为字符串,e.g. 'MJ','ST' 377 | if selected_d in dirpath: # 如果dirpath中包含了select_data 说明当前的目录是目标目录,select_flag置True 378 | select_flag = True 379 | break 380 | 381 | if select_flag: 382 | dataset = LmdbDataset(dirpath, opt, pseudo=pseudo) 383 | sub_dataset_log = f'sub-directory:\t/{os.path.relpath(dirpath, root)}\t num samples: {len(dataset)}' 384 | print(sub_dataset_log) 385 | dataset_log += f'{sub_dataset_log}\n' 386 | dataset_list.append(dataset) 387 | 388 | # 把所有数据集拼接在一起,以MJ为例,dataset_list中包括了MJ_train, MJ_valid和MJ_test 389 | concatenated_dataset = ConcatDataset(dataset_list) 390 | 391 | return concatenated_dataset, dataset_log 392 | 393 | 394 | class LmdbDataset(Dataset): 395 | 396 | def __init__(self, root, opt, pseudo=False): 397 | 398 | self.root = root 399 | self.opt = opt 400 | self.env = lmdb.open(root, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) 401 | if not self.env: 402 | print('cannot create lmdb from %s' % (root)) 403 | sys.exit(0) 404 | 405 | with self.env.begin(write=False) as txn: 406 | nSamples = int(txn.get('num-samples'.encode())) 407 | self.nSamples = nSamples 408 | 409 | if self.opt.data_filtering_off: 410 | # for fast check or benchmark evaluation with no filtering 411 | self.filtered_index_list = [index + 1 for index in range(self.nSamples)] 412 | else: 413 | """ Filtering part 414 | If you want to evaluate IC15-2077 & CUTE datasets which have special character labels, 415 | use --data_filtering_off and only evaluate on alphabets and digits. 416 | see https://github.com/clovaai/deep-text-recognition-benchmark/blob/6593928855fb7abb999a99f428b3e4477d4ae356/dataset.py#L190-L192 417 | 418 | And if you want to evaluate them with the model trained with --sensitive option, 419 | use --sensitive and --data_filtering_off, 420 | see https://github.com/clovaai/deep-text-recognition-benchmark/blob/dff844874dbe9e0ec8c5a52a7bd08c7f20afe704/test.py#L137-L144 421 | """ 422 | self.filtered_index_list = [] 423 | for index in range(self.nSamples): 424 | if self.opt.pseudo_dataset_num != -1 and pseudo and index > self.opt.pseudo_dataset_num: 425 | break 426 | index += 1 # lmdb starts with 1 427 | label_key = 'label-%09d'.encode() % index 428 | label = txn.get(label_key).decode('utf-8') 429 | # print(label) 430 | 431 | if len(label) > self.opt.batch_max_length: 432 | # print(f'The length of the label is longer than max_length: length 433 | # {len(label)}, {label} in dataset {self.root}') 434 | continue 435 | 436 | # By default, images containing characters which are not in opt.character are filtered. 437 | # You can add [UNK] token to `opt.character` in utils.py instead of this filtering. 438 | out_of_char = f'[^{self.opt.character}]' 439 | # if re.search(out_of_char, label.lower()): # 根据车牌场景进行了修改,因为车牌里只有大写字母,如果调用了lower,因为opt.char里面不包含小写字母,则所有车牌均被过滤 440 | if re.search(out_of_char, label): 441 | continue 442 | 443 | self.filtered_index_list.append(index) 444 | 445 | self.nSamples = len(self.filtered_index_list) 446 | 447 | def __len__(self): 448 | return self.nSamples 449 | 450 | def __getitem__(self, index): 451 | assert index <= len(self), 'index range error' 452 | index = self.filtered_index_list[index] 453 | 454 | with self.env.begin(write=False) as txn: 455 | label_key = 'label-%09d'.encode() % index 456 | label = txn.get(label_key).decode('utf-8') 457 | img_key = 'image-%09d'.encode() % index 458 | imgbuf = txn.get(img_key) 459 | 460 | buf = six.BytesIO() 461 | buf.write(imgbuf) 462 | buf.seek(0) 463 | try: 464 | if self.opt.rgb: 465 | img = Image.open(buf).convert('RGB') # for color image 466 | else: 467 | img = Image.open(buf).convert('L') 468 | 469 | except IOError: 470 | print(f'Corrupted image for {index}') 471 | # make dummy image and dummy label for corrupted image. 472 | if self.opt.rgb: 473 | img = Image.new('RGB', (self.opt.imgW, self.opt.imgH)) 474 | else: 475 | img = Image.new('L', (self.opt.imgW, self.opt.imgH)) 476 | label = '[dummy_label]' 477 | 478 | # if not self.opt.sensitive: 479 | # label = label.lower() 480 | 481 | # We only train and evaluate on alphanumerics (or pre-defined character set in train.py) 482 | out_of_char = f'[^{self.opt.character}]' 483 | label = re.sub(out_of_char, '', label) 484 | 485 | return (img, label) 486 | 487 | 488 | class RawDataset(Dataset): 489 | 490 | def __init__(self, root, opt): 491 | self.opt = opt 492 | self.image_path_list = [] 493 | for dirpath, dirnames, filenames in os.walk(root): 494 | for name in filenames: 495 | _, ext = os.path.splitext(name) 496 | ext = ext.lower() 497 | if ext == '.jpg' or ext == '.jpeg' or ext == '.png': 498 | self.image_path_list.append(os.path.join(dirpath, name)) 499 | 500 | self.image_path_list = natsorted(self.image_path_list) 501 | self.nSamples = len(self.image_path_list) 502 | 503 | def __len__(self): 504 | return self.nSamples 505 | 506 | def __getitem__(self, index): 507 | 508 | try: 509 | if self.opt.rgb: 510 | img = Image.open(self.image_path_list[index]).convert('RGB') # for color image 511 | else: 512 | img = Image.open(self.image_path_list[index]).convert('L') 513 | 514 | except IOError: 515 | print(f'Corrupted image for {index}') 516 | # make dummy image and dummy label for corrupted image. 517 | if self.opt.rgb: 518 | img = Image.new('RGB', (self.opt.imgW, self.opt.imgH)) 519 | else: 520 | img = Image.new('L', (self.opt.imgW, self.opt.imgH)) 521 | 522 | return (img, self.image_path_list[index]) 523 | 524 | 525 | class ResizeNormalize(object): 526 | 527 | def __init__(self, size, interpolation=Image.BICUBIC): 528 | self.size = size 529 | self.interpolation = interpolation 530 | self.toTensor = transforms.ToTensor() 531 | 532 | def __call__(self, img): 533 | img = img.resize(self.size, self.interpolation) 534 | img = self.toTensor(img) 535 | img.sub_(0.5).div_(0.5) 536 | return img 537 | 538 | 539 | class NormalizePAD(object): 540 | 541 | def __init__(self, max_size, PAD_type='right'): 542 | self.toTensor = transforms.ToTensor() 543 | self.max_size = max_size 544 | self.max_width_half = math.floor(max_size[2] / 2) 545 | self.PAD_type = PAD_type 546 | 547 | def __call__(self, img): 548 | img = self.toTensor(img) 549 | img.sub_(0.5).div_(0.5) 550 | c, h, w = img.size() 551 | Pad_img = torch.FloatTensor(*self.max_size).fill_(0) 552 | Pad_img[:, :, :w] = img # right pad 553 | if self.max_size[2] != w: # add border Pad 554 | Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w) 555 | 556 | return Pad_img 557 | 558 | 559 | class AlignCollate(object): 560 | 561 | def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False): 562 | self.imgH = imgH 563 | self.imgW = imgW 564 | self.keep_ratio_with_pad = keep_ratio_with_pad 565 | 566 | def __call__(self, batch): 567 | batch = filter(lambda x: x is not None, batch) 568 | images, labels = zip(*batch) 569 | 570 | if self.keep_ratio_with_pad: # same concept with 'Rosetta' paper 571 | resized_max_w = self.imgW 572 | input_channel = 3 if images[0].mode == 'RGB' else 1 573 | transform = NormalizePAD((input_channel, self.imgH, resized_max_w)) 574 | 575 | resized_images = [] 576 | for image in images: 577 | w, h = image.size 578 | ratio = w / float(h) 579 | if math.ceil(self.imgH * ratio) > self.imgW: 580 | resized_w = self.imgW 581 | else: 582 | resized_w = math.ceil(self.imgH * ratio) 583 | 584 | resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC) 585 | resized_images.append(transform(resized_image)) 586 | # resized_image.save('./image_test/%d_test.jpg' % w) 587 | 588 | image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0) 589 | 590 | else: 591 | transform = ResizeNormalize((self.imgW, self.imgH)) 592 | image_tensors = [transform(image) for image in images] 593 | image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0) 594 | 595 | return image_tensors, labels 596 | 597 | 598 | def self_training_collate(batch): 599 | imgs, labels = [], [] 600 | for img, label in batch: 601 | imgs.append(img) 602 | labels.append(label) 603 | 604 | return torch.stack(imgs), labels 605 | 606 | class SelfTrainingDataset(Dataset): 607 | def __init__(self, imgs, labels): 608 | self.imgs = imgs 609 | self.labels = labels 610 | 611 | def __getitem__(self, index): 612 | return self.imgs[index], self.labels[index] 613 | 614 | def __len__(self): 615 | assert len(self.imgs) == len(self.labels) 616 | return len(self.imgs) 617 | 618 | 619 | 620 | def tensor2im(image_tensor, imtype=np.uint8): 621 | image_numpy = image_tensor.cpu().float().numpy() 622 | if image_numpy.shape[0] == 1: 623 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 624 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 625 | return image_numpy.astype(imtype) 626 | 627 | 628 | def save_image(image_numpy, image_path): 629 | image_pil = Image.fromarray(image_numpy) 630 | image_pil.save(image_path) 631 | --------------------------------------------------------------------------------