├── img ├── attention_vis_1.gif ├── attention_vis_2.gif ├── attention_vis_3.gif ├── attention_vis_4.gif └── model_framework.jpg ├── losses ├── coral.py ├── focal_loss.py ├── tpn_task.py ├── mmd.py └── ef_focal_loss.py ├── modules ├── sequence_modeling.py ├── domain_adapt.py ├── initialization.py ├── model_utils.py ├── prediction.py ├── transformation.py ├── radam.py ├── densenet.py └── feature_extraction.py ├── create_lmdb_dataset.py ├── README.md ├── seqda_model.py ├── utils.py ├── test.py ├── dataset.py ├── train_da_coral.py ├── train_da_local.py └── train_da_global_local_selected.py /img/attention_vis_1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AprilYapingZhang/Seq2SeqAdapt/HEAD/img/attention_vis_1.gif -------------------------------------------------------------------------------- /img/attention_vis_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AprilYapingZhang/Seq2SeqAdapt/HEAD/img/attention_vis_2.gif -------------------------------------------------------------------------------- /img/attention_vis_3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AprilYapingZhang/Seq2SeqAdapt/HEAD/img/attention_vis_3.gif -------------------------------------------------------------------------------- /img/attention_vis_4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AprilYapingZhang/Seq2SeqAdapt/HEAD/img/attention_vis_4.gif -------------------------------------------------------------------------------- /img/model_framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AprilYapingZhang/Seq2SeqAdapt/HEAD/img/model_framework.jpg -------------------------------------------------------------------------------- /losses/coral.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def CORAL(source, target): 4 | d1 = source.data.shape[1] 5 | d2 = target.data.shape[1] 6 | 7 | # source covariance 8 | xm = torch.mean(source, 0, keepdim=True) - source 9 | xc = xm.t() @ xm 10 | 11 | # target covariance 12 | xmt = torch.mean(target, 0, keepdim=True) - target 13 | xct = xmt.t() @ xmt 14 | 15 | # frobenius norm between source and target 16 | loss = torch.mean(torch.mul((xc - xct), (xc - xct))) 17 | loss = loss/(4*d1*d2) 18 | 19 | return loss -------------------------------------------------------------------------------- /modules/sequence_modeling.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BidirectionalLSTM(nn.Module): 5 | 6 | def __init__(self, input_size, hidden_size, output_size): 7 | super(BidirectionalLSTM, self).__init__() 8 | self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True) 9 | self.linear = nn.Linear(hidden_size * 2, output_size) 10 | 11 | def forward(self, input): 12 | """ 13 | input : visual feature [batch_size x T x input_size] 14 | output : contextual feature [batch_size x T x output_size] 15 | """ 16 | self.rnn.flatten_parameters() 17 | recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) 18 | output = self.linear(recurrent) # batch_size x T x output_size 19 | return output 20 | -------------------------------------------------------------------------------- /losses/focal_loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: 3 | https://github.com/clcarwin/focal_loss_pytorch/blob/master/focalloss.py 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.autograd import Variable 10 | 11 | class FocalLoss(nn.Module): 12 | def __init__(self, gamma=0, alpha=None, size_average=True): 13 | super(FocalLoss, self).__init__() 14 | self.gamma = gamma 15 | self.alpha = alpha 16 | if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha]) 17 | if isinstance(alpha,list): self.alpha = torch.Tensor(alpha) 18 | self.size_average = size_average 19 | 20 | def forward(self, input, target): 21 | if input.dim()>2: 22 | input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W 23 | input = input.transpose(1,2) # N,C,H*W => N,H*W,C 24 | input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C 25 | target = target.view(-1,1) 26 | 27 | logpt = F.log_softmax(input) 28 | logpt = logpt.gather(1,target) 29 | logpt = logpt.view(-1) 30 | pt = Variable(logpt.data.exp()) 31 | 32 | if self.alpha is not None: 33 | if self.alpha.type()!=input.data.type(): 34 | self.alpha = self.alpha.type_as(input.data) 35 | at = self.alpha.gather(0,target.data.view(-1)) 36 | logpt = logpt * Variable(at) 37 | 38 | loss = -1 * (1-pt)**self.gamma * logpt 39 | if self.size_average: return loss.mean() 40 | else: return loss.sum() -------------------------------------------------------------------------------- /modules/domain_adapt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Function 5 | import numpy as np 6 | 7 | 8 | def flatten(x): 9 | N = list(x.size())[0] 10 | # print('dim 0', N, 1024*19*37) 11 | return x.view(N, -1) 12 | 13 | 14 | def grad_reverse(x, beta): 15 | return GradReverse(beta)(x) 16 | 17 | 18 | class GradReverse(Function): 19 | def __init__(self, beta): 20 | super(GradReverse, self).__init__() 21 | self.beta = beta 22 | 23 | def set_beta(self, beta): 24 | self.beta = beta 25 | 26 | def forward(self, x): 27 | return x.view_as(x) 28 | 29 | def backward(self, grad_output): 30 | return (grad_output * (-1 * self.beta)) 31 | 32 | 33 | # pool_feat dim: N x 2048, where N may be 300. 34 | 35 | class d_cls_inst(nn.Module): 36 | def __init__(self, beta=1, fc_size=2048): 37 | super(d_cls_inst, self).__init__() 38 | self.fc_1_inst = nn.Linear(fc_size, 100) 39 | self.fc_2_inst = nn.Linear(100, 2) 40 | self.relu = nn.ReLU(inplace=True) 41 | self.beta = beta 42 | # self.softmax = nn.Softmax() 43 | # self.logsoftmax = nn.LogSoftmax() 44 | self.bn = nn.BatchNorm1d(2) 45 | 46 | def forward(self, x): 47 | x = grad_reverse(x, self.beta) 48 | x = self.relu(self.fc_1_inst(x)) 49 | x = self.relu(self.bn(self.fc_2_inst(x))) 50 | # y = self.softmax(x) 51 | # x = self.logsoftmax(x) 52 | # return x, y 53 | return x 54 | 55 | def set_beta(self, beta): 56 | self.beta = beta 57 | -------------------------------------------------------------------------------- /modules/initialization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | 5 | def weights_init_cpm(m): 6 | classname = m.__class__.__name__ 7 | # print(classname) 8 | if classname.find('Conv') != -1: 9 | m.weight.data.normal_(0, 0.01) 10 | if m.bias is not None: m.bias.data.zero_() 11 | elif classname.find('BatchNorm2d') != -1: 12 | m.weight.data.fill_(1) 13 | m.bias.data.zero_() 14 | 15 | def weights_init_normal(m): 16 | classname = m.__class__.__name__ 17 | # print(classname) 18 | if classname.find('Conv') != -1: 19 | init.uniform(m.weight.data, 0.0, 0.02) 20 | elif classname.find('Linear') != -1: 21 | init.uniform(m.weight.data, 0.0, 0.02) 22 | elif classname.find('BatchNorm2d') != -1: 23 | init.uniform(m.weight.data, 1.0, 0.02) 24 | init.constant(m.bias.data, 0.0) 25 | 26 | 27 | def weights_init_xavier(m): 28 | classname = m.__class__.__name__ 29 | # print(classname) 30 | if classname.find('Conv') != -1: 31 | init.xavier_normal(m.weight.data, gain=1) 32 | elif classname.find('Linear') != -1: 33 | init.xavier_normal(m.weight.data, gain=1) 34 | elif classname.find('BatchNorm2d') != -1: 35 | init.uniform(m.weight.data, 1.0, 0.02) 36 | init.constant(m.bias.data, 0.0) 37 | 38 | 39 | def weights_init_kaiming(m): 40 | classname = m.__class__.__name__ 41 | # print(classname) 42 | if classname.find('Conv') != -1: 43 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in') 44 | elif classname.find('Linear') != -1: 45 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in') 46 | elif classname.find('BatchNorm2d') != -1: 47 | init.uniform(m.weight.data, 1.0, 0.02) 48 | init.constant(m.bias.data, 0.0) 49 | 50 | 51 | def weights_init_orthogonal(m): 52 | classname = m.__class__.__name__ 53 | print(classname) 54 | if classname.find('Conv') != -1: 55 | init.orthogonal(m.weight.data, gain=1) 56 | elif classname.find('Linear') != -1: 57 | init.orthogonal(m.weight.data, gain=1) 58 | elif classname.find('BatchNorm2d') != -1: 59 | init.uniform(m.weight.data, 1.0, 0.02) 60 | init.constant(m.bias.data, 0.0) -------------------------------------------------------------------------------- /losses/tpn_task.py: -------------------------------------------------------------------------------- 1 | # pytorch implementation for Transferrable Prototypical Networks for Unsupervised Domain Adaptation 2 | # Sample-level discrepancy loss in Section 3.4 Task-specific Domain Adaptation 3 | # https://arxiv.org/pdf/1904.11227.pdf 4 | 5 | import torch 6 | import torch.functional as F 7 | from torch import nn 8 | 9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 10 | 11 | 12 | class TpnTaskLoss(nn.Module): 13 | def __init__(self): 14 | super(TpnTaskLoss, self).__init__() 15 | 16 | def forward(self, src_feat, trg_feat, src_label, trg_label): 17 | labels = list(src_label.data.cpu().numpy()) 18 | labels = list(set(labels)) 19 | 20 | dim = src_feat.size(1) 21 | center_num = len(labels) 22 | 23 | u_s = torch.zeros(center_num, dim).to(device) 24 | u_t = torch.zeros(center_num, dim).to(device) 25 | u_st = torch.zeros(center_num, dim).to(device) 26 | 27 | for i, l in enumerate(labels): 28 | s_feat = src_feat[src_label == l] 29 | t_feat = trg_feat[trg_label == l] 30 | 31 | u_s[i, :] = s_feat.mean(dim=0) 32 | u_t[i, :] = t_feat.mean(dim=0) 33 | u_st[i, :] = (s_feat.sum(dim=0) + t_feat.sum(dim=0)) / (s_feat.size(0) + t_feat.size(0)) 34 | 35 | feats = torch.cat((src_feat, trg_feat), dim=0) 36 | p_s = torch.matmul(feats, u_s.t()) 37 | p_t = torch.matmul(feats, u_t.t()) 38 | p_st = torch.matmul(feats, u_st.t()) 39 | 40 | loss_st = (F.kl_div(F.log_softmax(p_s, dim=-1), F.log_softmax(p_t, dim=-1), 41 | reduction='mean') + 42 | F.kl_div(F.log_softmax(p_t, dim=-1), F.log_softmax(p_s, dim=-1), 43 | reduction='mean')) / 2 44 | loss_sst = (F.kl_div(F.log_softmax(p_s, dim=-1), F.log_softmax(p_st, dim=-1), 45 | reduction='mean') + 46 | F.kl_div(F.log_softmax(p_st, dim=-1), F.log_softmax(p_s, dim=-1), 47 | reduction='mean')) / 2 48 | loss_tst = (F.kl_div(F.log_softmax(p_t, dim=-1), F.log_softmax(p_st, dim=-1), 49 | reduction='mean') + 50 | F.kl_div(F.log_softmax(p_st, dim=-1), F.log_softmax(p_t, dim=-1), 51 | reduction='mean')) / 2 52 | tpn_task = (loss_st + loss_sst + loss_tst) / 3 53 | return tpn_task, ('04. tpn_task loss: ', tpn_task.data.cpu().numpy()) 54 | -------------------------------------------------------------------------------- /create_lmdb_dataset.py: -------------------------------------------------------------------------------- 1 | """ a modified version of CRNN torch repository https://github.com/bgshih/crnn/blob/master/tool/create_dataset.py """ 2 | 3 | import fire 4 | import os 5 | import lmdb 6 | import cv2 7 | 8 | import numpy as np 9 | 10 | 11 | def checkImageIsValid(imageBin): 12 | if imageBin is None: 13 | return False 14 | imageBuf = np.frombuffer(imageBin, dtype=np.uint8) 15 | img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE) 16 | imgH, imgW = img.shape[0], img.shape[1] 17 | if imgH * imgW == 0: 18 | return False 19 | return True 20 | 21 | 22 | def writeCache(env, cache): 23 | with env.begin(write=True) as txn: 24 | for k, v in cache.items(): 25 | txn.put(k, v) 26 | 27 | 28 | def createDataset(inputPath, gtFile, outputPath, checkValid=True): 29 | """ 30 | Create LMDB dataset for training and evaluation. 31 | ARGS: 32 | inputPath : input folder path where starts imagePath 33 | outputPath : LMDB output path 34 | gtFile : list of image path and label 35 | checkValid : if true, check the validity of every image 36 | """ 37 | os.makedirs(outputPath, exist_ok=True) 38 | env = lmdb.open(outputPath, map_size=1099511627776) 39 | cache = {} 40 | cnt = 1 41 | 42 | with open(gtFile, 'r', encoding='utf-8') as data: 43 | datalist = data.readlines() 44 | 45 | nSamples = len(datalist) 46 | for i in range(nSamples): 47 | line_info = datalist[i].strip('\n').split('\t') 48 | if len(line_info) == 2: 49 | imagePath, label = line_info 50 | elif len(line_info) > 2: 51 | imagePath, label = line_info[:2] 52 | imagePath = os.path.join(inputPath, imagePath) 53 | 54 | # # only use alphanumeric data 55 | # if re.search('[^a-zA-Z0-9]', label): 56 | # continue 57 | 58 | if not os.path.exists(imagePath): 59 | print('%s does not exist' % imagePath) 60 | continue 61 | with open(imagePath, 'rb') as f: 62 | imageBin = f.read() 63 | if checkValid: 64 | try: 65 | if not checkImageIsValid(imageBin): 66 | print('%s is not a valid image' % imagePath) 67 | continue 68 | except: 69 | print('error occured', i) 70 | with open(outputPath + '/error_image_log.txt', 'a') as log: 71 | log.write('%s-th image data occured error\n' % str(i)) 72 | continue 73 | 74 | imageKey = 'image-%09d'.encode() % cnt 75 | labelKey = 'label-%09d'.encode() % cnt 76 | cache[imageKey] = imageBin 77 | cache[labelKey] = label.encode() 78 | 79 | if cnt % 1000 == 0: 80 | writeCache(env, cache) 81 | cache = {} 82 | print('Written %d / %d' % (cnt, nSamples)) 83 | cnt += 1 84 | nSamples = cnt-1 85 | cache['num-samples'.encode()] = str(nSamples).encode() 86 | writeCache(env, cache) 87 | print('Created dataset with %d samples' % nSamples) 88 | 89 | 90 | if __name__ == '__main__': 91 | fire.Fire(createDataset) 92 | -------------------------------------------------------------------------------- /modules/model_utils.py: -------------------------------------------------------------------------------- 1 | from scipy.ndimage.interpolation import zoom 2 | from collections import OrderedDict 3 | import utils 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import copy, numbers, numpy as np 8 | 9 | def np2variable(x, is_cuda=True, requires_grad=True, dtype=torch.FloatTensor): 10 | if isinstance(x, np.ndarray): 11 | v = torch.autograd.Variable(torch.from_numpy(x).type(dtype), requires_grad=requires_grad) 12 | elif isinstance(x, torch.FloatTensor): 13 | v = torch.autograd.Variable(x.type(dtype), requires_grad=requires_grad) 14 | else: 15 | raise Exception('Do not know this type : {}'.format( type(x) )) 16 | 17 | if is_cuda: return v.cuda() 18 | else: return v 19 | 20 | def variable2np(x): 21 | if x.is_cuda: 22 | x = x.cpu() 23 | if isinstance(x, torch.autograd.Variable): 24 | return x.data.numpy() 25 | else: 26 | return x.numpy() 27 | 28 | def get_parameters(model, bias): 29 | for m in model.modules(): 30 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): 31 | if bias: 32 | yield m.bias 33 | else: 34 | yield m.weight 35 | elif isinstance(m, nn.BatchNorm2d): 36 | if bias: 37 | yield m.bias 38 | else: 39 | yield m.weight 40 | 41 | def load_weight_from_dict(model, weight_state_dict, param_pair=None, remove_prefix=True): 42 | if remove_prefix: weight_state_dict = remove_module_dict(weight_state_dict) 43 | all_parameter = model.state_dict() 44 | all_weights = [] 45 | finetuned_layer, random_initial_layer = [], [] 46 | for key, value in all_parameter.items(): 47 | if param_pair is not None and key in param_pair: 48 | all_weights.append((key, weight_state_dict[ param_pair[key] ])) 49 | elif key in weight_state_dict: 50 | all_weights.append((key, weight_state_dict[key])) 51 | finetuned_layer.append(key) 52 | else: 53 | all_weights.append((key, value)) 54 | random_initial_layer.append(key) 55 | print ('==>[load_model] finetuned layers : {}'.format(finetuned_layer)) 56 | print ('==>[load_model] keeped layers : {}'.format(random_initial_layer)) 57 | all_weights = OrderedDict(all_weights) 58 | model.load_state_dict(all_weights) 59 | 60 | def remove_module_dict(state_dict): 61 | new_state_dict = OrderedDict() 62 | for k, v in state_dict.items(): 63 | name = k[7:] # remove `module.` 64 | new_state_dict[name] = v 65 | return new_state_dict 66 | 67 | def roi_pooling(input, rois, size=(7,7)): 68 | assert rois.dim() == 2 and rois.size(1) == 5, 'rois shape is wrong : {}'.format(rois.size()) 69 | output = [] 70 | num_rois = rois.size(0) 71 | size = np.array(size) 72 | spatial_size = np.array([input.size(3), input.size(2)]) 73 | for i in range(num_rois): 74 | roi = variable2np(rois[i]) 75 | im_idx = int(roi[0]) 76 | theta = utils.crop2affine(spatial_size, roi[1:]) 77 | theta = np2variable(theta, input.is_cuda).unsqueeze(0) 78 | grid_size = torch.Size([1, 3, int(size[1]), int(size[0])]) 79 | grid = F.affine_grid(theta, grid_size) 80 | roi_feature = F.grid_sample(input.narrow(0, im_idx, 1), grid) 81 | output.append( roi_feature ) 82 | return torch.cat(output, 0) 83 | 84 | def print_network(net, net_str, log): 85 | num_params = 0 86 | for param in net.parameters(): 87 | num_params += param.numel() 88 | utils.print_log(net, log) 89 | utils.print_log('Total number of parameters for {} is {}'.format(net_str, num_params), log) 90 | 91 | def count_network_param(net): 92 | num_params = 0 93 | for param in net.parameters(): 94 | num_params += param.numel() 95 | return num_params -------------------------------------------------------------------------------- /losses/mmd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from functools import partial 3 | from torch.autograd import Variable 4 | 5 | 6 | # Consider linear time MMD with a linear kernel: 7 | # K(f(x), f(y)) = f(x)^Tf(y) 8 | # h(z_i, z_j) = k(x_i, x_j) + k(y_i, y_j) - k(x_i, y_j) - k(x_j, y_i) 9 | # = [f(x_i) - f(y_i)]^T[f(x_j) - f(y_j)] 10 | # 11 | # f_of_X: batch_size * k 12 | # f_of_Y: batch_size * k 13 | def mmd_linear(f_of_X, f_of_Y): 14 | delta = f_of_X - f_of_Y 15 | loss = torch.mean(torch.mm(delta, torch.transpose(delta, 0, 1))) 16 | return loss 17 | 18 | 19 | def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): 20 | n_samples = int(source.size()[0]) + int(target.size()[0]) 21 | total = torch.cat([source, target], dim=0) 22 | total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) 23 | total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) 24 | L2_distance = ((total0 - total1) ** 2).sum(2) 25 | if fix_sigma: 26 | bandwidth = fix_sigma 27 | else: 28 | bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples) 29 | bandwidth /= kernel_mul ** (kernel_num // 2) 30 | bandwidth_list = [bandwidth * (kernel_mul ** i) for i in range(kernel_num)] 31 | kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list] 32 | return sum(kernel_val) # /len(kernel_val) 33 | 34 | 35 | def mmd_rbf_accelerate(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): 36 | batch_size = int(source.size()[0]) 37 | kernels = guassian_kernel(source, target, 38 | kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma) 39 | loss = 0 40 | for i in range(batch_size): 41 | s1, s2 = i, (i + 1) % batch_size 42 | t1, t2 = s1 + batch_size, s2 + batch_size 43 | loss += kernels[s1, s2] + kernels[t1, t2] 44 | loss -= kernels[s1, t2] + kernels[s2, t1] 45 | return loss / float(batch_size) 46 | 47 | 48 | def mmd_rbf_noaccelerate(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): 49 | batch_size = int(source.size()[0]) 50 | kernels = guassian_kernel(source, target, 51 | kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma) 52 | XX = kernels[:batch_size, :batch_size] 53 | YY = kernels[batch_size:, batch_size:] 54 | XY = kernels[:batch_size, batch_size:] 55 | YX = kernels[batch_size:, :batch_size] 56 | loss = torch.mean(XX + YY - XY - YX) 57 | return loss 58 | 59 | 60 | def pairwise_distance(x, y): 61 | if not len(x.shape) == len(y.shape) == 2: 62 | raise ValueError('Both inputs should be matrices.') 63 | 64 | if x.shape[1] != y.shape[1]: 65 | raise ValueError('The number of features should be the same.') 66 | 67 | x = x.view(x.shape[0], x.shape[1], 1) 68 | y = torch.transpose(y, 0, 1) 69 | output = torch.sum((x - y) ** 2, 1) 70 | output = torch.transpose(output, 0, 1) 71 | return output 72 | 73 | 74 | def gaussian_kernel_matrix(x, y, sigmas): 75 | sigmas = sigmas.view(sigmas.shape[0], 1) 76 | beta = 1. / (2. * sigmas) 77 | dist = pairwise_distance(x, y).contiguous() 78 | dist_ = dist.view(1, -1) 79 | s = torch.matmul(beta, dist_) 80 | return torch.sum(torch.exp(-s), 0).view_as(dist) 81 | 82 | 83 | def maximum_mean_discrepancy(x, y, kernel=gaussian_kernel_matrix): 84 | cost = torch.mean(kernel(x, x)) 85 | cost += torch.mean(kernel(y, y)) 86 | cost -= 2 * torch.mean(kernel(x, y)) 87 | return cost 88 | 89 | 90 | def mmd_loss(source_features, target_features): 91 | sigmas = [ 92 | 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 5, 10, 15, 20, 25, 30, 35, 100, 93 | 1e3, 1e4, 1e5, 1e6 94 | ] 95 | gaussian_kernel = partial( 96 | gaussian_kernel_matrix, sigmas=Variable(torch.cuda.FloatTensor(sigmas)) 97 | ) 98 | loss_value = maximum_mean_discrepancy(source_features, target_features, kernel=gaussian_kernel) 99 | loss_value = loss_value 100 | return loss_value 101 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adversarial Sequence-to-sequence Domain adaptation 2 | 3 | ## Overview 4 | We propose a novel Adversarial Sequence-to-sequence Domain Adaptation Network dubbed ASSDA for robust text image recognition, 5 | which could adaptively transfer coarse global-level and fine-grained character-level knowledge. 6 | 7 | 8 | 9 | 10 | ### Install 11 | 12 | 1. This code is test in the environment with ```cuda==10.1, python==3.6.8```. 13 | 14 | 2. Install Requirements 15 | 16 | ``` 17 | pip3 install torch==1.2.0 pillow==6.2.1 torchvision==0.4.0 lmdb nltk natsort 18 | ``` 19 | 20 | ### Dataset 21 | 22 | - The prepared synthetic and real scene dataset can be downloaded from [here](https://drive.google.com/drive/folders/192UfE9agQUMNq6AgU3_E05_FcPZK4hyt), which are created by NAVER Corp. 23 | 24 | - Synthetic scene text : [MJSynth (MJ)](http://www.robots.ox.ac.uk/~vgg/data/text/) and [SynthText (ST)](http://www.robots.ox.ac.uk/~vgg/data/scenetext/) \ 25 | - Real scene text : the union of the training sets [IC13](http://rrc.cvc.uab.es/?ch=2), [IC15](http://rrc.cvc.uab.es/?ch=4), [IIIT](http://cvit.iiit.ac.in/projects/SceneTextUnderstanding/IIIT5K.html), and [SVT](http://www.iapr-tc11.org/mediawiki/index.php/The_Street_View_Text_Dataset).\ 26 | - Benchmark evaluation scene text datasets : consist of [IIIT](http://cvit.iiit.ac.in/projects/SceneTextUnderstanding/IIIT5K.html), [SVT](http://www.iapr-tc11.org/mediawiki/index.php/The_Street_View_Text_Dataset), [IC03](http://www.iapr-tc11.org/mediawiki/index.php/ICDAR_2003_Robust_Reading_Competitions), [IC13](http://rrc.cvc.uab.es/?ch=2)[3], [IC15](http://rrc.cvc.uab.es/?ch=4), 27 | [SVTP](http://openaccess.thecvf.com/content_iccv_2013/papers/Phan_Recognizing_Text_with_2013_ICCV_paper.pdf), and [CUTE](http://cs-chan.com/downloads_CUTE80_dataset.html). 28 | - The prepared handwritten text dataset can be downloaded from [here](https://www.dropbox.com/sh/4a9vrtnshozu929/AAAZucKLtEAUDuOufIRDVPOTa?dl=0) 29 | - Handwritten text: [IAM](http://www.fki.inf.unibe.ch/databases/iam-handwriting-database) 30 | 31 | 32 | ### Training and evaluation 33 | 34 | - For a toy example, you can download the pretrained model from [here](https://drive.google.com/drive/folders/15WPsuPJDCzhp2SvYZLRj8mAlT3zmoAMW) 35 | 36 | - Add model files to test into `data/` 37 | 38 | - Training model 39 | 40 | ``` 41 | CUDA_VISIBLE_DEVICES=1 python train_da_global_local_selected.py --Transformation TPS --FeatureExtraction ResNet --SequenceModeling BiLSTM --Prediction Attn \ 42 | --src_train_data ./data/data_lmdb_release/training/ \ 43 | --tar_train_data ./data/IAM/test --tar_select_data IAM --tar_batch_ratio 1 --valid_data ../data/IAM/test/ \ 44 | --continue_model ./data/TPS-ResNet-BiLSTM-Attn.pth \ 45 | --batch_size 128 --lr 1 \ 46 | --experiment_name _adv_global_local_synth2iam_pc_0.1 --pc 0.1 47 | ``` 48 | 49 | - Test model 50 | 51 | - Test the baseline model 52 | ``` 53 | CUDA_VISIBLE_DEVICES=0 python test.py --Transformation TPS --FeatureExtraction ResNet --SequenceModeling BiLSTM --Prediction Attn \ 54 | --eval_data ./data/IAM/test \ 55 | --saved_model ./data/TPS-ResNet-BiLSTM-Attn.pth 56 | ``` 57 | 58 | - Test the adaptation model 59 | 60 | ``` 61 | CUDA_VISIBLE_DEVICES=0 python test.py --Transformation TPS --FeatureExtraction ResNet --SequenceModeling BiLSTM --Prediction Attn \ 62 | --eval_data ./data/IAM/test \ 63 | --saved_model saved_models/TPS-ResNet-BiLSTM-Attn-Seed1111_adv_global_local_selected/best_accuracy.pth 64 | ``` 65 | 66 | 67 | ## Citation 68 | If you use this code for a paper please cite: 69 | 70 | ``` 71 | @inproceedings{zhang2019sequence, 72 | title={Sequence-to-sequence domain adaptation network for robust text image recognition}, 73 | author={Zhang, Yaping and Nie, Shuai and Liu, Wenju and Xu, Xing and Zhang, Dongxiang and Shen, Heng Tao}, 74 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 75 | pages={2740--2749}, 76 | year={2019} 77 | } 78 | 79 | @article{zhang2021robust, 80 | title={Robust Text Image Recognition via Adversarial Sequence-to-Sequence Domain Adaptation}, 81 | author={Zhang, Yaping and Nie, Shuai and Liang, Shan and Liu, Wenju}, 82 | journal={IEEE Transactions on Image Processing}, 83 | volume={30}, 84 | pages={3922--3933}, 85 | year={2021}, 86 | publisher={IEEE} 87 | } 88 | ``` 89 | 90 | 91 | ## Acknowledgement 92 | 93 | This implementation has been based on this repository [deep-text-recognition-benchmark](https://github.com/clovaai/deep-text-recognition-benchmark) 94 | 95 | -------------------------------------------------------------------------------- /seqda_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import torch.nn as nn 18 | 19 | from modules.feature_extraction import VGG_FeatureExtractor, RCNN_FeatureExtractor, \ 20 | ResNet_FeatureExtractor, DenseNet_FeatureExtractor 21 | from modules.prediction import Attention 22 | from modules.sequence_modeling import BidirectionalLSTM 23 | from modules.transformation import TPS_SpatialTransformerNetwork 24 | 25 | 26 | class Model(nn.Module): 27 | 28 | def __init__(self, opt): 29 | super(Model, self).__init__() 30 | self.opt = opt 31 | self.stages = {'Trans': opt.Transformation, 'Feat': opt.FeatureExtraction, 32 | 'Seq': opt.SequenceModeling, 'Pred': opt.Prediction} 33 | 34 | """ Transformation """ 35 | if opt.Transformation == 'TPS': 36 | self.Transformation = TPS_SpatialTransformerNetwork( 37 | F=opt.num_fiducial, I_size=(opt.imgH, opt.imgW), I_r_size=(opt.imgH, opt.imgW), 38 | I_channel_num=opt.input_channel) 39 | else: 40 | print('No Transformation module specified') 41 | 42 | """ FeatureExtraction """ 43 | if opt.FeatureExtraction == 'VGG': 44 | self.FeatureExtraction = VGG_FeatureExtractor(opt.input_channel, opt.output_channel) 45 | elif opt.FeatureExtraction == 'RCNN': 46 | self.FeatureExtraction = RCNN_FeatureExtractor(opt.input_channel, opt.output_channel) 47 | elif opt.FeatureExtraction == 'ResNet': 48 | self.FeatureExtraction = ResNet_FeatureExtractor(opt.input_channel, opt.output_channel) 49 | elif opt.FeatureExtraction == 'DenseNet': 50 | self.FeatureExtraction = DenseNet_FeatureExtractor(opt.input_channel, 51 | opt.output_channel) 52 | else: 53 | raise Exception('No FeatureExtraction module specified') 54 | self.FeatureExtraction_output = opt.output_channel # int(imgH/16-1) * 512 55 | self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1)) # Transform final (imgH/16-1) -> 1 56 | 57 | """ Sequence modeling""" 58 | if opt.SequenceModeling == 'BiLSTM': 59 | self.SequenceModeling = nn.Sequential( 60 | BidirectionalLSTM(self.FeatureExtraction_output, opt.hidden_size, opt.hidden_size), 61 | BidirectionalLSTM(opt.hidden_size, opt.hidden_size, opt.hidden_size)) 62 | self.SequenceModeling_output = opt.hidden_size 63 | else: 64 | print('No SequenceModeling module specified') 65 | self.SequenceModeling_output = self.FeatureExtraction_output 66 | 67 | """ Prediction """ 68 | if opt.Prediction == 'Attn': 69 | self.Prediction = Attention(self.SequenceModeling_output, opt.hidden_size, 70 | opt.num_class) 71 | else: 72 | raise Exception('Prediction is neither CTC or Attn') 73 | 74 | def forward(self, input, text, is_train=True): 75 | """ Transformation stage """ 76 | key_points = None 77 | if not self.stages['Trans'] == "None": 78 | input, key_points = self.Transformation(input) 79 | 80 | """ Feature extraction stage """ 81 | visual_feature = self.FeatureExtraction(input) 82 | visual_feature = self.AdaptiveAvgPool( 83 | visual_feature.permute(0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h] 84 | visual_feature = visual_feature.squeeze(3) 85 | # b,w,c 86 | 87 | """ Sequence modeling stage """ 88 | if self.stages['Seq'] == 'BiLSTM': 89 | contextual_feature = self.SequenceModeling(visual_feature) 90 | else: 91 | contextual_feature = visual_feature # for convenience. this is NOT contextually modeled by BiLSTM 92 | 93 | """ Prediction stage """ 94 | prediction = self.Prediction(contextual_feature.contiguous(), text, is_train, 95 | batch_max_length=self.opt.batch_max_length) 96 | 97 | 98 | return prediction, visual_feature, self.Prediction.context_history 99 | 100 | 101 | -------------------------------------------------------------------------------- /modules/prediction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 5 | 6 | 7 | class Attention(nn.Module): 8 | 9 | def __init__(self, input_size, hidden_size, num_classes): 10 | super(Attention, self).__init__() 11 | self.attention_cell = AttentionCell(input_size, hidden_size, num_classes) 12 | self.hidden_size = hidden_size 13 | self.num_classes = num_classes 14 | self.generator = nn.Linear(hidden_size, num_classes) 15 | 16 | def _char_to_onehot(self, input_char, onehot_dim=38): 17 | input_char = input_char.unsqueeze(1) 18 | batch_size = input_char.size(0) 19 | one_hot = torch.FloatTensor(batch_size, onehot_dim).zero_().to(device) 20 | one_hot = one_hot.scatter_(1, input_char, 1) 21 | return one_hot 22 | 23 | def forward(self, batch_H, text, is_train=True, batch_max_length=25): 24 | """ 25 | input: 26 | batch_H : contextual_feature H = hidden state of encoder. [batch_size x num_steps x num_classes] 27 | text : the text-index of each image. [batch_size x (max_length+1)]. +1 for [GO] token. text[:, 0] = [GO]. 28 | output: probability distribution at each step [batch_size x num_steps x num_classes] 29 | """ 30 | batch_size = batch_H.size(0) 31 | num_steps = batch_max_length + 1 # +1 for [s] at end of sentence. 32 | 33 | output_hiddens = torch.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0).to(device) 34 | hidden = (torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device), 35 | torch.FloatTensor(batch_size, self.hidden_size).fill_(0).to(device)) 36 | 37 | self.context_history = torch.FloatTensor(batch_size, num_steps, self.hidden_size).fill_(0).to(device) 38 | self.alpha_history = [] 39 | if is_train: 40 | for i in range(num_steps): 41 | # one-hot vectors for a i-th char. in a batch 42 | char_onehots = self._char_to_onehot(text[:, i], onehot_dim=self.num_classes) 43 | # hidden : decoder's hidden s_{t-1}, batch_H : encoder's hidden H, char_onehots : one-hot(y_{t-1}) 44 | # batch_H [batch_size,times, feaiture_dims) 45 | # alpha [batch_size,times,1] 46 | # cur_time: context [batch_size,feature_dims] 47 | hidden, alpha, context = self.attention_cell(hidden, batch_H, char_onehots) 48 | 49 | output_hiddens[:, i, :] = hidden[0] # LSTM hidden index (0: hidden, 1: Cell) 50 | self.alpha_history.append(alpha) 51 | self.context_history[:, i, :] =context 52 | 53 | probs = self.generator(output_hiddens) 54 | 55 | else: 56 | targets = torch.LongTensor(batch_size).fill_(0).to(device) # [GO] token 57 | probs = torch.FloatTensor(batch_size, num_steps, self.num_classes).fill_(0).to(device) 58 | 59 | for i in range(num_steps): 60 | char_onehots = self._char_to_onehot(targets, onehot_dim=self.num_classes) 61 | hidden, alpha, context = self.attention_cell(hidden, batch_H, char_onehots) 62 | probs_step = self.generator(hidden[0]) 63 | probs[:, i, :] = probs_step 64 | _, next_input = probs_step.max(1) 65 | 66 | 67 | targets = next_input 68 | self.alpha_history.append(alpha) 69 | self.context_history[:, i, :] = context 70 | self.alpha_history = torch.cat(self.alpha_history, -1) 71 | self.alpha_history.permute(0,2,1) # batch_size x num_steps x num_classes 72 | return probs # batch_size x num_steps x num_classes 73 | 74 | 75 | class AttentionCell(nn.Module): 76 | 77 | def __init__(self, input_size, hidden_size, num_embeddings): 78 | super(AttentionCell, self).__init__() 79 | self.i2h = nn.Linear(input_size, hidden_size, bias=False) 80 | self.h2h = nn.Linear(hidden_size, hidden_size) # either i2i or h2h should have bias 81 | self.score = nn.Linear(hidden_size, 1, bias=False) 82 | self.rnn = nn.LSTMCell(input_size + num_embeddings, hidden_size) 83 | self.hidden_size = hidden_size 84 | 85 | def forward(self, prev_hidden, batch_H, char_onehots): 86 | # [batch_size x num_encoder_step x num_channel] -> [batch_size x num_encoder_step x hidden_size] 87 | 88 | batch_H_proj = self.i2h(batch_H) 89 | prev_hidden_proj = self.h2h(prev_hidden[0]).unsqueeze(1) 90 | # e= v^T tanh(Ws*s_{t-1} + Wh*H) : H batch_size,times,feature_dims, 91 | e = self.score(torch.tanh(batch_H_proj + prev_hidden_proj)) # batch_size x num_encoder_step * 1 92 | alpha = F.softmax(e, dim=1) 93 | context = torch.bmm(alpha.permute(0, 2, 1), batch_H).squeeze(1) # batch_size x num_channel 94 | concat_context = torch.cat([context, char_onehots], 1) # batch_size x (num_channel + num_embedding) 95 | cur_hidden = self.rnn(concat_context, prev_hidden) 96 | return cur_hidden, alpha,context 97 | -------------------------------------------------------------------------------- /modules/transformation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 6 | 7 | 8 | class TPS_SpatialTransformerNetwork(nn.Module): 9 | """ Rectification Network of RARE, namely TPS based STN """ 10 | 11 | def __init__(self, F, I_size, I_r_size, I_channel_num=1): 12 | """ Based on RARE TPS 13 | input: 14 | batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width] 15 | I_size : (height, width) of the input image I 16 | I_r_size : (height, width) of the rectified image I_r 17 | I_channel_num : the number of channels of the input image I 18 | output: 19 | batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width] 20 | """ 21 | super(TPS_SpatialTransformerNetwork, self).__init__() 22 | self.F = F 23 | self.I_size = I_size 24 | self.I_r_size = I_r_size # = (I_r_height, I_r_width) 25 | self.I_channel_num = I_channel_num 26 | self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num) 27 | self.GridGenerator = GridGenerator(self.F, self.I_r_size) 28 | 29 | def forward(self, batch_I): 30 | batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 31 | build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2 32 | build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2]) 33 | batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border') 34 | 35 | return batch_I_r,batch_C_prime 36 | 37 | 38 | class LocalizationNetwork(nn.Module): 39 | """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """ 40 | 41 | def __init__(self, F, I_channel_num): 42 | super(LocalizationNetwork, self).__init__() 43 | self.F = F 44 | self.I_channel_num = I_channel_num 45 | self.conv = nn.Sequential( 46 | nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1, 47 | bias=False), nn.BatchNorm2d(64), nn.ReLU(True), 48 | nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2 49 | nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), 50 | nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4 51 | nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), 52 | nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8 53 | nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), 54 | nn.AdaptiveAvgPool2d(1) # batch_size x 512 55 | ) 56 | 57 | self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True)) 58 | self.localization_fc2 = nn.Linear(256, self.F * 2) 59 | 60 | # Init fc2 in LocalizationNetwork 61 | self.localization_fc2.weight.data.fill_(0) 62 | """ see RARE paper Fig. 6 (a) """ 63 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 64 | ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2)) 65 | ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2)) 66 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 67 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 68 | initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 69 | self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1) 70 | 71 | def forward(self, batch_I): 72 | """ 73 | input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width] 74 | output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2] 75 | """ 76 | batch_size = batch_I.size(0) 77 | features = self.conv(batch_I).view(batch_size, -1) 78 | batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2) 79 | return batch_C_prime 80 | 81 | 82 | class GridGenerator(nn.Module): 83 | """ Grid Generator of RARE, which produces P_prime by multipling T with P """ 84 | 85 | def __init__(self, F, I_r_size): 86 | """ Generate P_hat and inv_delta_C for later """ 87 | super(GridGenerator, self).__init__() 88 | self.eps = 1e-6 89 | self.I_r_height, self.I_r_width = I_r_size 90 | self.F = F 91 | self.C = self._build_C(self.F) # F x 2 92 | self.P = self._build_P(self.I_r_width, self.I_r_height) 93 | 94 | self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3 95 | self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3 96 | 97 | def _build_C(self, F): 98 | """ Return coordinates of fiducial points in I_r; C """ 99 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) 100 | ctrl_pts_y_top = -1 * np.ones(int(F / 2)) 101 | ctrl_pts_y_bottom = np.ones(int(F / 2)) 102 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 103 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 104 | C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 105 | return C # F x 2 106 | 107 | def _build_inv_delta_C(self, F, C): 108 | """ Return inv_delta_C which is needed to calculate T """ 109 | hat_C = np.zeros((F, F), dtype=float) # F x F 110 | for i in range(0, F): 111 | for j in range(i, F): 112 | r = np.linalg.norm(C[i] - C[j]) 113 | hat_C[i, j] = r 114 | hat_C[j, i] = r 115 | np.fill_diagonal(hat_C, 1) 116 | hat_C = (hat_C ** 2) * np.log(hat_C) 117 | # print(C.shape, hat_C.shape) 118 | delta_C = np.concatenate( # F+3 x F+3 119 | [ 120 | np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3 121 | np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3 122 | np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3 123 | ], 124 | axis=0 125 | ) 126 | inv_delta_C = np.linalg.inv(delta_C) 127 | return inv_delta_C # F+3 x F+3 128 | 129 | def _build_P(self, I_r_width, I_r_height): 130 | I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width # self.I_r_width 131 | I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height # self.I_r_height 132 | P = np.stack( # self.I_r_width x self.I_r_height x 2 133 | np.meshgrid(I_r_grid_x, I_r_grid_y), 134 | axis=2 135 | ) 136 | return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2 137 | 138 | def _build_P_hat(self, F, C, P): 139 | n = P.shape[0] # n (= self.I_r_width x self.I_r_height) 140 | P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2 141 | C_tile = np.expand_dims(C, axis=0) # 1 x F x 2 142 | P_diff = P_tile - C_tile # n x F x 2 143 | rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F 144 | rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F 145 | P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1) 146 | return P_hat # n x F+3 147 | 148 | def build_P_prime(self, batch_C_prime): 149 | """ Generate Grid from batch_C_prime [batch_size x F x 2] """ 150 | batch_size = batch_C_prime.size(0) 151 | batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1) 152 | batch_P_hat = self.P_hat.repeat(batch_size, 1, 1) 153 | batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros( 154 | batch_size, 3, 2).float().to(device)), dim=1) # batch_size x F+3 x 2 155 | batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2 156 | batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2 157 | return batch_P_prime # batch_size x n x 2 158 | -------------------------------------------------------------------------------- /modules/radam.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.optim.optimizer import Optimizer, required 4 | 5 | 6 | class RAdam(Optimizer): 7 | 8 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 9 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 10 | self.buffer = [[None, None, None] for ind in range(10)] 11 | super(RAdam, self).__init__(params, defaults) 12 | 13 | def __setstate__(self, state): 14 | super(RAdam, self).__setstate__(state) 15 | 16 | def step(self, closure=None): 17 | 18 | loss = None 19 | if closure is not None: 20 | loss = closure() 21 | 22 | for group in self.param_groups: 23 | 24 | for p in group['params']: 25 | if p.grad is None: 26 | continue 27 | grad = p.grad.data.float() 28 | if grad.is_sparse: 29 | raise RuntimeError('RAdam does not support sparse gradients') 30 | 31 | p_data_fp32 = p.data.float() 32 | 33 | state = self.state[p] 34 | 35 | if len(state) == 0: 36 | state['step'] = 0 37 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 38 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 39 | else: 40 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 41 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 42 | 43 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 44 | beta1, beta2 = group['betas'] 45 | 46 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 47 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 48 | 49 | state['step'] += 1 50 | buffered = self.buffer[int(state['step'] % 10)] 51 | if state['step'] == buffered[0]: 52 | N_sma, step_size = buffered[1], buffered[2] 53 | else: 54 | buffered[0] = state['step'] 55 | beta2_t = beta2 ** state['step'] 56 | N_sma_max = 2 / (1 - beta2) - 1 57 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 58 | buffered[1] = N_sma 59 | 60 | # more conservative since it's an approximated value 61 | if N_sma >= 5: 62 | step_size = group['lr'] * math.sqrt( 63 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * ( 64 | N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / ( 65 | 1 - beta1 ** state['step']) 66 | else: 67 | step_size = group['lr'] / (1 - beta1 ** state['step']) 68 | buffered[2] = step_size 69 | 70 | if group['weight_decay'] != 0: 71 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 72 | 73 | # more conservative since it's an approximated value 74 | if N_sma >= 5: 75 | denom = exp_avg_sq.sqrt().add_(group['eps']) 76 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 77 | else: 78 | p_data_fp32.add_(-step_size, exp_avg) 79 | 80 | p.data.copy_(p_data_fp32) 81 | 82 | return loss 83 | 84 | 85 | class PlainRAdam(Optimizer): 86 | 87 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 88 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 89 | 90 | super(PlainRAdam, self).__init__(params, defaults) 91 | 92 | def __setstate__(self, state): 93 | super(PlainRAdam, self).__setstate__(state) 94 | 95 | def step(self, closure=None): 96 | 97 | loss = None 98 | if closure is not None: 99 | loss = closure() 100 | 101 | for group in self.param_groups: 102 | 103 | for p in group['params']: 104 | if p.grad is None: 105 | continue 106 | grad = p.grad.data.float() 107 | if grad.is_sparse: 108 | raise RuntimeError('RAdam does not support sparse gradients') 109 | 110 | p_data_fp32 = p.data.float() 111 | 112 | state = self.state[p] 113 | 114 | if len(state) == 0: 115 | state['step'] = 0 116 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 117 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 118 | else: 119 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 120 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 121 | 122 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 123 | beta1, beta2 = group['betas'] 124 | 125 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 126 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 127 | 128 | state['step'] += 1 129 | beta2_t = beta2 ** state['step'] 130 | N_sma_max = 2 / (1 - beta2) - 1 131 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 132 | 133 | if group['weight_decay'] != 0: 134 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 135 | 136 | # more conservative since it's an approximated value 137 | if N_sma >= 5: 138 | step_size = group['lr'] * math.sqrt( 139 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * ( 140 | N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / ( 141 | 1 - beta1 ** state['step']) 142 | denom = exp_avg_sq.sqrt().add_(group['eps']) 143 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 144 | else: 145 | step_size = group['lr'] / (1 - beta1 ** state['step']) 146 | p_data_fp32.add_(-step_size, exp_avg) 147 | 148 | p.data.copy_(p_data_fp32) 149 | 150 | return loss 151 | 152 | 153 | class AdamW(Optimizer): 154 | 155 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup=0): 156 | defaults = dict(lr=lr, betas=betas, eps=eps, 157 | weight_decay=weight_decay, warmup=warmup) 158 | super(AdamW, self).__init__(params, defaults) 159 | 160 | def __setstate__(self, state): 161 | super(AdamW, self).__setstate__(state) 162 | 163 | def step(self, closure=None): 164 | loss = None 165 | if closure is not None: 166 | loss = closure() 167 | 168 | for group in self.param_groups: 169 | 170 | for p in group['params']: 171 | if p.grad is None: 172 | continue 173 | grad = p.grad.data.float() 174 | if grad.is_sparse: 175 | raise RuntimeError( 176 | 'Adam does not support sparse gradients, please consider SparseAdam instead') 177 | 178 | p_data_fp32 = p.data.float() 179 | 180 | state = self.state[p] 181 | 182 | if len(state) == 0: 183 | state['step'] = 0 184 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 185 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 186 | else: 187 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 188 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 189 | 190 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 191 | beta1, beta2 = group['betas'] 192 | 193 | state['step'] += 1 194 | 195 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 196 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 197 | 198 | denom = exp_avg_sq.sqrt().add_(group['eps']) 199 | bias_correction1 = 1 - beta1 ** state['step'] 200 | bias_correction2 = 1 - beta2 ** state['step'] 201 | 202 | if group['warmup'] > state['step']: 203 | scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] 204 | else: 205 | scheduled_lr = group['lr'] 206 | 207 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 208 | 209 | if group['weight_decay'] != 0: 210 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) 211 | 212 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 213 | 214 | p.data.copy_(p_data_fp32) 215 | 216 | return loss -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import torch 4 | from torch.nn import Parameter 5 | import string 6 | import numpy as np 7 | from nltk.metrics.distance import edit_distance 8 | from scipy import misc 9 | 10 | try: 11 | from torch.hub import load_state_dict_from_url 12 | except ImportError: 13 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 14 | 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | 17 | 18 | def to_numpy(tensor): 19 | if torch.is_tensor(tensor): 20 | return tensor.cpu().numpy() 21 | elif type(tensor).__module__ != 'numpy': 22 | raise ValueError("Cannot convert {} to numpy array" 23 | .format(type(tensor))) 24 | return tensor 25 | 26 | 27 | def load_test_data(image_path, size=256): 28 | img = misc.imread(image_path, mode='RGB') 29 | img = misc.imresize(img, [size, size]) 30 | img = np.expand_dims(img, axis=0) 31 | img = preprocessing(img) 32 | 33 | return img 34 | 35 | 36 | def preprocessing(x): 37 | x = x / 127.5 - 1 # -1 ~ 1 38 | return x 39 | 40 | 41 | def save_images(images, size, image_path): 42 | return imsave(inverse_transform(images), size, image_path) 43 | 44 | 45 | def inverse_transform(images): 46 | return (images + 1.) / 2 47 | 48 | 49 | def imsave(images, size, path): 50 | return misc.imsave(path, merge(images, size)) 51 | 52 | 53 | def merge(images, size): 54 | h, w = images.shape[1], images.shape[2] 55 | img = np.zeros((h * size[0], w * size[1], 3)) 56 | for idx, image in enumerate(images): 57 | i = idx % size[1] 58 | j = idx // size[1] 59 | img[h * j:h * (j + 1), w * i:w * (i + 1), :] = image 60 | 61 | return img 62 | 63 | 64 | def check_folder(log_dir): 65 | if not os.path.exists(log_dir): 66 | os.makedirs(log_dir) 67 | return log_dir 68 | 69 | 70 | def str2bool(x): 71 | return x.lower() in ('true') 72 | 73 | 74 | def cam(x, size=256): 75 | x = x - np.min(x) 76 | cam_img = x / np.max(x) 77 | cam_img = np.uint8(255 * cam_img) 78 | cam_img = cv2.resize(cam_img, (size, size)) 79 | cam_img = cv2.applyColorMap(cam_img, cv2.COLORMAP_JET) 80 | return cam_img / 255.0 81 | 82 | 83 | def imagenet_norm(x): 84 | mean = [0.485, 0.456, 0.406] 85 | std = [0.299, 0.224, 0.225] 86 | mean = torch.FloatTensor(mean).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(x.device) 87 | std = torch.FloatTensor(std).unsqueeze(0).unsqueeze(2).unsqueeze(3).to(x.device) 88 | return (x - mean) / std 89 | 90 | 91 | def denorm(x): 92 | return x * 0.5 + 0.5 93 | 94 | 95 | def tensor2numpy(x): 96 | return x.detach().cpu().numpy().transpose(1, 2, 0) 97 | 98 | 99 | def RGB2BGR(x): 100 | return cv2.cvtColor(x, cv2.COLOR_RGB2BGR) 101 | 102 | 103 | def load_char_dict(char_dict_file): 104 | with open(char_dict_file) as f: 105 | lines = f.readlines() 106 | char_dict = [line.rstrip('\n') for line in lines] 107 | return char_dict 108 | 109 | 110 | def normalize_text(text): 111 | text = ''.join(filter(lambda x: x in (string.digits + string.ascii_letters), text)) 112 | return text 113 | 114 | 115 | def edit_distance_loss(s1, s2): 116 | """ 计算编辑距离, 作为一种 loss function 117 | """ 118 | if len(s1) > len(s2): 119 | s1, s2 = s2, s1 120 | if len(s2) == 0: # 空串跳出 121 | return 0 122 | distances = range(len(s1) + 1) 123 | for i, c2 in enumerate(s2): 124 | distances_ = [i + 1] 125 | for j, c1 in enumerate(s1): 126 | if c1 == c2: 127 | distances_.append(distances[j]) 128 | else: 129 | distances_.append( 130 | 1 + min((distances[j], distances[j + 1], distances_[-1]))) 131 | distances = distances_ 132 | loss = 1 - distances[-1] * 1.0 / len(s2) 133 | return loss 134 | 135 | 136 | def compute_loss(preds_str, labels, opt, case_sensitive=False, filtering_punctuation=True): 137 | # calculate accuracy. 138 | n_correct = 0 139 | norm_ED = 0 140 | for pred, gt in zip(preds_str, labels): 141 | if 'Attn' in opt.Prediction: 142 | pred = pred[:pred.find('[s]')] # prune after "end of sentence" token ([s]) 143 | gt = gt[:gt.find('[s]')] 144 | 145 | if not case_sensitive: 146 | pred = normalize_text(pred).lower() 147 | gt = normalize_text(gt).lower() 148 | # pred = pred.lower() 149 | # gt = gt.lower() 150 | 151 | if pred == gt: 152 | n_correct += 1 153 | if len(gt) == 0: 154 | norm_ED += 1 155 | else: 156 | norm_ED += edit_distance_loss(pred, gt) 157 | return n_correct, norm_ED 158 | 159 | 160 | class AttnLabelConverter(object): 161 | """ Convert between text-label and text-index """ 162 | 163 | def __init__(self, character): 164 | # character (str): set of the possible characters. 165 | # [GO] for the start token of the attention decoder. [s] for end-of-sentence token. 166 | list_token = ['[GO]', '[s]'] # ['[s]','[UNK]','[PAD]','[GO]'] 167 | list_character = list(character) 168 | self.character = list_token + list_character 169 | 170 | self.dict = {} 171 | for i, char in enumerate(self.character): 172 | # print(i, char) 173 | self.dict[char] = i 174 | 175 | def encode(self, text, batch_max_length=25): 176 | """ convert text-label into text-index. 177 | input: 178 | text: text labels of each image. [batch_size] 179 | batch_max_length: max length of text label in the batch. 25 by default 180 | 181 | output: 182 | text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token. 183 | text[:, 0] is [GO] token and text is padded with [GO] token after [s] token. 184 | length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size] 185 | """ 186 | length = [len(s) + 1 for s in text] # +1 for [s] at end of sentence. 187 | # batch_max_length = max(length) # this is not allowed for multi-gpu setting 188 | batch_max_length += 1 189 | # additional +1 for [GO] at first step. batch_text is padded with [GO] token after [s] token. 190 | batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0) 191 | for i, t in enumerate(text): 192 | text = list(t) 193 | text.append('[s]') 194 | text = [self.dict[char] for char in text] 195 | batch_text[i][1:1 + len(text)] = torch.LongTensor(text) # batch_text[:, 0] = [GO] token 196 | return (batch_text.to(device), torch.IntTensor(length).to(device)) 197 | 198 | def decode(self, text_index, length): 199 | """ convert text-index into text-label. """ 200 | texts = [] 201 | for index, l in enumerate(length): 202 | text = ''.join([self.character[i] for i in text_index[index, :]]) 203 | texts.append(text) 204 | return texts 205 | 206 | 207 | class Averager(object): 208 | """Compute average for torch.Tensor, used for loss average.""" 209 | 210 | def __init__(self): 211 | self.reset() 212 | 213 | def add(self, v): 214 | count = v.data.numel() 215 | v = v.data.sum() 216 | self.n_count += count 217 | self.sum += v 218 | 219 | def reset(self): 220 | self.n_count = 0 221 | self.sum = 0 222 | 223 | def val(self): 224 | res = 0 225 | if self.n_count != 0: 226 | res = self.sum / float(self.n_count) 227 | return res 228 | 229 | 230 | def copy_state_dict(state_dict, model, strip=None): 231 | tgt_state = model.state_dict() 232 | copied_names = set() 233 | for name, param in state_dict.items(): 234 | if strip is not None and name.startswith(strip): 235 | name = name[len(strip):] 236 | if name not in tgt_state: 237 | continue 238 | if isinstance(param, Parameter): 239 | param = param.data 240 | if param.size() != tgt_state[name].size(): 241 | print('mismatch:', name, param.size(), tgt_state[name].size()) 242 | continue 243 | tgt_state[name].copy_(param) 244 | copied_names.add(name) 245 | 246 | missing = set(tgt_state.keys()) - copied_names 247 | if len(missing) > 0: 248 | print("missing keys in state_dict:", missing) 249 | 250 | return model 251 | 252 | 253 | def adjust_learning_rate(optimizer, decay=0.1): 254 | """Sets the learning rate to the initial LR decayed by 0.5 every 20 epochs""" 255 | for param_group in optimizer.param_groups: 256 | param_group['lr'] = decay * param_group['lr'] 257 | 258 | 259 | def clip_gradient(model, clip_norm): 260 | """Computes a gradient clipping coefficient based on gradient norm.""" 261 | totalnorm = 0 262 | for p in model.parameters(): 263 | if p.requires_grad: 264 | modulenorm = p.grad.data.norm() 265 | totalnorm += modulenorm ** 2 266 | totalnorm = np.sqrt(totalnorm) 267 | 268 | norm = clip_norm / max(totalnorm, clip_norm) 269 | for p in model.parameters(): 270 | if p.requires_grad: 271 | p.grad.mul_(norm) 272 | -------------------------------------------------------------------------------- /losses/ef_focal_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.autograd import Variable 6 | 7 | import pdb 8 | 9 | 10 | class EFocalLoss(nn.Module): 11 | """ 12 | This criterion is a implemenation of Focal Loss, which is proposed in 13 | Focal Loss for Dense Object Detection. 14 | 15 | Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class]) 16 | 17 | The losses are averaged across observations for each minibatch. 18 | Args: 19 | alpha(1D Tensor, Variable) : the scalar factor for this criterion 20 | gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), 21 | putting more focus on hard, misclassified examples 22 | size_average(bool): size_average(bool): By default, the losses are averaged over observations for each minibatch. 23 | However, if the field size_average is set to False, the losses are 24 | instead summed for each minibatch. 25 | """ 26 | 27 | def __init__(self, class_num, alpha=None, gamma=2, size_average=True): 28 | super(EFocalLoss, self).__init__() 29 | if alpha is None: 30 | self.alpha = Variable(torch.ones(class_num, 1) * 1.0) 31 | else: 32 | if isinstance(alpha, Variable): 33 | self.alpha = alpha 34 | else: 35 | self.alpha = Variable(alpha) 36 | self.gamma = gamma 37 | self.class_num = class_num 38 | self.size_average = size_average 39 | 40 | def forward(self, inputs, targets): 41 | N = inputs.size(0) 42 | # print(N) 43 | C = inputs.size(1) 44 | # inputs = F.sigmoid(inputs) 45 | P = F.softmax(inputs) 46 | class_mask = inputs.data.new(N, C).fill_(0) 47 | class_mask = Variable(class_mask) 48 | ids = targets.view(-1, 1) 49 | class_mask.scatter_(1, ids.data, 1.) 50 | # print(class_mask) 51 | 52 | if inputs.is_cuda and not self.alpha.is_cuda: 53 | self.alpha = self.alpha.cuda() 54 | alpha = self.alpha[ids.data.view(-1)] 55 | 56 | probs = (P * class_mask).sum(1).view(-1, 1) 57 | log_p = probs.log() 58 | # print('probs size= {}'.format(probs.size())) 59 | # print(probs) 60 | batch_loss = -alpha * torch.exp(-self.gamma * probs) * log_p 61 | # print('-----bacth_loss------') 62 | # print(batch_loss) 63 | 64 | if self.size_average: 65 | loss = batch_loss.mean() 66 | else: 67 | loss = batch_loss.sum() 68 | return loss 69 | 70 | 71 | class FocalLoss(nn.Module): 72 | r""" 73 | This criterion is a implemenation of Focal Loss, which is proposed in 74 | Focal Loss for Dense Object Detection. 75 | 76 | Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class]) 77 | 78 | The losses are averaged across observations for each minibatch. 79 | Args: 80 | alpha(1D Tensor, Variable) : the scalar factor for this criterion 81 | gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), 82 | putting more focus on hard, misclassified examples 83 | size_average(bool): size_average(bool): By default, the losses are averaged over observations for each minibatch. 84 | However, if the field size_average is set to False, the losses are 85 | instead summed for each minibatch. 86 | """ 87 | 88 | def __init__(self, class_num, alpha=None, gamma=2, size_average=True, sigmoid=False, 89 | reduce=True): 90 | super(FocalLoss, self).__init__() 91 | if alpha is None: 92 | self.alpha = Variable(torch.ones(class_num, 1) * 1.0) 93 | else: 94 | if isinstance(alpha, Variable): 95 | self.alpha = alpha 96 | else: 97 | self.alpha = Variable(alpha) 98 | self.gamma = gamma 99 | self.class_num = class_num 100 | self.size_average = size_average 101 | self.sigmoid = sigmoid 102 | self.reduce = reduce 103 | 104 | def forward(self, inputs, targets): 105 | N = inputs.size(0) 106 | # print(N) 107 | C = inputs.size(1) 108 | if self.sigmoid: 109 | P = F.sigmoid(inputs) 110 | # F.softmax(inputs) 111 | if targets == 0: 112 | probs = 1 - P # (P * class_mask).sum(1).view(-1, 1) 113 | log_p = probs.log() 114 | batch_loss = - (torch.pow((1 - probs), self.gamma)) * log_p 115 | if targets == 1: 116 | probs = P # (P * class_mask).sum(1).view(-1, 1) 117 | log_p = probs.log() 118 | batch_loss = - (torch.pow((1 - probs), self.gamma)) * log_p 119 | else: 120 | # inputs = F.sigmoid(inputs) 121 | P = F.softmax(inputs) 122 | 123 | class_mask = inputs.data.new(N, C).fill_(0) 124 | class_mask = Variable(class_mask) 125 | ids = targets.view(-1, 1) 126 | class_mask.scatter_(1, ids.data, 1.) 127 | # print(class_mask) 128 | 129 | if inputs.is_cuda and not self.alpha.is_cuda: 130 | self.alpha = self.alpha.cuda() 131 | alpha = self.alpha[ids.data.view(-1)] 132 | 133 | probs = (P * class_mask).sum(1).view(-1, 1) 134 | 135 | log_p = probs.log() 136 | # print('probs size= {}'.format(probs.size())) 137 | # print(probs) 138 | 139 | batch_loss = -alpha * (torch.pow((1 - probs), self.gamma)) * log_p 140 | # print('-----bacth_loss------') 141 | # print(batch_loss) 142 | 143 | if not self.reduce: 144 | return batch_loss 145 | if self.size_average: 146 | loss = batch_loss.mean() 147 | else: 148 | loss = batch_loss.sum() 149 | return loss 150 | 151 | 152 | class FocalPseudo(nn.Module): 153 | r""" 154 | This criterion is a implemenation of Focal Loss, which is proposed in 155 | Focal Loss for Dense Object Detection. 156 | 157 | Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class]) 158 | 159 | The losses are averaged across observations for each minibatch. 160 | Args: 161 | alpha(1D Tensor, Variable) : the scalar factor for this criterion 162 | gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), 163 | putting more focus on hard, misclassified examples 164 | size_average(bool): size_average(bool): By default, the losses are averaged over observations for each minibatch. 165 | However, if the field size_average is set to False, the losses are 166 | instead summed for each minibatch. 167 | """ 168 | 169 | def __init__(self, class_num, alpha=None, gamma=2, size_average=True, threshold=0.8): 170 | super(FocalPseudo, self).__init__() 171 | if alpha is None: 172 | self.alpha = Variable(torch.ones(class_num, 1) * 1.0) 173 | else: 174 | if isinstance(alpha, Variable): 175 | self.alpha = alpha 176 | else: 177 | self.alpha = Variable(alpha) 178 | self.gamma = gamma 179 | self.class_num = class_num 180 | self.size_average = size_average 181 | self.threshold = threshold 182 | 183 | def forward(self, inputs): 184 | N = inputs.size(0) 185 | C = inputs.size(1) 186 | inputs = inputs[0, :, :] 187 | # print(inputs) 188 | # pdb.set_trace() 189 | inputs, ind = torch.max(inputs, 1) 190 | ones = torch.ones(inputs.size()).cuda() 191 | value = torch.where(inputs > self.threshold, inputs, ones) 192 | # 193 | # pdb.set_trace() 194 | # ind 195 | # print(value) 196 | try: 197 | ind = value.ne(1) 198 | indexes = torch.nonzero(ind) 199 | # value2 = inputs[indexes] 200 | inputs = inputs[indexes] 201 | log_p = inputs.log() 202 | # print('probs size= {}'.format(probs.size())) 203 | # print(probs) 204 | if not self.gamma == 0: 205 | batch_loss = - (torch.pow((1 - inputs), self.gamma)) * log_p 206 | else: 207 | batch_loss = - log_p 208 | except: 209 | # inputs = inputs#[indexes] 210 | log_p = value.log() 211 | # print('probs size= {}'.format(probs.size())) 212 | # print(probs) 213 | if not self.gamma == 0: 214 | batch_loss = - (torch.pow((1 - inputs), self.gamma)) * log_p 215 | else: 216 | batch_loss = - log_p 217 | # print('-----bacth_loss------') 218 | # print(batch_loss) 219 | # batch_loss = batch_loss #* weight 220 | if self.size_average: 221 | try: 222 | loss = batch_loss.mean() # + 0.1*balance 223 | except: 224 | pdb.set_trace() 225 | else: 226 | loss = batch_loss.sum() 227 | return loss 228 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import string 4 | import time 5 | 6 | import numpy as np 7 | import torch 8 | import torch.backends.cudnn as cudnn 9 | import torch.utils.data 10 | 11 | from dataset import hierarchical_dataset, AlignCollate 12 | from seqda_model import Model 13 | from utils import AttnLabelConverter, Averager 14 | from utils import load_char_dict, compute_loss 15 | 16 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | 18 | 19 | def benchmark_all_eval(model, criterion, converter, opt, calculate_infer_time=False): 20 | """ evaluation with 10 benchmark evaluation datasets """ 21 | # The evaluation datasets, dataset order is same with Table 1 in our paper. 22 | eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', 'IC13_857', 23 | 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80'] 24 | 25 | if calculate_infer_time: 26 | evaluation_batch_size = 1 # batch_size should be 1 to calculate the GPU inference time per image. 27 | else: 28 | evaluation_batch_size = opt.batch_size 29 | 30 | list_accuracy = [] 31 | total_forward_time = 0 32 | total_evaluation_data_number = 0 33 | total_correct_number = 0 34 | print('-' * 80) 35 | for eval_data in eval_data_list: 36 | eval_data_path = os.path.join(opt.eval_data, eval_data) 37 | AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, 38 | keep_ratio_with_pad=opt.PAD) 39 | eval_data = hierarchical_dataset(root=eval_data_path, opt=opt) 40 | evaluation_loader = torch.utils.data.DataLoader( 41 | eval_data, batch_size=evaluation_batch_size, 42 | shuffle=False, 43 | num_workers=int(opt.workers), 44 | collate_fn=AlignCollate_evaluation, pin_memory=True) 45 | 46 | _, accuracy_by_best_model, norm_ED_by_best_model, _, _, infer_time, length_of_data = validation( 47 | model, criterion, evaluation_loader, converter, opt) 48 | list_accuracy.append(f'{accuracy_by_best_model:0.3f}') 49 | total_forward_time += infer_time 50 | total_evaluation_data_number += len(eval_data) 51 | total_correct_number += accuracy_by_best_model * length_of_data 52 | print('Acc %0.3f\t normalized_ED %0.3f' % (accuracy_by_best_model, norm_ED_by_best_model)) 53 | print('-' * 80) 54 | 55 | averaged_forward_time = total_forward_time / total_evaluation_data_number * 1000 56 | total_accuracy = total_correct_number / total_evaluation_data_number 57 | params_num = sum([np.prod(p.size()) for p in model.parameters()]) 58 | 59 | evaluation_log = 'accuracy: ' 60 | for name, accuracy in zip(eval_data_list, list_accuracy): 61 | evaluation_log += f'{name}: {accuracy}\t' 62 | evaluation_log += f'total_accuracy: {total_accuracy:0.3f}\t' 63 | evaluation_log += f'averaged_infer_time: {averaged_forward_time:0.3f}\t# parameters: {params_num/1e6:0.3f}' 64 | print(evaluation_log) 65 | with open(f'./result/{opt.experiment_name}/log_all_evaluation.txt', 'a') as log: 66 | log.write(evaluation_log + '\n') 67 | 68 | return None 69 | 70 | 71 | def validation(model, criterion, evaluation_loader, converter, opt): 72 | """ validation or evaluation """ 73 | n_correct = 0 74 | norm_ED = 0 75 | length_of_data = 0 76 | infer_time = 0 77 | valid_loss_avg = Averager() 78 | 79 | for i, (image_tensors, labels) in enumerate(evaluation_loader): 80 | batch_size = image_tensors.size(0) 81 | length_of_data = length_of_data + batch_size 82 | image = image_tensors.to(device) 83 | # For max length prediction 84 | length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) 85 | text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) 86 | 87 | text_for_loss, length_for_loss = converter.encode(labels, 88 | batch_max_length=opt.batch_max_length) 89 | 90 | start_time = time.time() 91 | 92 | preds, global_feature, local_feature, attention_weights, transformed_imgs, control_points = model( 93 | image, text_for_pred, is_train=False) 94 | 95 | forward_time = time.time() - start_time 96 | 97 | preds = preds[:, :text_for_loss.shape[1] - 1, :] 98 | target = text_for_loss[:, 1:] # without [GO] Symbol 99 | cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), 100 | target.contiguous().view(-1)) 101 | 102 | # select max probabilty (greedy decoding) then decode index to character 103 | preds_score, preds_index = preds.max(2) 104 | preds_str = converter.decode(preds_index, length_for_pred) 105 | labels = converter.decode(text_for_loss[:, 1:], length_for_loss) 106 | 107 | infer_time += forward_time 108 | valid_loss_avg.add(cost) 109 | 110 | # calculate accuracy. 111 | batch_n_correct, batch_char_acc = compute_loss(preds_str, labels, opt) 112 | n_correct += batch_n_correct 113 | norm_ED += batch_char_acc 114 | 115 | accuracy = n_correct / float(length_of_data) * 100 116 | norm_ED = norm_ED / float(length_of_data) * 100 117 | 118 | return valid_loss_avg.val(), accuracy, norm_ED, preds_str, labels, infer_time, length_of_data 119 | 120 | 121 | def load(model, saved_model): 122 | params = torch.load(saved_model) 123 | 124 | if 'model' not in params: 125 | model.load_state_dict(params) 126 | else: 127 | model.load_state_dict(params['model']) 128 | 129 | 130 | def test(opt): 131 | """ model configuration """ 132 | converter = AttnLabelConverter(opt.character) 133 | opt.num_class = len(converter.character) 134 | 135 | if opt.rgb: 136 | opt.input_channel = 3 137 | model = Model(opt) 138 | print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, 139 | opt.output_channel, 140 | opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, 141 | opt.FeatureExtraction, 142 | opt.SequenceModeling, opt.Prediction) 143 | model = torch.nn.DataParallel(model).to(device) 144 | 145 | # load model 146 | print('loading pretrained model from %s' % opt.saved_model) 147 | # model.load_state_dict(torch.load(opt.saved_model)) 148 | load(model, opt.saved_model) 149 | opt.experiment_name = '_'.join(opt.saved_model.split('/')[1:]) 150 | # print(model) 151 | 152 | """ keep evaluation model and result logs """ 153 | os.makedirs(f'./result/{opt.experiment_name}', exist_ok=True) 154 | os.system(f'cp {opt.saved_model} ./result/{opt.experiment_name}/') 155 | 156 | """ setup loss """ 157 | criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to( 158 | device) # ignore [GO] token = ignore index 0 159 | 160 | """ evaluation """ 161 | model.eval() 162 | with torch.no_grad(): 163 | if opt.benchmark_all_eval: # evaluation with 10 benchmark evaluation datasets 164 | benchmark_all_eval(model, criterion, converter, opt) 165 | else: 166 | AlignCollate_evaluation = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, 167 | keep_ratio_with_pad=opt.PAD) 168 | eval_data = hierarchical_dataset(root=opt.eval_data, opt=opt) 169 | evaluation_loader = torch.utils.data.DataLoader( 170 | eval_data, batch_size=opt.batch_size, 171 | shuffle=False, 172 | num_workers=int(opt.workers), 173 | collate_fn=AlignCollate_evaluation, pin_memory=True) 174 | _, accuracy_by_best_model, char_acc_by_best_model, _, _, _, _ = validation( 175 | model, criterion, evaluation_loader, converter, opt) 176 | 177 | print(accuracy_by_best_model) 178 | print(char_acc_by_best_model) 179 | with open('./result/{0}/log_evaluation.txt'.format(opt.experiment_name), 'a') as log: 180 | log.write(str(accuracy_by_best_model) + '\n') 181 | log.write(str(char_acc_by_best_model) + '\n') 182 | 183 | 184 | if __name__ == '__main__': 185 | parser = argparse.ArgumentParser() 186 | parser.add_argument('--eval_data', required=True, help='path to evaluation dataset') 187 | parser.add_argument('--benchmark_all_eval', action='store_true', 188 | help='evaluate 10 benchmark evaluation datasets') 189 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) 190 | parser.add_argument('--batch_size', type=int, default=192, help='input batch size') 191 | parser.add_argument('--saved_model', required=True, help="path to saved_model to evaluation") 192 | parser.add_argument('--visualize', action='store_true', help='use rgb input') 193 | """ Data processing """ 194 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') 195 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') 196 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') 197 | parser.add_argument('--rgb', action='store_true', help='use rgb input') 198 | parser.add_argument('--char_dict', type=str, default=None, 199 | help="path to char dict dataset/iam/char_dict.txt") 200 | parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', 201 | help='character label') 202 | parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode') 203 | parser.add_argument('--ignore_special_char', action='store_true', 204 | help='for evaluation mode, ignore special char') 205 | parser.add_argument('--ignore_case_sensitive', action='store_true', 206 | help='for evaluation mode, ignore sensitive character') 207 | parser.add_argument('--PAD', action='store_true', 208 | help='whether to keep ratio then pad for image resize') 209 | parser.add_argument('--data_filtering_off', action='store_true', 210 | help='for data_filtering_off mode') 211 | """ Model Architecture """ 212 | parser.add_argument('--Transformation', type=str, required=True, 213 | help='Transformation stage. None|TPS') 214 | parser.add_argument('--FeatureExtraction', type=str, required=True, 215 | help='FeatureExtraction stage. VGG|RCNN|ResNet') 216 | parser.add_argument('--SequenceModeling', type=str, required=True, 217 | help='SequenceModeling stage. None|BiLSTM') 218 | parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn') 219 | parser.add_argument('--num_fiducial', type=int, default=20, 220 | help='number of fiducial points of TPS-STN') 221 | parser.add_argument('--input_channel', type=int, default=1, 222 | help='the number of input channel of Feature extractor') 223 | parser.add_argument('--output_channel', type=int, default=512, 224 | help='the number of output channel of Feature extractor') 225 | parser.add_argument('--hidden_size', type=int, default=256, 226 | help='the size of the LSTM hidden state') 227 | 228 | opt = parser.parse_args() 229 | 230 | """ vocab / character number configuration """ 231 | if opt.sensitive: 232 | opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). 233 | if opt.char_dict is not None: 234 | opt.character = load_char_dict(opt.char_dict)[3:-2] # 去除Attention 和 CTC引入的一些特殊符号 235 | cudnn.benchmark = True 236 | cudnn.deterministic = True 237 | opt.num_gpu = torch.cuda.device_count() 238 | 239 | test(opt) 240 | -------------------------------------------------------------------------------- /modules/densenet.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.checkpoint as cp 6 | from collections import OrderedDict 7 | from utils import load_state_dict_from_url 8 | 9 | import numpy as np 10 | 11 | 12 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] 13 | 14 | model_urls = { 15 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 16 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 17 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 18 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 19 | } 20 | 21 | 22 | def _bn_function_factory(norm, relu, conv): 23 | def bn_function(*inputs): 24 | concated_features = torch.cat(inputs, 1) 25 | bottleneck_output = conv(relu(norm(concated_features))) 26 | return bottleneck_output 27 | 28 | return bn_function 29 | 30 | 31 | class _DenseLayer(nn.Sequential): 32 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False): 33 | super(_DenseLayer, self).__init__() 34 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 35 | self.add_module('relu1', nn.ReLU(inplace=True)), 36 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 37 | growth_rate, kernel_size=1, stride=1, 38 | bias=False)), 39 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 40 | self.add_module('relu2', nn.ReLU(inplace=True)), 41 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 42 | kernel_size=3, stride=1, padding=1, 43 | bias=False)), 44 | self.drop_rate = drop_rate 45 | self.memory_efficient = memory_efficient 46 | 47 | def forward(self, *prev_features): 48 | bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1) 49 | if self.memory_efficient and any(prev_feature.requires_grad for prev_feature in prev_features): 50 | bottleneck_output = cp.checkpoint(bn_function, *prev_features) 51 | else: 52 | bottleneck_output = bn_function(*prev_features) 53 | new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) 54 | if self.drop_rate > 0: 55 | new_features = F.dropout(new_features, p=self.drop_rate, 56 | training=self.training) 57 | return new_features 58 | 59 | 60 | class _DenseBlock(nn.Module): 61 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False): 62 | super(_DenseBlock, self).__init__() 63 | for i in range(num_layers): 64 | layer = _DenseLayer( 65 | num_input_features + i * growth_rate, 66 | growth_rate=growth_rate, 67 | bn_size=bn_size, 68 | drop_rate=drop_rate, 69 | memory_efficient=memory_efficient, 70 | ) 71 | self.add_module('denselayer%d' % (i + 1), layer) 72 | 73 | def forward(self, init_features): 74 | features = [init_features] 75 | for name, layer in self.named_children(): 76 | new_features = layer(*features) 77 | features.append(new_features) 78 | return torch.cat(features, 1) 79 | 80 | 81 | class _Transition(nn.Sequential): 82 | def __init__(self, num_input_features, num_output_features): 83 | super(_Transition, self).__init__() 84 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 85 | self.add_module('relu', nn.ReLU(inplace=True)) 86 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 87 | kernel_size=1, stride=1, bias=False)) 88 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 89 | 90 | 91 | class DenseNet(nn.Module): 92 | r"""Densenet-BC model class, based on 93 | `"Densely Connected Convolutional Networks" `_ 94 | Args: 95 | growth_rate (int) - how many filters to add each layer (`k` in paper) 96 | block_config (list of 4 ints) - how many layers in each pooling block 97 | num_init_features (int) - the number of filters to learn in the first convolution layer 98 | bn_size (int) - multiplicative factor for number of bottle neck layers 99 | (i.e. bn_size * k features in the bottleneck layer) 100 | drop_rate (float) - dropout rate after each dense layer 101 | num_classes (int) - number of classification classes 102 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 103 | but slower. Default: *False*. See `"paper" `_ 104 | """ 105 | 106 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 107 | num_init_features=64, bn_size=4, drop_rate=0, 108 | num_classes=1000, 109 | memory_efficient=False, 110 | channel=3): 111 | 112 | super(DenseNet, self).__init__() 113 | 114 | # First convolution 115 | self.features = nn.Sequential(OrderedDict([ 116 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, 117 | padding=3, bias=False)), 118 | ('norm0', nn.BatchNorm2d(num_init_features)), 119 | ('relu0', nn.ReLU(inplace=True)), 120 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 121 | ])) 122 | 123 | # Each denseblock 124 | num_features = num_init_features 125 | for i, num_layers in enumerate(block_config): 126 | block = _DenseBlock( 127 | num_layers=num_layers, 128 | num_input_features=num_features, 129 | bn_size=bn_size, 130 | growth_rate=growth_rate, 131 | drop_rate=drop_rate, 132 | memory_efficient=memory_efficient 133 | ) 134 | self.features.add_module('denseblock%d' % (i + 1), block) 135 | num_features = num_features + num_layers * growth_rate 136 | if i != len(block_config) - 1: 137 | trans = _Transition(num_input_features=num_features, 138 | num_output_features=num_features // 2) 139 | self.features.add_module('transition%d' % (i + 1), trans) 140 | num_features = num_features // 2 141 | 142 | # Final batch norm 143 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 144 | 145 | # Linear layer 146 | self.classifier = nn.Linear(num_features, num_classes) 147 | 148 | # Official init from torch repo. 149 | for m in self.modules(): 150 | if isinstance(m, nn.Conv2d): 151 | nn.init.kaiming_normal_(m.weight) 152 | elif isinstance(m, nn.BatchNorm2d): 153 | nn.init.constant_(m.weight, 1) 154 | nn.init.constant_(m.bias, 0) 155 | elif isinstance(m, nn.Linear): 156 | nn.init.constant_(m.bias, 0) 157 | 158 | def forward(self, x): 159 | self.feature_maps = self.features(x) 160 | out = F.relu(self.feature_maps, inplace=True) 161 | out = F.adaptive_avg_pool2d(out, (1, 1)) 162 | out = torch.flatten(out, 1) 163 | out = self.classifier(out) 164 | return out 165 | 166 | 167 | def _load_state_dict(model, model_url, progress): 168 | # '.'s are no longer allowed in module names, but previous _DenseLayer 169 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 170 | # They are also in the checkpoints in model_urls. This pattern is used 171 | # to find such keys. 172 | pattern = re.compile( 173 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 174 | 175 | state_dict = load_state_dict_from_url(model_url, progress=progress) 176 | model_dict = model.state_dict() 177 | for key in list(state_dict.keys()): 178 | res = pattern.match(key) 179 | if res: 180 | new_key = res.group(1) + res.group(2) 181 | state_dict[new_key] = state_dict[key] 182 | del state_dict[key] 183 | pretrained_dict = {} 184 | for k,v in state_dict.items(): 185 | if k in model_dict and v.size()==model_dict[k].size(): 186 | pretrained_dict[k] = v 187 | print(k,v.size()) 188 | 189 | model_dict.update(pretrained_dict) 190 | # print(model_dict['classifier.weighst'].size()) 191 | model.load_state_dict(model_dict) 192 | 193 | 194 | def _densenet(arch, growth_rate, block_config, num_init_features, pretrained, progress, 195 | **kwargs): 196 | model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) 197 | if pretrained: 198 | _load_state_dict(model, model_urls[arch], progress) 199 | return model 200 | 201 | 202 | def densenet121(pretrained=False, progress=True, **kwargs): 203 | r"""Densenet-121 model from 204 | `"Densely Connected Convolutional Networks" `_ 205 | Args: 206 | pretrained (bool): If True, returns a model pre-trained on ImageNet 207 | progress (bool): If True, displays a progress bar of the download to stderr 208 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 209 | but slower. Default: *False*. See `"paper" `_ 210 | """ 211 | return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress, 212 | **kwargs) 213 | 214 | 215 | def densenet161(pretrained=False, progress=True, **kwargs): 216 | r"""Densenet-161 model from 217 | `"Densely Connected Convolutional Networks" `_ 218 | Args: 219 | pretrained (bool): If True, returns a model pre-trained on ImageNet 220 | progress (bool): If True, displays a progress bar of the download to stderr 221 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 222 | but slower. Default: *False*. See `"paper" `_ 223 | """ 224 | return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress, 225 | **kwargs) 226 | 227 | 228 | def densenet169(pretrained=False, progress=True, **kwargs): 229 | r"""Densenet-169 model from 230 | `"Densely Connected Convolutional Networks" `_ 231 | Args: 232 | pretrained (bool): If True, returns a model pre-trained on ImageNet 233 | progress (bool): If True, displays a progress bar of the download to stderr 234 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 235 | but slower. Default: *False*. See `"paper" `_ 236 | """ 237 | return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress, 238 | **kwargs) 239 | 240 | 241 | def densenet201(pretrained=False, progress=True, **kwargs): 242 | r"""Densenet-201 model from 243 | `"Densely Connected Convolutional Networks" `_ 244 | Args: 245 | pretrained (bool): If True, returns a model pre-trained on ImageNet 246 | progress (bool): If True, displays a progress bar of the download to stderr 247 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 248 | but slower. Default: *False*. See `"paper" `_ 249 | """ 250 | return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress, 251 | **kwargs) 252 | 253 | if __name__ == "__main__": 254 | model = densenet121(pretrained=True, num_classes=94, channel=3) 255 | model.eval() 256 | x = torch.Tensor(np.ones([1,3,64,64])) 257 | with torch.no_grad(): 258 | out = model(x) 259 | print(out) 260 | print(out.size(),model.feature_maps.size()) 261 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | import six 5 | import math 6 | import lmdb 7 | import torch 8 | 9 | from natsort import natsorted 10 | from PIL import Image 11 | import numpy as np 12 | from torch.utils.data import Dataset, ConcatDataset, Subset 13 | from torch._utils import _accumulate 14 | import torchvision.transforms as transforms 15 | 16 | 17 | class Batch_Balanced_Dataset(object): 18 | 19 | def __init__(self, opt, train_data, select_data, batch_ratio,is_shuffle=True): 20 | """ 21 | Modulate the data ratio in the batch. 22 | For example, when select_data is "MJ-ST" and batch_ratio is "0.5-0.5", 23 | the 50% of the batch is filled with MJ and the other 50% of the batch is filled with ST. 24 | """ 25 | print('-' * 80) 26 | print( 27 | f'dataset_root: {train_data}\nopt.select_data: { select_data}\nopt.batch_ratio: {batch_ratio}') 28 | assert len(select_data) == len(batch_ratio) 29 | 30 | _AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 31 | self.data_loader_list = [] 32 | self.dataloader_iter_list = [] 33 | batch_size_list = [] 34 | Total_batch_size = 0 35 | self.total_data_size = 0 36 | for selected_d, batch_ratio_d in zip(select_data, batch_ratio): 37 | _batch_size = max(round(opt.batch_size * float(batch_ratio_d)), 1) 38 | print('-' * 80) 39 | _dataset = hierarchical_dataset(root=train_data, opt=opt, select_data=[selected_d]) 40 | total_number_dataset = len(_dataset) 41 | 42 | """ 43 | The total number of data can be modified with opt.total_data_usage_ratio. 44 | ex) opt.total_data_usage_ratio = 1 indicates 100% usage, and 0.2 indicates 20% usage. 45 | See 4.2 section in our paper. 46 | """ 47 | number_dataset = int(total_number_dataset * float(opt.total_data_usage_ratio)) 48 | dataset_split = [number_dataset, total_number_dataset - number_dataset] 49 | indices = range(total_number_dataset) 50 | _dataset, _ = [Subset(_dataset, indices[offset - length:offset]) 51 | for offset, length in zip(_accumulate(dataset_split), dataset_split)] 52 | 53 | self.total_data_size += len(_dataset) 54 | print( 55 | f'num total samples of {selected_d}: {total_number_dataset} x {opt.total_data_usage_ratio} (total_data_usage_ratio) = {len(_dataset)}') 56 | print( 57 | f'num samples of {selected_d} per batch: {opt.batch_size} x {float(batch_ratio_d)} (batch_ratio) = {_batch_size}') 58 | batch_size_list.append(str(_batch_size)) 59 | Total_batch_size += _batch_size 60 | 61 | _data_loader = torch.utils.data.DataLoader( 62 | _dataset, batch_size=_batch_size, 63 | shuffle=is_shuffle, 64 | num_workers=int(opt.workers), 65 | collate_fn=_AlignCollate, pin_memory=True) 66 | self.data_loader_list.append(_data_loader) 67 | self.dataloader_iter_list.append(iter(_data_loader)) 68 | print('-' * 80) 69 | print('Total_batch_size: ', '+'.join(batch_size_list), '=', str(Total_batch_size)) 70 | opt.batch_size = Total_batch_size 71 | print('-' * 80) 72 | 73 | def get_batch(self): 74 | balanced_batch_images = [] 75 | balanced_batch_texts = [] 76 | 77 | for i, data_loader_iter in enumerate(self.dataloader_iter_list): 78 | try: 79 | image, text = data_loader_iter.next() 80 | balanced_batch_images.append(image) 81 | balanced_batch_texts += text 82 | except StopIteration: 83 | self.dataloader_iter_list[i] = iter(self.data_loader_list[i]) 84 | image, text = self.dataloader_iter_list[i].next() 85 | balanced_batch_images.append(image) 86 | balanced_batch_texts += text 87 | except ValueError: 88 | pass 89 | 90 | balanced_batch_images = torch.cat(balanced_batch_images, 0) 91 | 92 | return balanced_batch_images, balanced_batch_texts 93 | 94 | 95 | def hierarchical_dataset(root, opt, select_data='/'): 96 | """ select_data='/' contains all sub-directory of root directory """ 97 | dataset_list = [] 98 | print(f'dataset_root: {root}\t dataset: {select_data[0]}') 99 | for dirpath, dirnames, filenames in os.walk(root): 100 | if not dirnames: 101 | select_flag = False 102 | for selected_d in select_data: 103 | if selected_d in dirpath: 104 | select_flag = True 105 | break 106 | 107 | if select_flag: 108 | dataset = LmdbDataset(dirpath, opt) 109 | print( 110 | f'sub-directory:\t/{os.path.relpath(dirpath, root)}\t num samples: {len(dataset)}') 111 | dataset_list.append(dataset) 112 | 113 | concatenated_dataset = ConcatDataset(dataset_list) 114 | 115 | return concatenated_dataset 116 | 117 | 118 | class LmdbDataset(Dataset): 119 | 120 | def __init__(self, root, opt): 121 | 122 | self.root = root 123 | self.opt = opt 124 | self.env = lmdb.open(root, max_readers=32, readonly=True, lock=False, readahead=False, 125 | meminit=False) 126 | if not self.env: 127 | print('cannot create lmdb from %s' % (root)) 128 | sys.exit(0) 129 | 130 | with self.env.begin(write=False) as txn: 131 | nSamples = int(txn.get('num-samples'.encode())) 132 | self.nSamples = nSamples 133 | 134 | if self.opt.data_filtering_off: 135 | # for fast check with no filtering 136 | self.filtered_index_list = [index + 1 for index in range(self.nSamples)] 137 | else: 138 | # Filtering 139 | self.filtered_index_list = [] 140 | for index in range(self.nSamples): 141 | index += 1 # lmdb starts with 1 142 | label_key = 'label-%09d'.encode() % index 143 | label = txn.get(label_key).decode('utf-8') 144 | 145 | if len(label) > self.opt.batch_max_length: 146 | # print(f'The length of the label is longer than max_length: length 147 | # {len(label)}, {label} in dataset {self.root}') 148 | continue 149 | 150 | # By default, images containing characters which are not in opt.character are filtered. 151 | # You can add [UNK] token to `opt.character` in utils.py instead of this filtering. 152 | out_of_char = f'[^{self.opt.character}]' 153 | if re.search(out_of_char, label.lower()): 154 | continue 155 | 156 | self.filtered_index_list.append(index) 157 | 158 | self.nSamples = len(self.filtered_index_list) 159 | 160 | def __len__(self): 161 | return self.nSamples 162 | 163 | def __getitem__(self, index): 164 | assert index <= len(self), 'index range error' 165 | index = self.filtered_index_list[index] 166 | 167 | with self.env.begin(write=False) as txn: 168 | label_key = 'label-%09d'.encode() % index 169 | label = txn.get(label_key).decode('utf-8') 170 | img_key = 'image-%09d'.encode() % index 171 | imgbuf = txn.get(img_key) 172 | 173 | buf = six.BytesIO() 174 | buf.write(imgbuf) 175 | buf.seek(0) 176 | try: 177 | if self.opt.rgb: 178 | img = Image.open(buf).convert('RGB') # for color image 179 | else: 180 | img = Image.open(buf).convert('L') 181 | 182 | except IOError: 183 | print(f'Corrupted image for {index}') 184 | # make dummy image and dummy label for corrupted image. 185 | if self.opt.rgb: 186 | img = Image.new('RGB', (self.opt.imgW, self.opt.imgH)) 187 | else: 188 | img = Image.new('L', (self.opt.imgW, self.opt.imgH)) 189 | label = '[dummy_label]' 190 | 191 | if not self.opt.sensitive: 192 | label = label.lower() 193 | 194 | # We only train and evaluate on alphanumerics (or pre-defined character set in train.py) 195 | out_of_char = f'[^{self.opt.character}]' 196 | label = re.sub(out_of_char, '', label) 197 | 198 | return (img, label) 199 | 200 | 201 | class RawDataset(Dataset): 202 | 203 | def __init__(self, root, opt): 204 | self.opt = opt 205 | self.image_path_list = [] 206 | for dirpath, dirnames, filenames in os.walk(root): 207 | for name in filenames: 208 | _, ext = os.path.splitext(name) 209 | ext = ext.lower() 210 | if ext == '.jpg' or ext == '.jpeg' or ext == '.png': 211 | self.image_path_list.append(os.path.join(dirpath, name)) 212 | 213 | self.image_path_list = natsorted(self.image_path_list) 214 | self.nSamples = len(self.image_path_list) 215 | 216 | def __len__(self): 217 | return self.nSamples 218 | 219 | def __getitem__(self, index): 220 | 221 | try: 222 | if self.opt.rgb: 223 | img = Image.open(self.image_path_list[index]).convert('RGB') # for color image 224 | else: 225 | img = Image.open(self.image_path_list[index]).convert('L') 226 | 227 | except IOError: 228 | print(f'Corrupted image for {index}') 229 | # make dummy image and dummy label for corrupted image. 230 | if self.opt.rgb: 231 | img = Image.new('RGB', (self.opt.imgW, self.opt.imgH)) 232 | else: 233 | img = Image.new('L', (self.opt.imgW, self.opt.imgH)) 234 | 235 | return (img, self.image_path_list[index]) 236 | 237 | 238 | class ResizeNormalize(object): 239 | 240 | def __init__(self, size, interpolation=Image.BICUBIC): 241 | self.size = size 242 | self.interpolation = interpolation 243 | self.toTensor = transforms.ToTensor() 244 | 245 | def __call__(self, img): 246 | img = img.resize(self.size, self.interpolation) 247 | img = self.toTensor(img) 248 | img.sub_(0.5).div_(0.5) 249 | return img 250 | 251 | 252 | class NormalizePAD(object): 253 | 254 | def __init__(self, max_size, PAD_type='right'): 255 | self.toTensor = transforms.ToTensor() 256 | self.max_size = max_size 257 | self.max_width_half = math.floor(max_size[2] / 2) 258 | self.PAD_type = PAD_type 259 | 260 | def __call__(self, img): 261 | img = self.toTensor(img) 262 | img.sub_(0.5).div_(0.5) 263 | c, h, w = img.size() 264 | Pad_img = torch.FloatTensor(*self.max_size).fill_(0) 265 | Pad_img[:, :, :w] = img # right pad 266 | if self.max_size[2] != w: # add border Pad 267 | Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w) 268 | 269 | return Pad_img 270 | 271 | 272 | class AlignCollate(object): 273 | 274 | def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=False): 275 | self.imgH = imgH 276 | self.imgW = imgW 277 | self.keep_ratio_with_pad = keep_ratio_with_pad 278 | 279 | def __call__(self, batch): 280 | batch = filter(lambda x: x is not None, batch) 281 | images, labels = zip(*batch) 282 | 283 | if self.keep_ratio_with_pad: # same concept with 'Rosetta' paper 284 | resized_max_w = self.imgW 285 | transform = NormalizePAD((1, self.imgH, resized_max_w)) 286 | 287 | resized_images = [] 288 | for image in images: 289 | w, h = image.size 290 | ratio = w / float(h) 291 | if math.ceil(self.imgH * ratio) > self.imgW: 292 | resized_w = self.imgW 293 | else: 294 | resized_w = math.ceil(self.imgH * ratio) 295 | 296 | resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC) 297 | resized_images.append(transform(resized_image)) 298 | # resized_image.save('./image_test/%d_test.jpg' % w) 299 | 300 | image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0) 301 | 302 | else: 303 | transform = ResizeNormalize((self.imgW, self.imgH)) 304 | image_tensors = [transform(image) for image in images] 305 | image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0) 306 | 307 | return image_tensors, labels 308 | 309 | 310 | def tensor2im(image_tensor, imtype=np.uint8): 311 | image_numpy = image_tensor.cpu().float().numpy() 312 | if image_numpy.shape[0] == 1: 313 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 314 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 315 | return image_numpy.astype(imtype) 316 | 317 | 318 | def save_image(image_numpy, image_path): 319 | image_pil = Image.fromarray(image_numpy) 320 | image_pil.save(image_path) 321 | -------------------------------------------------------------------------------- /modules/feature_extraction.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from collections import OrderedDict 5 | import torch.utils.checkpoint as cp 6 | import numpy as np 7 | 8 | def _bn_function_factory(norm, relu, conv): 9 | def bn_function(*inputs): 10 | concated_features = torch.cat(inputs, 1) 11 | bottleneck_output = conv(relu(norm(concated_features))) 12 | return bottleneck_output 13 | 14 | return bn_function 15 | 16 | 17 | class VGG_FeatureExtractor(nn.Module): 18 | """ FeatureExtractor of CRNN (https://arxiv.org/pdf/1507.05717.pdf) """ 19 | 20 | def __init__(self, input_channel, output_channel=512): 21 | super(VGG_FeatureExtractor, self).__init__() 22 | self.output_channel = [int(output_channel / 8), int(output_channel / 4), 23 | int(output_channel / 2), output_channel] # [64, 128, 256, 512] 24 | self.ConvNet = nn.Sequential( 25 | nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True), 26 | nn.MaxPool2d(2, 2), # 64x16x50 27 | nn.Conv2d(self.output_channel[0], self.output_channel[1], 3, 1, 1), nn.ReLU(True), 28 | nn.MaxPool2d(2, 2), # 128x8x25 29 | nn.Conv2d(self.output_channel[1], self.output_channel[2], 3, 1, 1), nn.ReLU(True), 30 | # 256x8x25 31 | nn.Conv2d(self.output_channel[2], self.output_channel[2], 3, 1, 1), nn.ReLU(True), 32 | nn.MaxPool2d((2, 1), (2, 1)), # 256x4x25 33 | nn.Conv2d(self.output_channel[2], self.output_channel[3], 3, 1, 1, bias=False), 34 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), # 512x4x25 35 | nn.Conv2d(self.output_channel[3], self.output_channel[3], 3, 1, 1, bias=False), 36 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True), 37 | nn.MaxPool2d((2, 1), (2, 1)), # 512x2x25 38 | nn.Conv2d(self.output_channel[3], self.output_channel[3], 2, 1, 0), 39 | nn.ReLU(True)) # 512x1x24 40 | 41 | def forward(self, input): 42 | return self.ConvNet(input) 43 | 44 | 45 | class RCNN_FeatureExtractor(nn.Module): 46 | """ FeatureExtractor of GRCNN (https://papers.nips.cc/paper/6637-gated-recurrent-convolution-neural-network-for-ocr.pdf) """ 47 | 48 | def __init__(self, input_channel, output_channel=512): 49 | super(RCNN_FeatureExtractor, self).__init__() 50 | self.output_channel = [int(output_channel / 8), int(output_channel / 4), 51 | int(output_channel / 2), output_channel] # [64, 128, 256, 512] 52 | self.ConvNet = nn.Sequential( 53 | nn.Conv2d(input_channel, self.output_channel[0], 3, 1, 1), nn.ReLU(True), 54 | nn.MaxPool2d(2, 2), # 64 x 16 x 50 55 | GRCL(self.output_channel[0], self.output_channel[0], num_iteration=5, kernel_size=3, 56 | pad=1), 57 | nn.MaxPool2d(2, 2), # 64 x 8 x 25 58 | GRCL(self.output_channel[0], self.output_channel[1], num_iteration=5, kernel_size=3, 59 | pad=1), 60 | nn.MaxPool2d(2, (2, 1), (0, 1)), # 128 x 4 x 26 61 | GRCL(self.output_channel[1], self.output_channel[2], num_iteration=5, kernel_size=3, 62 | pad=1), 63 | nn.MaxPool2d(2, (2, 1), (0, 1)), # 256 x 2 x 27 64 | nn.Conv2d(self.output_channel[2], self.output_channel[3], 2, 1, 0, bias=False), 65 | nn.BatchNorm2d(self.output_channel[3]), nn.ReLU(True)) # 512 x 1 x 26 66 | 67 | def forward(self, input): 68 | return self.ConvNet(input) 69 | 70 | 71 | class ResNet_FeatureExtractor(nn.Module): 72 | """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """ 73 | 74 | def __init__(self, input_channel, output_channel=512): 75 | super(ResNet_FeatureExtractor, self).__init__() 76 | self.ConvNet = ResNet(input_channel, output_channel, BasicBlock, [1, 2, 5, 3]) 77 | 78 | def forward(self, input): 79 | return self.ConvNet(input) 80 | 81 | 82 | class DenseNet_FeatureExtractor(nn.Module): 83 | """ FeatureExtractor of FAN (http://openaccess.thecvf.com/content_ICCV_2017/papers/Cheng_Focusing_Attention_Towards_ICCV_2017_paper.pdf) """ 84 | 85 | def __init__(self, input_channel, output_channel=512): 86 | super(DenseNet_FeatureExtractor, self).__init__() 87 | self.ConvNet = DenseNet(input_channel, output_channel, 88 | growth_rate=24, 89 | block_config=(6, 12, 16), 90 | num_init_features=64, 91 | bn_size=4, 92 | drop_rate=0, 93 | memory_efficient=False) 94 | 95 | def forward(self, input): 96 | return self.ConvNet(input) 97 | 98 | 99 | # For Gated RCNN 100 | class GRCL(nn.Module): 101 | 102 | def __init__(self, input_channel, output_channel, num_iteration, kernel_size, pad): 103 | super(GRCL, self).__init__() 104 | self.wgf_u = nn.Conv2d(input_channel, output_channel, 1, 1, 0, bias=False) 105 | self.wgr_x = nn.Conv2d(output_channel, output_channel, 1, 1, 0, bias=False) 106 | self.wf_u = nn.Conv2d(input_channel, output_channel, kernel_size, 1, pad, bias=False) 107 | self.wr_x = nn.Conv2d(output_channel, output_channel, kernel_size, 1, pad, bias=False) 108 | 109 | self.BN_x_init = nn.BatchNorm2d(output_channel) 110 | 111 | self.num_iteration = num_iteration 112 | self.GRCL = [GRCL_unit(output_channel) for _ in range(num_iteration)] 113 | self.GRCL = nn.Sequential(*self.GRCL) 114 | 115 | def forward(self, input): 116 | """ The input of GRCL is consistant over time t, which is denoted by u(0) 117 | thus wgf_u / wf_u is also consistant over time t. 118 | """ 119 | wgf_u = self.wgf_u(input) 120 | wf_u = self.wf_u(input) 121 | x = F.relu(self.BN_x_init(wf_u)) 122 | 123 | for i in range(self.num_iteration): 124 | x = self.GRCL[i](wgf_u, self.wgr_x(x), wf_u, self.wr_x(x)) 125 | 126 | return x 127 | 128 | 129 | class GRCL_unit(nn.Module): 130 | 131 | def __init__(self, output_channel): 132 | super(GRCL_unit, self).__init__() 133 | self.BN_gfu = nn.BatchNorm2d(output_channel) 134 | self.BN_grx = nn.BatchNorm2d(output_channel) 135 | self.BN_fu = nn.BatchNorm2d(output_channel) 136 | self.BN_rx = nn.BatchNorm2d(output_channel) 137 | self.BN_Gx = nn.BatchNorm2d(output_channel) 138 | 139 | def forward(self, wgf_u, wgr_x, wf_u, wr_x): 140 | G_first_term = self.BN_gfu(wgf_u) 141 | G_second_term = self.BN_grx(wgr_x) 142 | G = F.sigmoid(G_first_term + G_second_term) 143 | 144 | x_first_term = self.BN_fu(wf_u) 145 | x_second_term = self.BN_Gx(self.BN_rx(wr_x) * G) 146 | x = F.relu(x_first_term + x_second_term) 147 | 148 | return x 149 | 150 | 151 | class BasicBlock(nn.Module): 152 | expansion = 1 153 | 154 | def __init__(self, inplanes, planes, stride=1, downsample=None): 155 | super(BasicBlock, self).__init__() 156 | self.conv1 = self._conv3x3(inplanes, planes) 157 | self.bn1 = nn.BatchNorm2d(planes) 158 | self.conv2 = self._conv3x3(planes, planes) 159 | self.bn2 = nn.BatchNorm2d(planes) 160 | self.relu = nn.ReLU(inplace=True) 161 | self.downsample = downsample 162 | self.stride = stride 163 | 164 | def _conv3x3(self, in_planes, out_planes, stride=1): 165 | "3x3 convolution with padding" 166 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 167 | padding=1, bias=False) 168 | 169 | def forward(self, x): 170 | residual = x 171 | 172 | out = self.conv1(x) 173 | out = self.bn1(out) 174 | out = self.relu(out) 175 | 176 | out = self.conv2(out) 177 | out = self.bn2(out) 178 | 179 | if self.downsample is not None: 180 | residual = self.downsample(x) 181 | out += residual 182 | out = self.relu(out) 183 | 184 | return out 185 | 186 | 187 | class ResNet(nn.Module): 188 | 189 | def __init__(self, input_channel, output_channel, block, layers): 190 | super(ResNet, self).__init__() 191 | 192 | self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), 193 | output_channel, output_channel] 194 | 195 | self.inplanes = int(output_channel / 8) 196 | self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16), 197 | kernel_size=3, stride=1, padding=1, bias=False) 198 | self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16)) 199 | self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes, 200 | kernel_size=3, stride=1, padding=1, bias=False) 201 | self.bn0_2 = nn.BatchNorm2d(self.inplanes) 202 | self.relu = nn.ReLU(inplace=True) 203 | 204 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 205 | self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0]) 206 | self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[ 207 | 0], kernel_size=3, stride=1, padding=1, bias=False) 208 | self.bn1 = nn.BatchNorm2d(self.output_channel_block[0]) 209 | 210 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) 211 | self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1) 212 | self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[ 213 | 1], kernel_size=3, stride=1, padding=1, bias=False) 214 | self.bn2 = nn.BatchNorm2d(self.output_channel_block[1]) 215 | 216 | self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1)) 217 | self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1) 218 | self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[ 219 | 2], kernel_size=3, stride=1, padding=1, bias=False) 220 | self.bn3 = nn.BatchNorm2d(self.output_channel_block[2]) 221 | 222 | self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1) 223 | self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ 224 | 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False) 225 | self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3]) 226 | self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[ 227 | 3], kernel_size=2, stride=1, padding=0, bias=False) 228 | self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3]) 229 | 230 | def _make_layer(self, block, planes, blocks, stride=1): 231 | downsample = None 232 | if stride != 1 or self.inplanes != planes * block.expansion: 233 | downsample = nn.Sequential( 234 | nn.Conv2d(self.inplanes, planes * block.expansion, 235 | kernel_size=1, stride=stride, bias=False), 236 | nn.BatchNorm2d(planes * block.expansion), 237 | ) 238 | 239 | layers = [] 240 | layers.append(block(self.inplanes, planes, stride, downsample)) 241 | self.inplanes = planes * block.expansion 242 | for i in range(1, blocks): 243 | layers.append(block(self.inplanes, planes)) 244 | 245 | return nn.Sequential(*layers) 246 | 247 | def forward(self, x): 248 | x = self.conv0_1(x) 249 | x = self.bn0_1(x) 250 | x = self.relu(x) 251 | x = self.conv0_2(x) 252 | x = self.bn0_2(x) 253 | x = self.relu(x) 254 | 255 | x = self.maxpool1(x) 256 | x = self.layer1(x) 257 | x = self.conv1(x) 258 | x = self.bn1(x) 259 | x = self.relu(x) 260 | 261 | x = self.maxpool2(x) 262 | x = self.layer2(x) 263 | x = self.conv2(x) 264 | x = self.bn2(x) 265 | x = self.relu(x) 266 | 267 | x = self.maxpool3(x) 268 | x = self.layer3(x) 269 | x = self.conv3(x) 270 | x = self.bn3(x) 271 | x = self.relu(x) 272 | 273 | x = self.layer4(x) 274 | x = self.conv4_1(x) 275 | x = self.bn4_1(x) 276 | x = self.relu(x) 277 | x = self.conv4_2(x) 278 | x = self.bn4_2(x) 279 | x = self.relu(x) 280 | 281 | return x 282 | 283 | 284 | ##-----------------DenseNet------------------------# 285 | class _DenseLayer(nn.Sequential): 286 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False): 287 | super(_DenseLayer, self).__init__() 288 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 289 | self.add_module('relu1', nn.ReLU(inplace=True)), 290 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 291 | growth_rate, kernel_size=1, stride=1, 292 | bias=False)), 293 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 294 | self.add_module('relu2', nn.ReLU(inplace=True)), 295 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 296 | kernel_size=3, stride=1, padding=1, 297 | bias=False)), 298 | self.drop_rate = drop_rate 299 | self.memory_efficient = memory_efficient 300 | 301 | def forward(self, *prev_features): 302 | bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1) 303 | if self.memory_efficient and any( 304 | prev_feature.requires_grad for prev_feature in prev_features): 305 | bottleneck_output = cp.checkpoint(bn_function, *prev_features) 306 | else: 307 | bottleneck_output = bn_function(*prev_features) 308 | new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) 309 | if self.drop_rate > 0: 310 | new_features = F.dropout(new_features, p=self.drop_rate, 311 | training=self.training) 312 | return new_features 313 | 314 | 315 | class _DenseBlock(nn.Module): 316 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, 317 | memory_efficient=False): 318 | super(_DenseBlock, self).__init__() 319 | for i in range(num_layers): 320 | layer = _DenseLayer( 321 | num_input_features + i * growth_rate, 322 | growth_rate=growth_rate, 323 | bn_size=bn_size, 324 | drop_rate=drop_rate, 325 | memory_efficient=memory_efficient, 326 | ) 327 | self.add_module('denselayer%d' % (i + 1), layer) 328 | 329 | def forward(self, init_features): 330 | features = [init_features] 331 | for name, layer in self.named_children(): 332 | new_features = layer(*features) 333 | features.append(new_features) 334 | return torch.cat(features, 1) 335 | 336 | 337 | class _Transition(nn.Sequential): 338 | def __init__(self, num_input_features, num_output_features): 339 | super(_Transition, self).__init__() 340 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 341 | self.add_module('relu', nn.ReLU(inplace=True)) 342 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 343 | kernel_size=1, stride=1, bias=False)) 344 | # self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 345 | self.add_module('pool', nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1))) 346 | 347 | 348 | class DenseNet(nn.Module): 349 | r"""Densenet-BC model class, based on 350 | `"Densely Connected Convolutional Networks" `_ 351 | Args: 352 | growth_rate (int) - how many filters to add each layer (`k` in paper) 353 | block_config (list of 4 ints) - how many layers in each pooling block 354 | num_init_features (int) - the number of filters to learn in the first convolution layer 355 | bn_size (int) - multiplicative factor for number of bottle neck layers 356 | (i.e. bn_size * k features in the bottleneck layer) 357 | drop_rate (float) - dropout rate after each dense layer 358 | num_classes (int) - number of classification classes 359 | memory_efficient (bool) - If True, uses checkpointing. Much more memory efficient, 360 | but slower. Default: *False*. See `"paper" `_ 361 | """ 362 | 363 | def __init__(self, input_channel, 364 | output_channel, 365 | growth_rate=32, block_config=(6, 12, 24, 16), 366 | num_init_features=64, bn_size=4, drop_rate=0, 367 | memory_efficient=False, 368 | ): 369 | 370 | super(DenseNet, self).__init__() 371 | 372 | # First convolution 373 | self.features = nn.Sequential(OrderedDict([ 374 | ('conv0', nn.Conv2d(input_channel, num_init_features, kernel_size=3, stride=2, 375 | padding=1, bias=False)), 376 | ('norm0', nn.BatchNorm2d(num_init_features)), 377 | ('relu0', nn.ReLU(inplace=True)), 378 | ('pool0', nn.MaxPool2d(kernel_size=2, stride=2, padding=1)), 379 | ])) 380 | 381 | # Each denseblock 382 | num_features = num_init_features 383 | for i, num_layers in enumerate(block_config): 384 | block = _DenseBlock( 385 | num_layers=num_layers, 386 | num_input_features=num_features, 387 | bn_size=bn_size, 388 | growth_rate=growth_rate, 389 | drop_rate=drop_rate, 390 | memory_efficient=memory_efficient 391 | ) 392 | self.features.add_module('denseblock%d' % (i + 1), block) 393 | num_features = num_features + num_layers * growth_rate 394 | if i != len(block_config) - 1: 395 | trans = _Transition(num_input_features=num_features, 396 | num_output_features=num_features // 2) 397 | self.features.add_module('transition%d' % (i + 1), trans) 398 | num_features = num_features // 2 399 | 400 | # Final batch norm 401 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 402 | 403 | # Official init from torch repo. 404 | for m in self.modules(): 405 | if isinstance(m, nn.Conv2d): 406 | nn.init.kaiming_normal_(m.weight) 407 | elif isinstance(m, nn.BatchNorm2d): 408 | nn.init.constant_(m.weight, 1) 409 | nn.init.constant_(m.bias, 0) 410 | elif isinstance(m, nn.Linear): 411 | nn.init.constant_(m.bias, 0) 412 | 413 | def forward(self, x): 414 | out = self.features(x) 415 | return out 416 | 417 | 418 | if __name__ == "__main__": 419 | # model = DenseNet_FeatureExtractor(input_channel=3,output_channel=64) 420 | model = ResNet_FeatureExtractor(input_channel=3,output_channel=64) 421 | model.eval() 422 | x = torch.Tensor(np.ones([1,3,64,64])) 423 | with torch.no_grad(): 424 | out = model(x) 425 | print(out.size()) -------------------------------------------------------------------------------- /train_da_coral.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import string 5 | import time 6 | 7 | import numpy as np 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | import torch.nn.init as init 11 | import torch.optim as optim 12 | import torch.utils.data 13 | 14 | from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset 15 | from losses.coral import CORAL 16 | from seqda_model import Model 17 | from test import validation 18 | from utils import AttnLabelConverter, Averager, load_char_dict 19 | 20 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 21 | 22 | 23 | def coral_loss(source_context_history, source_prediction, 24 | target_context_history, target_prediction): 25 | feature_dim = source_context_history.size()[-1] 26 | 27 | source_feature = source_context_history.reshape(-1, feature_dim) 28 | target_feature = target_context_history.reshape(-1, feature_dim) 29 | 30 | # print(type(pred_class),pred_class) 31 | _, source_pred_class = source_prediction.max(-1) 32 | _, target_pred_class = target_prediction.max(-1) 33 | source_valid_char_index = (source_pred_class.reshape(-1, ) != 1).nonzero().reshape(-1, ) 34 | source_valid_char_feature = source_feature.reshape(-1, feature_dim).index_select(0, 35 | source_valid_char_index) 36 | target_valid_char_index = (target_pred_class.reshape(-1, ) != 1).nonzero().reshape(-1, ) 37 | target_valid_char_feature = target_feature.reshape(-1, feature_dim).index_select(0, 38 | target_valid_char_index) 39 | 40 | similarity_loss = CORAL( 41 | source_valid_char_feature, 42 | target_valid_char_feature) 43 | return similarity_loss 44 | 45 | 46 | class trainer(object): 47 | def __init__(self, opt): 48 | 49 | opt.src_select_data = opt.src_select_data.split('-') 50 | opt.src_batch_ratio = opt.src_batch_ratio.split('-') 51 | opt.tar_select_data = opt.tar_select_data.split('-') 52 | opt.tar_batch_ratio = opt.tar_batch_ratio.split('-') 53 | 54 | """ vocab / character number configuration """ 55 | if opt.sensitive: 56 | # opt.character += 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' 57 | opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). 58 | 59 | if opt.char_dict is not None: 60 | opt.character = load_char_dict(opt.char_dict)[3:-2] # 去除Attention 和 CTC引入的一些特殊符号 61 | 62 | """ model configuration """ 63 | self.converter = AttnLabelConverter(opt.character) 64 | opt.num_class = len(self.converter.character) 65 | 66 | if opt.rgb: 67 | opt.input_channel = 3 68 | self.opt = opt 69 | print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, 70 | opt.output_channel, 71 | opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, 72 | opt.FeatureExtraction, 73 | opt.SequenceModeling, opt.Prediction) 74 | self.save_opt_log(opt) 75 | 76 | self.build_model(opt) 77 | 78 | def dataloader(self, opt): 79 | src_train_data = opt.src_train_data 80 | src_select_data = opt.src_select_data 81 | src_batch_ratio = opt.src_batch_ratio 82 | src_train_dataset = Batch_Balanced_Dataset(opt, src_train_data, src_select_data, 83 | src_batch_ratio) 84 | 85 | tar_train_data = opt.tar_train_data 86 | tar_select_data = opt.tar_select_data 87 | tar_batch_ratio = opt.tar_batch_ratio 88 | tar_train_dataset = Batch_Balanced_Dataset(opt, tar_train_data, tar_select_data, 89 | tar_batch_ratio) 90 | 91 | AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 92 | 93 | valid_dataset = hierarchical_dataset(root=opt.valid_data, opt=opt) 94 | valid_loader = torch.utils.data.DataLoader( 95 | valid_dataset, batch_size=opt.batch_size, 96 | shuffle=True, # 'True' to check training progress with validation function. 97 | num_workers=int(opt.workers), 98 | collate_fn=AlignCollate_valid, pin_memory=True) 99 | return src_train_dataset, tar_train_dataset, valid_loader 100 | 101 | def _optimizer(self, opt): 102 | # filter that only require gradient decent 103 | filtered_parameters = [] 104 | params_num = [] 105 | for p in filter(lambda p: p.requires_grad, self.model.parameters()): 106 | filtered_parameters.append(p) 107 | params_num.append(np.prod(p.size())) 108 | print('Trainable params num : ', sum(params_num)) 109 | # setup optimizer 110 | if opt.adam: 111 | self.optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) 112 | else: 113 | self.optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, 114 | eps=opt.eps) 115 | 116 | print("Optimizer:") 117 | print(self.optimizer) 118 | 119 | def build_model(self, opt): 120 | print('-' * 80) 121 | 122 | """ Define Model """ 123 | self.model = Model(opt) 124 | 125 | self.weight_initializer() 126 | 127 | self.model = torch.nn.DataParallel(self.model).to(device) 128 | 129 | """ Define Loss """ 130 | self.criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) 131 | self.D_criterion = torch.nn.CrossEntropyLoss().to(device) 132 | 133 | """ Trainer """ 134 | self._optimizer(opt) 135 | 136 | def train(self, opt): 137 | # Add custom dataset add cfgs from da-faster-rcnn 138 | # Make sure you change the imdb_name in factory.py 139 | """ 140 | Dummy format: 141 | 142 | args.src_dataset == '$YOUR_DATASET_NAME' 143 | args.src_imdb_name = '$YOUR_DATASET_NAME_2007_trainval' 144 | args.src_imdbval_name = '$YOUR_DATASET_NAME_2007_test' 145 | args.set_cfgs = [...] 146 | """ 147 | 148 | # src, tar dataloaders 149 | src_dataset, tar_dataset, valid_loader = self.dataloader(opt) 150 | src_dataset_size = src_dataset.total_data_size 151 | tar_dataset_size = tar_dataset.total_data_size 152 | train_size = max([src_dataset_size, tar_dataset_size]) 153 | 154 | self.model.train() 155 | 156 | start_iter = 0 157 | 158 | if opt.continue_model != '': 159 | self.load(opt.continue_model) 160 | print(" [*] Load SUCCESS") 161 | # if opt.decay_flag and start_iter > (opt.num_iter // 2): 162 | # self.d_image_opt.param_groups[0]['lr'] -= (opt.lr / (opt.num_iter // 2)) * ( 163 | # start_iter - opt.num_iter // 2) 164 | # self.d_inst_opt.param_groups[0]['lr'] -= (opt.lr / (opt.num_iter // 2)) * ( 165 | # start_iter - opt.num_iter // 2) 166 | 167 | # loss averager 168 | cls_loss_avg = Averager() 169 | sim_loss_avg = Averager() 170 | loss_avg = Averager() 171 | 172 | # training loop 173 | print('training start !') 174 | start_time = time.time() 175 | best_accuracy = -1 176 | best_norm_ED = 1e+6 177 | 178 | for step in range(start_iter, opt.num_iter + 1): 179 | 180 | src_image, src_labels = src_dataset.get_batch() 181 | src_image = src_image.to(device) 182 | src_text, src_length = self.converter.encode(src_labels, 183 | batch_max_length=opt.batch_max_length) 184 | 185 | tar_image, tar_labels = tar_dataset.get_batch() 186 | tar_image = tar_image.to(device) 187 | tar_text, tar_length = self.converter.encode(tar_labels, 188 | batch_max_length=opt.batch_max_length) 189 | 190 | # Set gradient to zero... 191 | self.model.zero_grad() 192 | 193 | # Attention # align with Attention.forward 194 | src_preds, src_global_feature, src_local_feature = self.model(src_image, 195 | src_text[:, :-1]) 196 | target = src_text[:, 1:] # without [GO] Symbol 197 | src_cls_loss = self.criterion(src_preds.view(-1, src_preds.shape[-1]), 198 | target.contiguous().view(-1)) 199 | 200 | src_local_feature = src_local_feature.view(-1, src_local_feature.shape[-1]) 201 | # TODO 202 | tar_preds, tar_global_feature, tar_local_feature = self.model(tar_image, 203 | tar_text[:, :-1], 204 | is_train=False) 205 | 206 | tar_local_feature = tar_local_feature.view(-1, tar_local_feature.shape[-1]) 207 | 208 | d_inst_loss = coral_loss(src_local_feature, src_preds, 209 | tar_local_feature, tar_preds) 210 | # Add domain loss 211 | loss = src_cls_loss.mean() + 0.1 * d_inst_loss.mean() 212 | loss_avg.add(loss) 213 | cls_loss_avg.add(src_cls_loss) 214 | sim_loss_avg.add(d_inst_loss) 215 | 216 | # frcnn backward 217 | loss.backward() 218 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 219 | opt.grad_clip) # gradient clipping with 5 (Default) 220 | # frcnn optimizer update 221 | self.optimizer.step() 222 | 223 | # validation part 224 | if step % opt.valInterval == 0: 225 | 226 | elapsed_time = time.time() - start_time 227 | print( 228 | f'[{step}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} CLS_Loss: {cls_loss_avg.val():0.5f} SIMI_Loss: {sim_loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}') 229 | # for log 230 | with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: 231 | log.write( 232 | f'[{step}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n') 233 | loss_avg.reset() 234 | cls_loss_avg.reset() 235 | sim_loss_avg.reset() 236 | 237 | self.model.eval() 238 | with torch.no_grad(): 239 | valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation( 240 | self.model, self.criterion, valid_loader, self.converter, opt) 241 | 242 | self.print_prediction_result(preds, labels, log) 243 | 244 | valid_log = f'[{step}/{opt.num_iter}] valid loss: {valid_loss:0.5f}' 245 | valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}' 246 | print(valid_log) 247 | log.write(valid_log + '\n') 248 | 249 | self.model.train() 250 | 251 | # keep best accuracy model 252 | if current_accuracy > best_accuracy: 253 | best_accuracy = current_accuracy 254 | save_name = f'./saved_models/{opt.experiment_name}/best_accuracy.pth' 255 | self.save(opt, save_name) 256 | if current_norm_ED < best_norm_ED: 257 | best_norm_ED = current_norm_ED 258 | save_name = f'./saved_models/{opt.experiment_name}/best_norm_ED.pth' 259 | self.save(opt, save_name) 260 | 261 | best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}' 262 | print(best_model_log) 263 | log.write(best_model_log + '\n') 264 | 265 | # save model per 1e+5 iter. 266 | if (step + 1) % 1e+5 == 0: 267 | save_name = f'./saved_models/{opt.experiment_name}/iter_{step+1}.pth' 268 | self.save(opt, save_name) 269 | 270 | def load(self, saved_model): 271 | params = torch.load(saved_model) 272 | 273 | if 'model' not in params: 274 | self.model.load_state_dict(params) 275 | else: 276 | self.model.load_state_dict(params['model']) 277 | 278 | def save(self, opt, save_name): 279 | params = {} 280 | params['model'] = self.model.state_dict() 281 | # for training 282 | params['optimizer'] = self.optimizer.state_dict() 283 | torch.save(params, save_name) 284 | print('Successfully save model: {}'.format(save_name)) 285 | 286 | def weight_initializer(self): 287 | # weight initialization 288 | for name, param in self.model.named_parameters(): 289 | if 'localization_fc2' in name: 290 | print(f'Skip {name} as it is already initialized') 291 | continue 292 | try: 293 | if 'bias' in name: 294 | init.constant_(param, 0.0) 295 | elif 'weight' in name: 296 | init.kaiming_normal_(param) 297 | except Exception as e: # for batchnorm. 298 | if 'weight' in name: 299 | param.data.fill_(1) 300 | continue 301 | 302 | def save_opt_log(self, opt): 303 | """ final options """ 304 | # print(opt) 305 | with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file: 306 | opt_log = '------------ Options -------------\n' 307 | args = vars(opt) 308 | for k, v in args.items(): 309 | opt_log += f'{str(k)}: {str(v)}\n' 310 | opt_log += '---------------------------------------\n' 311 | print(opt_log) 312 | opt_file.write(opt_log) 313 | 314 | def print_prediction_result(self, preds, labels, fp_log): 315 | """ 316 | fp-logwenjian 317 | :param preds: 318 | :param labels: 319 | :param fp_log: 320 | :return: 321 | """ 322 | for pred, gt in zip(preds[:5], labels[:5]): 323 | if 'Attn' in opt.Prediction: 324 | pred = pred[:pred.find('[s]')] 325 | gt = gt[:gt.find('[s]')] 326 | print(f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}') 327 | fp_log.write(f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}\n') 328 | 329 | 330 | if __name__ == '__main__': 331 | parser = argparse.ArgumentParser() 332 | parser.add_argument('--experiment_name', help='Where to store logs and models') 333 | parser.add_argument('--src_train_data', required=True, help='path to training dataset') 334 | parser.add_argument('--tar_train_data', required=True, help='path to training dataset') 335 | parser.add_argument('--valid_data', required=True, help='path to validation dataset') 336 | parser.add_argument('--manualSeed', type=int, default=1111, help='for random seed setting') 337 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) 338 | parser.add_argument('--batch_size', type=int, default=192, help='input batch size') 339 | parser.add_argument('--num_iter', type=int, default=300000, 340 | help='number of iterations to train for') 341 | parser.add_argument('--valInterval', type=int, default=2000, 342 | help='Interval between each validation') 343 | parser.add_argument('--continue_model', default='', help="path to model to continue training") 344 | 345 | parser.add_argument('--adam', action='store_true', 346 | help='Whether to use adam (default is Adadelta)') 347 | parser.add_argument('--lr', type=float, default=1, 348 | help='learning rate, default=1.0 for Adadelta') 349 | parser.add_argument('--decay_flag', action='store_true', help='for learning rate decay') 350 | parser.add_argument('--use_tfboard', action='store_true', help='use_tfboard') 351 | parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for adam. default=0.9') 352 | parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for adam. default=0.9') 353 | # parser.add_argument('--weight_decay', type=float, default=0.9, help='weight_decay for adam. default=0.9') 354 | parser.add_argument('--schedule', type=int, nargs='+', default=[150, 225], 355 | help='Decrease learning rate at these epochs.') 356 | parser.add_argument('--gamma', type=float, default=0.1, 357 | help='LR is multiplied by gamma on schedule.') 358 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 359 | help='momentum') 360 | parser.add_argument('--weight_decay', '--wd', default=1e-4, type=float, 361 | metavar='W', help='weight decay (default: 1e-4)') 362 | 363 | parser.add_argument('--rho', type=float, default=0.95, 364 | help='decay rate rho for Adadelta. default=0.95') 365 | parser.add_argument('--eps', type=float, default=1e-8, help='eps for Adadelta. default=1e-8') 366 | parser.add_argument('--grad_clip', type=float, default=5, 367 | help='gradient clipping value. default=5') 368 | 369 | """ Data processing """ 370 | parser.add_argument('--src_select_data', type=str, default='MJ-ST', 371 | help='select training data (default is MJ-ST, which means MJ and ST used as training data)') 372 | parser.add_argument('--src_batch_ratio', type=str, default='0.5-0.5', 373 | help='assign ratio for each selected data in the batch') 374 | parser.add_argument('--tar_select_data', type=str, default='real_data', 375 | help='select training data (default is real_data, which means MJ and ST used as training data)') 376 | parser.add_argument('--tar_batch_ratio', type=str, default='1', 377 | help='assign ratio for each selected data in the batch') 378 | parser.add_argument('--total_data_usage_ratio', type=str, default='1.0', 379 | help='total data usage ratio, this ratio is multiplied to total number of data.') 380 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') 381 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') 382 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') 383 | parser.add_argument('--rgb', action='store_true', help='use rgb input') 384 | parser.add_argument('--char_dict', type=str, default=None, 385 | help="path to char dict: dataset/iam/char_dict.txt") 386 | parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', 387 | help='character label') 388 | parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode') 389 | parser.add_argument('--filtering_special_chars', action='store_true', 390 | help='for sensitive character mode') 391 | parser.add_argument('--PAD', action='store_true', 392 | help='whether to keep ratio then pad for image resize') 393 | parser.add_argument('--data_filtering_off', action='store_true', 394 | help='for data_filtering_off mode') 395 | """ Model Architecture """ 396 | parser.add_argument('--Transformation', type=str, required=True, 397 | help='Transformation stage. None|TPS') 398 | parser.add_argument('--FeatureExtraction', type=str, required=True, 399 | help='FeatureExtraction stage. VGG|RCNN|ResNet') 400 | parser.add_argument('--SequenceModeling', type=str, required=True, 401 | help='SequenceModeling stage. None|BiLSTM') 402 | parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn') 403 | parser.add_argument('--num_fiducial', type=int, default=20, 404 | help='number of fiducial points of TPS-STN') 405 | parser.add_argument('--input_channel', type=int, default=1, 406 | help='the number of input channel of Feature extractor') 407 | parser.add_argument('--output_channel', type=int, default=512, 408 | help='the number of output channel of Feature extractor') 409 | parser.add_argument('--hidden_size', type=int, default=256, 410 | help='the size of the LSTM hidden state') 411 | 412 | opt = parser.parse_args() 413 | 414 | if not opt.experiment_name: 415 | opt.experiment_name = f'{opt.Transformation}-{opt.FeatureExtraction}-{opt.SequenceModeling}-{opt.Prediction}' 416 | opt.experiment_name += f'-Seed{opt.manualSeed}' 417 | else: 418 | experiment_name = f'{opt.Transformation}-{opt.FeatureExtraction}-{opt.SequenceModeling}-{opt.Prediction}' 419 | experiment_name += f'-Seed{opt.manualSeed}' 420 | opt.experiment_name = experiment_name + opt.experiment_name 421 | # print(opt.experiment_name) 422 | 423 | os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True) 424 | 425 | """ Seed and GPU setting """ 426 | # print("Random Seed: ", opt.manualSeed) 427 | random.seed(opt.manualSeed) 428 | np.random.seed(opt.manualSeed) 429 | torch.manual_seed(opt.manualSeed) 430 | torch.cuda.manual_seed(opt.manualSeed) 431 | 432 | cudnn.benchmark = True 433 | cudnn.deterministic = True 434 | opt.num_gpu = torch.cuda.device_count() 435 | if opt.num_gpu > 1: 436 | print('------ Use multi-GPU setting ------') 437 | print('if you stuck too long time with multi-GPU setting, try to set --workers 0') 438 | opt.workers = opt.workers * opt.num_gpu 439 | 440 | """ previous version 441 | print('To equlize batch stats to 1-GPU setting, the batch_size is multiplied with num_gpu and multiplied batch_size is ', opt.batch_size) 442 | opt.batch_size = opt.batch_size * opt.num_gpu 443 | print('To equalize the number of epochs to 1-GPU setting, num_iter is divided with num_gpu by default.') 444 | If you dont care about it, just commnet out these line.) 445 | opt.num_iter = int(opt.num_iter / opt.num_gpu) 446 | """ 447 | train = trainer(opt) 448 | train.train(opt) 449 | -------------------------------------------------------------------------------- /train_da_local.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import string 5 | import time 6 | 7 | import numpy as np 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | import torch.nn.init as init 11 | import torch.optim as optim 12 | import torch.utils.data 13 | 14 | from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset 15 | from modules.domain_adapt import d_cls_inst 16 | from modules.radam import AdamW, RAdam 17 | from seqda_model import Model 18 | from test import validation 19 | from utils import AttnLabelConverter, Averager, load_char_dict 20 | 21 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 22 | 23 | 24 | class trainer(object): 25 | def __init__(self, opt): 26 | 27 | opt.src_select_data = opt.src_select_data.split('-') 28 | opt.src_batch_ratio = opt.src_batch_ratio.split('-') 29 | opt.tar_select_data = opt.tar_select_data.split('-') 30 | opt.tar_batch_ratio = opt.tar_batch_ratio.split('-') 31 | 32 | """ vocab / character number configuration """ 33 | if opt.sensitive: 34 | # opt.character += 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' 35 | opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). 36 | 37 | if opt.char_dict is not None: 38 | opt.character = load_char_dict(opt.char_dict)[3:-2] # 去除Attention 和 CTC引入的一些特殊符号 39 | 40 | """ model configuration """ 41 | self.converter = AttnLabelConverter(opt.character) 42 | opt.num_class = len(self.converter.character) 43 | 44 | if opt.rgb: 45 | opt.input_channel = 3 46 | self.opt = opt 47 | print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, 48 | opt.output_channel, 49 | opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, 50 | opt.FeatureExtraction, 51 | opt.SequenceModeling, opt.Prediction) 52 | self.save_opt_log(opt) 53 | 54 | self.build_model(opt) 55 | 56 | def dataloader(self, opt): 57 | src_train_data = opt.src_train_data 58 | src_select_data = opt.src_select_data 59 | src_batch_ratio = opt.src_batch_ratio 60 | src_train_dataset = Batch_Balanced_Dataset(opt, src_train_data, src_select_data, 61 | src_batch_ratio) 62 | 63 | tar_train_data = opt.tar_train_data 64 | tar_select_data = opt.tar_select_data 65 | tar_batch_ratio = opt.tar_batch_ratio 66 | tar_train_dataset = Batch_Balanced_Dataset(opt, tar_train_data, tar_select_data, 67 | tar_batch_ratio) 68 | 69 | AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 70 | 71 | valid_dataset = hierarchical_dataset(root=opt.valid_data, opt=opt) 72 | valid_loader = torch.utils.data.DataLoader( 73 | valid_dataset, batch_size=opt.batch_size, 74 | shuffle=True, # 'True' to check training progress with validation function. 75 | num_workers=int(opt.workers), 76 | collate_fn=AlignCollate_valid, pin_memory=True) 77 | return src_train_dataset, tar_train_dataset, valid_loader 78 | 79 | def _optimizer(self, opt): 80 | # filter that only require gradient decent 81 | filtered_parameters = [] 82 | params_num = [] 83 | for p in filter(lambda p: p.requires_grad, self.model.parameters()): 84 | filtered_parameters.append(p) 85 | params_num.append(np.prod(p.size())) 86 | print('Trainable params num : ', sum(params_num)) 87 | # setup optimizer 88 | if opt.optimizer.lower() == 'sgd': 89 | self.optimizer = optim.SGD(self.model.parameters(), lr=opt.lr, momentum=opt.momentum, 90 | weight_decay=opt.weight_decay) 91 | self.d_inst_opt = optim.SGD(self.local_discriminator.parameters(), 92 | lr=opt.lr, momentum=opt.momentum, 93 | weight_decay=opt.weight_decay) 94 | elif opt.optimizer.lower() == 'adam': 95 | self.optimizer = AdamW(self.model.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2), 96 | weight_decay=opt.weight_decay) 97 | self.d_inst_opt = AdamW(self.local_discriminator.parameters(), 98 | betas=(opt.beta1, opt.beta2), 99 | weight_decay=opt.weight_decay) 100 | elif opt.optimizer.lower() == 'radam': 101 | self.optimizer = RAdam(self.model.parameters(), lr=opt.lr, 102 | betas=(opt.beta1, opt.beta2), 103 | weight_decay=opt.weight_decay) 104 | self.d_inst_opt = RAdam(self.local_discriminator.parameters(), 105 | betas=(opt.beta1, opt.beta2), 106 | weight_decay=opt.weight_decay) 107 | else: 108 | self.optimizer = optim.Adadelta(filtered_parameters, lr=0.1 * opt.lr, rho=opt.rho, 109 | eps=opt.eps) 110 | self.d_inst_opt = optim.Adadelta(self.local_discriminator.parameters(), 111 | lr=opt.lr, 112 | rho=opt.rho, 113 | eps=opt.eps) 114 | 115 | print("Optimizer:") 116 | print(self.optimizer) 117 | 118 | def build_model(self, opt): 119 | """建立模型""" 120 | """DataLoder""" 121 | 122 | print('-' * 80) 123 | 124 | """ Define Model """ 125 | self.model = Model(opt) 126 | # Initialize domain classifiers here. 127 | self.local_discriminator = d_cls_inst(fc_size=256) 128 | 129 | self.weight_initializer() 130 | 131 | self.model = torch.nn.DataParallel(self.model).to(device) 132 | self.local_discriminator = torch.nn.DataParallel(self.local_discriminator).to(device) 133 | 134 | """ Define Loss """ 135 | self.criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) 136 | self.D_criterion = torch.nn.BCEWithLogitsLoss().to(device) 137 | 138 | """ Trainer """ 139 | self._optimizer(opt) 140 | 141 | def train(self, opt): 142 | 143 | # np.random.seed(opt.RNG_SEED) 144 | 145 | # src, tar dataloaders 146 | src_dataset, tar_dataset, valid_loader = self.dataloader(opt) 147 | src_dataset_size = src_dataset.total_data_size 148 | tar_dataset_size = tar_dataset.total_data_size 149 | train_size = max([src_dataset_size, tar_dataset_size]) 150 | iters_per_epoch = int(train_size / opt.batch_size) 151 | 152 | self.model.train() 153 | # self.global_discriminator.train() 154 | self.local_discriminator.train() 155 | start_iter = 0 156 | 157 | if opt.continue_model != '': 158 | self.load(opt.continue_model) 159 | print(" [*] Load SUCCESS") 160 | # TODO 关于学习率设置问题 161 | # if opt.decay_flag and start_iter > (opt.num_iter // 2): 162 | # self.d_image_opt.param_groups[0]['lr'] -= (opt.lr / (opt.num_iter // 2)) * ( 163 | # start_iter - opt.num_iter // 2) 164 | # self.d_inst_opt.param_groups[0]['lr'] -= (opt.lr / (opt.num_iter // 2)) * ( 165 | # start_iter - opt.num_iter // 2) 166 | 167 | # loss averager 168 | cls_loss_avg = Averager() 169 | sim_loss_avg = Averager() 170 | loss_avg = Averager() 171 | 172 | # training loop 173 | print('training start !') 174 | start_time = time.time() 175 | best_accuracy = -1 176 | best_norm_ED = 1e+6 177 | # i = start_iter 178 | gamma = 0 179 | omega = 1 180 | epoch = 0 181 | for step in range(start_iter, opt.num_iter + 1): 182 | epoch = step // iters_per_epoch 183 | if opt.decay_flag and step > (opt.num_iter // 2): 184 | # self.d_image_opt.param_groups[0]['lr'] -= (opt.lr / (opt.num_iter // 2)) 185 | self.d_inst_opt.param_groups[0]['lr'] -= (opt.lr / (opt.num_iter // 2)) 186 | 187 | src_image, src_labels = src_dataset.get_batch() 188 | src_image = src_image.to(device) 189 | src_text, src_length = self.converter.encode(src_labels, 190 | batch_max_length=opt.batch_max_length) 191 | 192 | tar_image, tar_labels = tar_dataset.get_batch() 193 | tar_image = tar_image.to(device) 194 | tar_text, tar_length = self.converter.encode(tar_labels, 195 | batch_max_length=opt.batch_max_length) 196 | 197 | # Set gradient to zero... 198 | self.model.zero_grad() 199 | # Domain classifiers 200 | # self.global_discriminator.zero_grad() 201 | self.local_discriminator.zero_grad() 202 | 203 | # Attention # align with Attention.forward 204 | src_preds, src_global_feature, src_local_feature = self.model(src_image, 205 | src_text[:, :-1]) 206 | # src_global_feature = self.model.visual_feature 207 | # src_local_feature = self.model.Prediction.context_history 208 | target = src_text[:, 1:] # without [GO] Symbol 209 | src_cls_loss = self.criterion(src_preds.view(-1, src_preds.shape[-1]), 210 | target.contiguous().view(-1)) 211 | src_global_feature = src_global_feature.view(src_global_feature.shape[0], -1) 212 | src_local_feature = src_local_feature.view(-1, src_local_feature.shape[-1]) 213 | # TODO 去除对tar_text 的依赖 214 | tar_preds, tar_global_feature, tar_local_feature = self.model(tar_image, 215 | tar_text[:, :-1], 216 | is_train=False) 217 | # tar_global_feature = self.model.visual_feature 218 | # tar_local_feature = self.model.Prediction.context_history 219 | tar_global_feature = tar_global_feature.view(tar_global_feature.shape[0], -1) 220 | tar_local_feature = tar_local_feature.view(-1, tar_local_feature.shape[-1]) 221 | 222 | # Add domain adaption elements 223 | # setup hyperparameter 224 | if step % 2000 == 0: 225 | p = float(step + start_iter) / opt.num_iter 226 | gamma = 2. / (1. + np.exp(-10 * p)) - 1 227 | omega = 1 - 1. / (1. + np.exp(-10 * p)) 228 | # self.global_discriminator.module.set_beta(gamma) 229 | self.local_discriminator.module.set_beta(gamma) 230 | 231 | # src_d_img_score = self.global_discriminator(src_global_feature) 232 | src_d_inst_score = self.local_discriminator(src_local_feature) 233 | # tar_d_img_score = self.global_discriminator(tar_global_feature) 234 | tar_d_inst_score = self.local_discriminator(tar_local_feature) 235 | 236 | # src_d_img_loss = self.D_criterion(src_d_img_score, 237 | # torch.zeros_like(src_d_img_score).to(device)) 238 | src_d_inst_loss = self.D_criterion(src_d_inst_score, 239 | torch.zeros_like(src_d_inst_score).to(device)) 240 | # tar_d_img_loss = self.D_criterion(tar_d_img_score, 241 | # torch.ones_like(tar_d_img_score).to(device)) 242 | tar_d_inst_loss = self.D_criterion(tar_d_inst_score, 243 | torch.ones_like(tar_d_inst_score).to(device)) 244 | # d_img_loss = src_d_img_loss + tar_d_img_loss 245 | d_inst_loss = src_d_inst_loss + tar_d_inst_loss 246 | 247 | # Add domain loss 248 | loss = src_cls_loss.mean() + omega * (d_inst_loss.mean()) 249 | loss_avg.add(loss) 250 | cls_loss_avg.add(src_cls_loss) 251 | sim_loss_avg.add(d_inst_loss) 252 | 253 | # frcnn backward 254 | loss.backward() 255 | # clip_gradient(self.model, 10.) 256 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 257 | opt.grad_clip) # gradient clipping with 5 (Default) 258 | # frcnn optimizer update 259 | self.optimizer.step() 260 | # domain optimizer update 261 | self.d_inst_opt.step() 262 | # self.d_image_opt.step() 263 | 264 | # validation part 265 | if step % opt.valInterval == 0 and step != 0: 266 | 267 | elapsed_time = time.time() - start_time 268 | print( 269 | f'[{step}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} CLS_Loss: {cls_loss_avg.val():0.5f} SIMI_Loss: {sim_loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}') 270 | # for log 271 | with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: 272 | log.write( 273 | f'[{step}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n') 274 | loss_avg.reset() 275 | cls_loss_avg.reset() 276 | sim_loss_avg.reset() 277 | 278 | self.model.eval() 279 | with torch.no_grad(): 280 | valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation( 281 | self.model, self.criterion, valid_loader, self.converter, opt) 282 | 283 | self.print_prediction_result(preds, labels, log) 284 | 285 | valid_log = f'[{step}/{opt.num_iter}] valid loss: {valid_loss:0.5f}' 286 | valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}' 287 | print(valid_log) 288 | log.write(valid_log + '\n') 289 | 290 | self.model.train() 291 | 292 | self.local_discriminator.train() 293 | 294 | # keep best accuracy model 295 | 296 | if current_accuracy > best_accuracy: 297 | best_accuracy = current_accuracy 298 | save_name = f'./saved_models/{opt.experiment_name}/best_accuracy.pth' 299 | self.save(opt, save_name) 300 | if current_norm_ED < best_norm_ED: 301 | best_norm_ED = current_norm_ED 302 | save_name = f'./saved_models/{opt.experiment_name}/best_norm_ED.pth' 303 | self.save(opt, save_name) 304 | 305 | best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}' 306 | print(best_model_log) 307 | log.write(best_model_log + '\n') 308 | 309 | # save model per 1e+5 iter. 310 | if (step + 1) % 1e+5 == 0: 311 | save_name = f'./saved_models/{opt.experiment_name}/iter_{step+1}.pth' 312 | self.save(opt, save_name) 313 | 314 | def load(self, saved_model): 315 | params = torch.load(saved_model) 316 | 317 | if 'model' not in params: 318 | self.model.load_state_dict(params) 319 | else: 320 | self.model.load_state_dict(params['model']) 321 | 322 | if 'local_discriminator' in params: 323 | self.local_discriminator.load_state_dict(params['local_discriminator']) 324 | else: 325 | print(params.keys()) 326 | if 'optimizer' in params: 327 | self.optimizer.load_state_dict(params['optimizer']) 328 | lr = self.optimizer.param_groups[0]['lr'] 329 | 330 | if 'd_inst_opt' in params: 331 | self.d_inst_opt.load_state_dict(params['d_inst_opt']) 332 | 333 | def save(self, opt, save_name): 334 | 335 | params = {} 336 | 337 | params['model'] = self.model.state_dict() 338 | # params['global_discriminator'] = self.global_discriminator.state_dict() 339 | params['local_discriminator'] = self.local_discriminator.state_dict() 340 | 341 | # for training 342 | params['optimizer'] = self.optimizer.state_dict() 343 | 344 | params['d_inst_opt'] = self.d_inst_opt.state_dict() 345 | 346 | torch.save(params, save_name) 347 | print('Successfully save model: {}'.format(save_name)) 348 | 349 | def weight_initializer(self): 350 | # weight initialization 351 | for name, param in self.model.named_parameters(): 352 | if 'localization_fc2' in name: 353 | print(f'Skip {name} as it is already initialized') 354 | continue 355 | try: 356 | if 'bias' in name: 357 | init.constant_(param, 0.0) 358 | elif 'weight' in name: 359 | init.kaiming_normal_(param) 360 | except Exception as e: # for batchnorm. 361 | if 'weight' in name: 362 | param.data.fill_(1) 363 | continue 364 | 365 | def save_opt_log(self, opt): 366 | """ final options """ 367 | # print(opt) 368 | with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file: 369 | opt_log = '------------ Options -------------\n' 370 | args = vars(opt) 371 | for k, v in args.items(): 372 | opt_log += f'{str(k)}: {str(v)}\n' 373 | opt_log += '---------------------------------------\n' 374 | print(opt_log) 375 | opt_file.write(opt_log) 376 | 377 | def print_prediction_result(self, preds, labels, fp_log): 378 | """ 379 | fp-logwenjian 380 | :param preds: 381 | :param labels: 382 | :param fp_log: 日志文件指针 383 | :return: 384 | """ 385 | for pred, gt in zip(preds[:5], labels[:5]): 386 | if 'Attn' in opt.Prediction: 387 | pred = pred[:pred.find('[s]')] 388 | gt = gt[:gt.find('[s]')] 389 | print(f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}') 390 | fp_log.write(f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}\n') 391 | 392 | 393 | if __name__ == '__main__': 394 | parser = argparse.ArgumentParser() 395 | parser.add_argument('--experiment_name', help='Where to store logs and models') 396 | parser.add_argument('--src_train_data', required=True, help='path to training dataset') 397 | parser.add_argument('--tar_train_data', required=True, help='path to training dataset') 398 | parser.add_argument('--valid_data', required=True, help='path to validation dataset') 399 | parser.add_argument('--manualSeed', type=int, default=1111, help='for random seed setting') 400 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) 401 | parser.add_argument('--batch_size', type=int, default=192, help='input batch size') 402 | parser.add_argument('--num_iter', type=int, default=300000, 403 | help='number of iterations to train for') 404 | parser.add_argument('--valInterval', type=int, default=500, 405 | help='Interval between each validation') 406 | parser.add_argument('--continue_model', default='', help="path to model to continue training") 407 | parser.add_argument('--adam', action='store_true', 408 | help='Whether to use adam (default is Adadelta)') 409 | 410 | # # Optimization options 411 | parser.add_argument('--optimizer', type=str, default='adadelta', 412 | help='optimizer type: adam , Radam, Adadelta') 413 | parser.add_argument('--lr', type=float, default=0.1, 414 | help='learning rate, default=0.1 for adam') 415 | parser.add_argument('--decay_flag', action='store_true', help='for learning rate decay') 416 | parser.add_argument('--use_tfboard', action='store_true', help='use_tfboard') 417 | parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for adam. default=0.9') 418 | parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for adam. default=0.9') 419 | # parser.add_argument('--weight_decay', type=float, default=0.9, help='weight_decay for adam. default=0.9') 420 | parser.add_argument('--schedule', type=int, nargs='+', default=[150, 225], 421 | help='Decrease learning rate at these epochs.') 422 | parser.add_argument('--gamma', type=float, default=0.1, 423 | help='LR is multiplied by gamma on schedule.') 424 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 425 | help='momentum') 426 | parser.add_argument('--weight_decay', '--wd', default=1e-4, type=float, 427 | metavar='W', help='weight decay (default: 1e-4)') 428 | 429 | parser.add_argument('--rho', type=float, default=0.95, 430 | help='decay rate rho for Adadelta. default=0.95') 431 | parser.add_argument('--eps', type=float, default=1e-8, help='eps for Adadelta. default=1e-8') 432 | parser.add_argument('--grad_clip', type=float, default=5, 433 | help='gradient clipping value. default=5') 434 | 435 | """ Data processing """ 436 | parser.add_argument('--src_select_data', type=str, default='MJ-ST', 437 | help='select training data (default is MJ-ST, which means MJ and ST used as training data)') 438 | parser.add_argument('--src_batch_ratio', type=str, default='0.5-0.5', 439 | help='assign ratio for each selected data in the batch') 440 | parser.add_argument('--tar_select_data', type=str, default='real_data', 441 | help='select training data (default is real_data, which means MJ and ST used as training data)') 442 | parser.add_argument('--tar_batch_ratio', type=str, default='1', 443 | help='assign ratio for each selected data in the batch') 444 | parser.add_argument('--total_data_usage_ratio', type=str, default='1.0', 445 | help='total data usage ratio, this ratio is multiplied to total number of data.') 446 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') 447 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') 448 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') 449 | parser.add_argument('--rgb', action='store_true', help='use rgb input') 450 | parser.add_argument('--char_dict', type=str, default=None, 451 | help="path to char dict: dataset/iam/char_dict.txt") 452 | parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', 453 | help='character label') 454 | parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode') 455 | parser.add_argument('--filtering_special_chars', action='store_true', 456 | help='for sensitive character mode') 457 | parser.add_argument('--PAD', action='store_true', 458 | help='whether to keep ratio then pad for image resize') 459 | parser.add_argument('--data_filtering_off', action='store_true', 460 | help='for data_filtering_off mode') 461 | """ Model Architecture """ 462 | parser.add_argument('--Transformation', type=str, required=True, 463 | help='Transformation stage. None|TPS') 464 | parser.add_argument('--FeatureExtraction', type=str, required=True, 465 | help='FeatureExtraction stage. VGG|RCNN|ResNet') 466 | parser.add_argument('--SequenceModeling', type=str, required=True, 467 | help='SequenceModeling stage. None|BiLSTM') 468 | parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn') 469 | parser.add_argument('--num_fiducial', type=int, default=20, 470 | help='number of fiducial points of TPS-STN') 471 | parser.add_argument('--input_channel', type=int, default=1, 472 | help='the number of input channel of Feature extractor') 473 | parser.add_argument('--output_channel', type=int, default=512, 474 | help='the number of output channel of Feature extractor') 475 | parser.add_argument('--hidden_size', type=int, default=256, 476 | help='the size of the LSTM hidden state') 477 | 478 | opt = parser.parse_args() 479 | 480 | if not opt.experiment_name: 481 | opt.experiment_name = f'{opt.Transformation}-{opt.FeatureExtraction}-{opt.SequenceModeling}-{opt.Prediction}' 482 | opt.experiment_name += f'-Seed{opt.manualSeed}' 483 | else: 484 | experiment_name = f'{opt.Transformation}-{opt.FeatureExtraction}-{opt.SequenceModeling}-{opt.Prediction}' 485 | experiment_name += f'-Seed{opt.manualSeed}' 486 | opt.experiment_name = experiment_name + opt.experiment_name 487 | # print(opt.experiment_name) 488 | 489 | os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True) 490 | 491 | """ Seed and GPU setting """ 492 | # print("Random Seed: ", opt.manualSeed) 493 | random.seed(opt.manualSeed) 494 | np.random.seed(opt.manualSeed) 495 | torch.manual_seed(opt.manualSeed) 496 | torch.cuda.manual_seed(opt.manualSeed) 497 | 498 | cudnn.benchmark = True 499 | cudnn.deterministic = True 500 | opt.num_gpu = torch.cuda.device_count() 501 | # print('device count', opt.num_gpu) 502 | if opt.num_gpu > 1: 503 | print('------ Use multi-GPU setting ------') 504 | print('if you stuck too long time with multi-GPU setting, try to set --workers 0') 505 | # check multi-GPU issue https://github.com/clovaai/deep-text-recognition-benchmark/issues/1 506 | opt.workers = opt.workers * opt.num_gpu 507 | 508 | """ previous version 509 | print('To equlize batch stats to 1-GPU setting, the batch_size is multiplied with num_gpu and multiplied batch_size is ', opt.batch_size) 510 | opt.batch_size = opt.batch_size * opt.num_gpu 511 | print('To equalize the number of epochs to 1-GPU setting, num_iter is divided with num_gpu by default.') 512 | If you dont care about it, just commnet out these line.) 513 | opt.num_iter = int(opt.num_iter / opt.num_gpu) 514 | """ 515 | train = trainer(opt) 516 | train.train(opt) 517 | -------------------------------------------------------------------------------- /train_da_global_local_selected.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import string 5 | import time 6 | 7 | import numpy as np 8 | import torch 9 | import torch.backends.cudnn as cudnn 10 | import torch.nn.init as init 11 | import torch.optim as optim 12 | import torch.utils.data 13 | 14 | from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset 15 | from modules.domain_adapt import d_cls_inst 16 | from modules.radam import AdamW, RAdam 17 | from seqda_model import Model 18 | from test import validation 19 | from utils import AttnLabelConverter, Averager, load_char_dict 20 | 21 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 22 | 23 | 24 | def filter_local_features(opt, 25 | source_context_history, source_prediction, 26 | target_context_history, target_prediction): 27 | feature_dim = source_context_history.size()[-1] 28 | 29 | source_feature = source_context_history.reshape(-1, feature_dim) 30 | target_feature = target_context_history.reshape(-1, feature_dim) 31 | 32 | # print(type(pred_class),pred_class) 33 | source_pred_score, source_pred_class = source_prediction.max(-1) 34 | target_pred_score, target_pred_class = target_prediction.max(-1) 35 | source_valid_char_index = (source_pred_score.reshape(-1, ) > opt.pc).nonzero().reshape(-1, ) 36 | source_valid_char_feature = source_feature.reshape(-1, feature_dim).index_select(0, 37 | source_valid_char_index) 38 | target_valid_char_index = (target_pred_score.reshape(-1, ) > opt.pc).nonzero().reshape(-1, ) 39 | target_valid_char_feature = target_feature.reshape(-1, feature_dim).index_select(0, 40 | target_valid_char_index) 41 | 42 | return source_valid_char_feature, target_valid_char_feature 43 | 44 | 45 | class trainer(object): 46 | def __init__(self, opt): 47 | 48 | opt.src_select_data = opt.src_select_data.split('-') 49 | opt.src_batch_ratio = opt.src_batch_ratio.split('-') 50 | opt.tar_select_data = opt.tar_select_data.split('-') 51 | opt.tar_batch_ratio = opt.tar_batch_ratio.split('-') 52 | 53 | """ vocab / character number configuration """ 54 | if opt.sensitive: 55 | # opt.character += 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' 56 | opt.character = string.printable[:-6] # same with ASTER setting (use 94 char). 57 | 58 | if opt.char_dict is not None: 59 | opt.character = load_char_dict(opt.char_dict)[3:-2] # 去除Attention 和 CTC引入的一些特殊符号 60 | 61 | """ model configuration """ 62 | 63 | self.converter = AttnLabelConverter(opt.character) 64 | opt.num_class = len(self.converter.character) 65 | 66 | if opt.rgb: 67 | opt.input_channel = 3 68 | self.opt = opt 69 | print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, 70 | opt.output_channel, 71 | opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, 72 | opt.FeatureExtraction, 73 | opt.SequenceModeling, opt.Prediction) 74 | self.save_opt_log(opt) 75 | 76 | self.build_model(opt) 77 | 78 | def dataloader(self, opt): 79 | src_train_data = opt.src_train_data 80 | src_select_data = opt.src_select_data 81 | src_batch_ratio = opt.src_batch_ratio 82 | src_train_dataset = Batch_Balanced_Dataset(opt, src_train_data, src_select_data, 83 | src_batch_ratio) 84 | 85 | tar_train_data = opt.tar_train_data 86 | tar_select_data = opt.tar_select_data 87 | tar_batch_ratio = opt.tar_batch_ratio 88 | tar_train_dataset = Batch_Balanced_Dataset(opt, tar_train_data, tar_select_data, 89 | tar_batch_ratio) 90 | 91 | AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) 92 | 93 | valid_dataset = hierarchical_dataset(root=opt.valid_data, opt=opt) 94 | valid_loader = torch.utils.data.DataLoader( 95 | valid_dataset, batch_size=opt.batch_size, 96 | shuffle=True, # 'True' to check training progress with validation function. 97 | num_workers=int(opt.workers), 98 | collate_fn=AlignCollate_valid, pin_memory=True) 99 | return src_train_dataset, tar_train_dataset, valid_loader 100 | 101 | def _optimizer(self, opt): 102 | # filter that only require gradient decent 103 | filtered_parameters = [] 104 | params_num = [] 105 | for p in filter(lambda p: p.requires_grad, self.model.parameters()): 106 | filtered_parameters.append(p) 107 | params_num.append(np.prod(p.size())) 108 | print('Trainable params num : ', sum(params_num)) 109 | # setup optimizer 110 | if opt.optimizer.lower() == 'sgd': 111 | self.optimizer = optim.SGD(self.model.parameters(), lr=opt.lr, momentum=opt.momentum, 112 | weight_decay=opt.weight_decay) 113 | self.d_image_opt = optim.SGD(self.global_discriminator.parameters(), lr=opt.lr, 114 | momentum=opt.momentum, 115 | weight_decay=opt.weight_decay) 116 | self.d_inst_opt = optim.SGD(self.local_discriminator.parameters(), 117 | lr=opt.lr, momentum=opt.momentum, 118 | weight_decay=opt.weight_decay) 119 | elif opt.optimizer.lower() == 'adam': 120 | self.optimizer = AdamW(self.model.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2), 121 | weight_decay=opt.weight_decay) 122 | self.d_image_opt = AdamW(self.global_discriminator.parameters(), lr=opt.lr, 123 | betas=(opt.beta1, opt.beta2), 124 | weight_decay=opt.weight_decay) 125 | self.d_inst_opt = AdamW(self.local_discriminator.parameters(), 126 | betas=(opt.beta1, opt.beta2), 127 | weight_decay=opt.weight_decay) 128 | elif opt.optimizer.lower() == 'radam': 129 | self.optimizer = RAdam(self.model.parameters(), lr=opt.lr, 130 | betas=(opt.beta1, opt.beta2), 131 | weight_decay=opt.weight_decay) 132 | self.d_image_opt = RAdam(self.global_discriminator.parameters(), lr=opt.lr, 133 | betas=(opt.beta1, opt.beta2), 134 | weight_decay=opt.weight_decay) 135 | self.d_inst_opt = RAdam(self.local_discriminator.parameters(), 136 | betas=(opt.beta1, opt.beta2), 137 | weight_decay=opt.weight_decay) 138 | 139 | 140 | else: 141 | self.optimizer = optim.Adadelta(filtered_parameters, lr=0.1 * opt.lr, rho=opt.rho, 142 | eps=opt.eps) 143 | self.d_image_opt = optim.Adadelta(self.global_discriminator.parameters(), 144 | lr=opt.lr, 145 | rho=opt.rho, 146 | eps=opt.eps) 147 | self.d_inst_opt = optim.Adadelta(self.local_discriminator.parameters(), 148 | lr=opt.lr, 149 | rho=opt.rho, 150 | eps=opt.eps) 151 | 152 | print("Optimizer:") 153 | print(self.optimizer) 154 | 155 | def build_model(self, opt): 156 | """建立模型""" 157 | 158 | print('-' * 80) 159 | 160 | """ Define Model """ 161 | self.model = Model(opt) 162 | # Initialize domain classifiers here. 163 | self.global_discriminator = d_cls_inst(fc_size=13312) 164 | self.local_discriminator = d_cls_inst(fc_size=256) 165 | 166 | self.weight_initializer() 167 | self.model = torch.nn.DataParallel(self.model).to(device) 168 | self.global_discriminator = torch.nn.DataParallel(self.global_discriminator).to(device) 169 | self.local_discriminator = torch.nn.DataParallel(self.local_discriminator).to(device) 170 | 171 | """ Define Loss """ 172 | if 'CTC' in opt.Prediction: 173 | self.criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) 174 | else: 175 | # ignore [GO] token = ignore index 0 176 | self.criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) 177 | self.D_criterion = torch.nn.BCEWithLogitsLoss().to(device) 178 | 179 | """ Trainer """ 180 | self._optimizer(opt) 181 | 182 | def train(self, opt): 183 | # src, tar dataloaders 184 | src_dataset, tar_dataset, valid_loader = self.dataloader(opt) 185 | src_dataset_size = src_dataset.total_data_size 186 | tar_dataset_size = tar_dataset.total_data_size 187 | train_size = max([src_dataset_size, tar_dataset_size]) 188 | iters_per_epoch = int(train_size / opt.batch_size) 189 | 190 | # Modify train size. Make sure both are of same size. 191 | # Modify training loop to continue giving src loss after tar is done. 192 | 193 | self.model.train() 194 | self.global_discriminator.train() 195 | self.local_discriminator.train() 196 | start_iter = 0 197 | 198 | if opt.continue_model != '': 199 | self.load(opt.continue_model) 200 | print(" [*] Load SUCCESS") 201 | 202 | # loss averager 203 | cls_loss_avg = Averager() 204 | sim_loss_avg = Averager() 205 | loss_avg = Averager() 206 | 207 | # training loop 208 | print('training start !') 209 | start_time = time.time() 210 | best_accuracy = -1 211 | best_norm_ED = 1e+6 212 | # i = start_iter 213 | gamma = 0 214 | omega = 1 215 | epoch = 0 216 | for step in range(start_iter, opt.num_iter + 1): 217 | epoch = step // iters_per_epoch 218 | if opt.decay_flag and step > (opt.num_iter // 2): 219 | self.d_image_opt.param_groups[0]['lr'] -= (opt.lr / (opt.num_iter // 2)) 220 | self.d_inst_opt.param_groups[0]['lr'] -= (opt.lr / (opt.num_iter // 2)) 221 | 222 | src_image, src_labels = src_dataset.get_batch() 223 | src_image = src_image.to(device) 224 | src_text, src_length = self.converter.encode(src_labels, 225 | batch_max_length=opt.batch_max_length) 226 | 227 | tar_image, tar_labels = tar_dataset.get_batch() 228 | tar_image = tar_image.to(device) 229 | tar_text, tar_length = self.converter.encode(tar_labels, 230 | batch_max_length=opt.batch_max_length) 231 | 232 | # Set gradient to zero... 233 | self.model.zero_grad() 234 | # Domain classifiers 235 | self.global_discriminator.zero_grad() 236 | self.local_discriminator.zero_grad() 237 | 238 | # Attention # align with Attention.forward 239 | src_preds, src_global_feature, src_local_feature = self.model(src_image, 240 | src_text[:, :-1]) 241 | # src_global_feature = self.model.visual_feature 242 | # src_local_feature = self.model.Prediction.context_history 243 | target = src_text[:, 1:] # without [GO] Symbol 244 | src_cls_loss = self.criterion(src_preds.view(-1, src_preds.shape[-1]), 245 | target.contiguous().view(-1)) 246 | src_global_feature = src_global_feature.view(src_global_feature.shape[0], -1) 247 | src_local_feature = src_local_feature.view(-1, src_local_feature.shape[-1]) 248 | 249 | tar_preds, tar_global_feature, tar_local_feature = self.model(tar_image, 250 | tar_text[:, :-1], 251 | is_train=False) 252 | # tar_global_feature = self.model.visual_feature 253 | # tar_local_feature = self.model.Prediction.context_history 254 | tar_global_feature = tar_global_feature.view(tar_global_feature.shape[0], -1) 255 | tar_local_feature = tar_local_feature.view(-1, tar_local_feature.shape[-1]) 256 | 257 | src_local_feature, tar_local_feature = filter_local_features(opt, src_local_feature, 258 | src_preds, 259 | tar_local_feature, 260 | tar_preds) 261 | 262 | # Add domain adaption elements 263 | # setup hyperparameter 264 | if step % 2000 == 0: 265 | p = float(step + start_iter) / opt.num_iter 266 | gamma = 2. / (1. + np.exp(-10 * p)) - 1 267 | omega = 1 - 1. / (1. + np.exp(-10 * p)) 268 | self.global_discriminator.module.set_beta(gamma) 269 | self.local_discriminator.module.set_beta(gamma) 270 | 271 | src_d_img_score = self.global_discriminator(src_global_feature) 272 | src_d_inst_score = self.local_discriminator(src_local_feature) 273 | tar_d_img_score = self.global_discriminator(tar_global_feature) 274 | tar_d_inst_score = self.local_discriminator(tar_local_feature) 275 | 276 | src_d_img_loss = self.D_criterion(src_d_img_score, 277 | torch.zeros_like(src_d_img_score).to(device)) 278 | src_d_inst_loss = self.D_criterion(src_d_inst_score, 279 | torch.zeros_like(src_d_inst_score).to(device)) 280 | tar_d_img_loss = self.D_criterion(tar_d_img_score, 281 | torch.ones_like(tar_d_img_score).to(device)) 282 | tar_d_inst_loss = self.D_criterion(tar_d_inst_score, 283 | torch.ones_like(tar_d_inst_score).to(device)) 284 | d_img_loss = src_d_img_loss + tar_d_img_loss 285 | d_inst_loss = src_d_inst_loss + tar_d_inst_loss 286 | 287 | # Add domain loss 288 | loss = src_cls_loss.mean() + omega * (d_img_loss.mean() + d_inst_loss.mean()) 289 | loss_avg.add(loss) 290 | cls_loss_avg.add(src_cls_loss) 291 | sim_loss_avg.add(d_img_loss + d_inst_loss) 292 | 293 | # frcnn backward 294 | loss.backward() 295 | # clip_gradient(self.model, 10.) 296 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), 297 | opt.grad_clip) # gradient clipping with 5 (Default) 298 | # frcnn optimizer update 299 | self.optimizer.step() 300 | # domain optimizer update 301 | self.d_inst_opt.step() 302 | self.d_image_opt.step() 303 | 304 | # validation part 305 | if step % opt.valInterval == 0: 306 | 307 | elapsed_time = time.time() - start_time 308 | print( 309 | f'[{step}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} CLS_Loss: {cls_loss_avg.val():0.5f} SIMI_Loss: {sim_loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}') 310 | # for log 311 | with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a') as log: 312 | log.write( 313 | f'[{step}/{opt.num_iter}] Loss: {loss_avg.val():0.5f} elapsed_time: {elapsed_time:0.5f}\n') 314 | loss_avg.reset() 315 | cls_loss_avg.reset() 316 | sim_loss_avg.reset() 317 | 318 | self.model.eval() 319 | with torch.no_grad(): 320 | valid_loss, current_accuracy, current_norm_ED, preds, labels, infer_time, length_of_data = validation( 321 | self.model, self.criterion, valid_loader, self.converter, opt) 322 | 323 | self.print_prediction_result(preds, labels, log) 324 | 325 | valid_log = f'[{step}/{opt.num_iter}] valid loss: {valid_loss:0.5f}' 326 | valid_log += f' accuracy: {current_accuracy:0.3f}, norm_ED: {current_norm_ED:0.2f}' 327 | print(valid_log) 328 | log.write(valid_log + '\n') 329 | 330 | self.model.train() 331 | self.global_discriminator.train() 332 | self.local_discriminator.train() 333 | 334 | # keep best accuracy model 335 | 336 | if current_accuracy > best_accuracy: 337 | best_accuracy = current_accuracy 338 | save_name = f'./saved_models/{opt.experiment_name}/best_accuracy.pth' 339 | self.save(opt, save_name) 340 | if current_norm_ED < best_norm_ED: 341 | best_norm_ED = current_norm_ED 342 | save_name = f'./saved_models/{opt.experiment_name}/best_norm_ED.pth' 343 | self.save(opt, save_name) 344 | 345 | best_model_log = f'best_accuracy: {best_accuracy:0.3f}, best_norm_ED: {best_norm_ED:0.2f}' 346 | print(best_model_log) 347 | log.write(best_model_log + '\n') 348 | 349 | # save model per 1e+5 iter. 350 | if (step + 1) % 1e+5 == 0: 351 | save_name = f'./saved_models/{opt.experiment_name}/iter_{step+1}.pth' 352 | self.save(opt, save_name) 353 | 354 | def load(self, saved_model): 355 | params = torch.load(saved_model) 356 | 357 | if 'model' not in params: 358 | self.model.load_state_dict(params) 359 | else: 360 | self.model.load_state_dict(params['model']) 361 | if 'global_discriminator' in params: 362 | self.global_discriminator.load_state_dict(params['global_discriminator']) 363 | if 'local_discriminator' in params: 364 | self.local_discriminator.load_state_dict(params['local_discriminator']) 365 | else: 366 | print(params.keys()) 367 | if 'optimizer' in params: 368 | self.optimizer.load_state_dict(params['optimizer']) 369 | lr = self.optimizer.param_groups[0]['lr'] 370 | if 'd_image_opt' in params: 371 | self.d_image_opt.load_state_dict(params['d_image_opt']) 372 | if 'd_inst_opt' in params: 373 | self.d_inst_opt.load_state_dict(params['d_inst_opt']) 374 | 375 | def save(self, opt, save_name): 376 | 377 | params = {} 378 | 379 | params['model'] = self.model.state_dict() 380 | params['global_discriminator'] = self.global_discriminator.state_dict() 381 | params['local_discriminator'] = self.local_discriminator.state_dict() 382 | 383 | # for training 384 | params['optimizer'] = self.optimizer.state_dict() 385 | params['d_image_opt'] = self.d_image_opt.state_dict() 386 | params['d_inst_opt'] = self.d_inst_opt.state_dict() 387 | # params['pooling_mode'] = opt.pooling_mode 388 | # params['class_agnostic'] = opt.class_agnostic 389 | 390 | torch.save(params, save_name) 391 | print('Successfully save model: {}'.format(save_name)) 392 | 393 | def weight_initializer(self): 394 | # weight initialization 395 | for name, param in self.model.named_parameters(): 396 | if 'localization_fc2' in name: 397 | print(f'Skip {name} as it is already initialized') 398 | continue 399 | try: 400 | if 'bias' in name: 401 | init.constant_(param, 0.0) 402 | elif 'weight' in name: 403 | init.kaiming_normal_(param) 404 | except Exception as e: # for batchnorm. 405 | if 'weight' in name: 406 | param.data.fill_(1) 407 | continue 408 | 409 | def save_opt_log(self, opt): 410 | """ final options """ 411 | # print(opt) 412 | with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a') as opt_file: 413 | opt_log = '------------ Options -------------\n' 414 | args = vars(opt) 415 | for k, v in args.items(): 416 | opt_log += f'{str(k)}: {str(v)}\n' 417 | opt_log += '---------------------------------------\n' 418 | print(opt_log) 419 | opt_file.write(opt_log) 420 | 421 | def print_prediction_result(self, preds, labels, fp_log): 422 | """ 423 | fp-logwenjian 424 | :param preds: 425 | :param labels: 426 | :param fp_log: 日志文件指针 427 | :return: 428 | """ 429 | for pred, gt in zip(preds[:5], labels[:5]): 430 | if 'Attn' in opt.Prediction: 431 | pred = pred[:pred.find('[s]')] 432 | gt = gt[:gt.find('[s]')] 433 | print(f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}') 434 | fp_log.write(f'{pred:20s}, gt: {gt:20s}, {str(pred == gt)}\n') 435 | 436 | 437 | if __name__ == '__main__': 438 | parser = argparse.ArgumentParser() 439 | parser.add_argument('--experiment_name', help='Where to store logs and models') 440 | parser.add_argument('--src_train_data', required=True, help='path to training dataset') 441 | parser.add_argument('--tar_train_data', required=True, help='path to training dataset') 442 | parser.add_argument('--valid_data', required=True, help='path to validation dataset') 443 | parser.add_argument('--manualSeed', type=int, default=1111, help='for random seed setting') 444 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=4) 445 | parser.add_argument('--batch_size', type=int, default=192, help='input batch size') 446 | parser.add_argument('--num_iter', type=int, default=300000, 447 | help='number of iterations to train for') 448 | parser.add_argument('--valInterval', type=int, default=500, 449 | help='Interval between each validation') 450 | parser.add_argument('--continue_model', default='', help="path to model to continue training") 451 | parser.add_argument('--adam', action='store_true', 452 | help='Whether to use adam (default is Adadelta)') 453 | 454 | # # Optimization options 455 | parser.add_argument('--optimizer', type=str, default='adadelta', 456 | help='optimizer type: adam , Radam, Adadelta') 457 | parser.add_argument('--lr', type=float, default=0.1, 458 | help='learning rate, default=0.1 for adam') 459 | parser.add_argument('--decay_flag', action='store_true', help='for learning rate decay') 460 | parser.add_argument('--use_tfboard', action='store_true', help='use_tfboard') 461 | parser.add_argument('--beta1', type=float, default=0.9, help='beta1 for adam. default=0.9') 462 | parser.add_argument('--beta2', type=float, default=0.999, help='beta2 for adam. default=0.9') 463 | # parser.add_argument('--weight_decay', type=float, default=0.9, help='weight_decay for adam. default=0.9') 464 | parser.add_argument('--schedule', type=int, nargs='+', default=[150, 225], 465 | help='Decrease learning rate at these epochs.') 466 | parser.add_argument('--pc', type=float, default=0.1, 467 | help='confidence threshold,, 0,0.1,0.2,0.4,0.8.') 468 | parser.add_argument('--gamma', type=float, default=0.1, 469 | help='LR is multiplied by gamma on schedule.') 470 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 471 | help='momentum') 472 | parser.add_argument('--weight_decay', '--wd', default=1e-4, type=float, 473 | metavar='W', help='weight decay (default: 1e-4)') 474 | 475 | parser.add_argument('--rho', type=float, default=0.95, 476 | help='decay rate rho for Adadelta. default=0.95') 477 | parser.add_argument('--eps', type=float, default=1e-8, help='eps for Adadelta. default=1e-8') 478 | parser.add_argument('--grad_clip', type=float, default=5, 479 | help='gradient clipping value. default=5') 480 | 481 | """ Data processing """ 482 | parser.add_argument('--src_select_data', type=str, default='MJ-ST', 483 | help='select training data (default is MJ-ST, which means MJ and ST used as training data)') 484 | parser.add_argument('--src_batch_ratio', type=str, default='0.5-0.5', 485 | help='assign ratio for each selected data in the batch') 486 | parser.add_argument('--tar_select_data', type=str, default='real_data', 487 | help='select training data (default is real_data, which means MJ and ST used as training data)') 488 | parser.add_argument('--tar_batch_ratio', type=str, default='1', 489 | help='assign ratio for each selected data in the batch') 490 | parser.add_argument('--total_data_usage_ratio', type=str, default='1.0', 491 | help='total data usage ratio, this ratio is multiplied to total number of data.') 492 | parser.add_argument('--batch_max_length', type=int, default=25, help='maximum-label-length') 493 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image') 494 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image') 495 | parser.add_argument('--rgb', action='store_true', help='use rgb input') 496 | parser.add_argument('--char_dict', type=str, default=None, 497 | help="path to char dict: dataset/iam/char_dict.txt") 498 | parser.add_argument('--character', type=str, default='0123456789abcdefghijklmnopqrstuvwxyz', 499 | help='character label') 500 | parser.add_argument('--sensitive', action='store_true', help='for sensitive character mode') 501 | parser.add_argument('--filtering_special_chars', action='store_true', 502 | help='for sensitive character mode') 503 | parser.add_argument('--PAD', action='store_true', 504 | help='whether to keep ratio then pad for image resize') 505 | parser.add_argument('--data_filtering_off', action='store_true', 506 | help='for data_filtering_off mode') 507 | """ Model Architecture """ 508 | parser.add_argument('--Transformation', type=str, required=True, 509 | help='Transformation stage. None|TPS') 510 | parser.add_argument('--FeatureExtraction', type=str, required=True, 511 | help='FeatureExtraction stage. VGG|RCNN|ResNet') 512 | parser.add_argument('--SequenceModeling', type=str, required=True, 513 | help='SequenceModeling stage. None|BiLSTM') 514 | parser.add_argument('--Prediction', type=str, required=True, help='Prediction stage. CTC|Attn') 515 | parser.add_argument('--num_fiducial', type=int, default=20, 516 | help='number of fiducial points of TPS-STN') 517 | parser.add_argument('--input_channel', type=int, default=1, 518 | help='the number of input channel of Feature extractor') 519 | parser.add_argument('--output_channel', type=int, default=512, 520 | help='the number of output channel of Feature extractor') 521 | parser.add_argument('--hidden_size', type=int, default=256, 522 | help='the size of the LSTM hidden state') 523 | 524 | opt = parser.parse_args() 525 | 526 | if not opt.experiment_name: 527 | opt.experiment_name = f'{opt.Transformation}-{opt.FeatureExtraction}-{opt.SequenceModeling}-{opt.Prediction}' 528 | opt.experiment_name += f'-Seed{opt.manualSeed}' 529 | else: 530 | experiment_name = f'{opt.Transformation}-{opt.FeatureExtraction}-{opt.SequenceModeling}-{opt.Prediction}' 531 | experiment_name += f'-Seed{opt.manualSeed}' 532 | opt.experiment_name = experiment_name + opt.experiment_name 533 | # print(opt.experiment_name) 534 | 535 | os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True) 536 | 537 | """ Seed and GPU setting """ 538 | # print("Random Seed: ", opt.manualSeed) 539 | random.seed(opt.manualSeed) 540 | np.random.seed(opt.manualSeed) 541 | torch.manual_seed(opt.manualSeed) 542 | torch.cuda.manual_seed(opt.manualSeed) 543 | 544 | cudnn.benchmark = True 545 | cudnn.deterministic = True 546 | opt.num_gpu = torch.cuda.device_count() 547 | # print('device count', opt.num_gpu) 548 | if opt.num_gpu > 1: 549 | print('------ Use multi-GPU setting ------') 550 | print('if you stuck too long time with multi-GPU setting, try to set --workers 0') 551 | # check multi-GPU issue https://github.com/clovaai/deep-text-recognition-benchmark/issues/1 552 | opt.workers = opt.workers * opt.num_gpu 553 | 554 | """ previous version 555 | print('To equlize batch stats to 1-GPU setting, the batch_size is multiplied with num_gpu and multiplied batch_size is ', opt.batch_size) 556 | opt.batch_size = opt.batch_size * opt.num_gpu 557 | print('To equalize the number of epochs to 1-GPU setting, num_iter is divided with num_gpu by default.') 558 | If you dont care about it, just commnet out these line.) 559 | opt.num_iter = int(opt.num_iter / opt.num_gpu) 560 | """ 561 | train = trainer(opt) 562 | train.train(opt) 563 | --------------------------------------------------------------------------------