├── utils ├── __init__.py ├── utils.py ├── data_iterator.py └── latex2gtd_v2_2.py ├── README.md ├── model ├── encoder.py ├── encoder_decoder.py └── decoder.py └── train.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TDv2 2 | The source codes of TDv2 in paper: 3 | 4 | TDv2: A Novel Tree-Structured Decoder for Offline Mathematical Expression Recognition. 5 | 6 | ## Note 7 | Due to the company's confidentiality regulations, we cannot take out the original experimental code from the company's intranet, and can only reproduce a copy of the code. More details will be added in the future. 8 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from torch import nn 4 | from .latex2gtd_v2_2 import list2node, node2list_shuffle 5 | 6 | # load dictionary 7 | def load_dict(dictFile): 8 | lexicon = {} 9 | with open(dictFile) as fp: 10 | stuff = fp.readlines() 11 | for l in stuff: 12 | w = l.strip().split() 13 | lexicon[w[0]] = int(w[1]) 14 | return lexicon 15 | 16 | # create batch 17 | def prepare_data(params, images_x, seqs_y, seqs_key, object2id, relation2id, shuffle=False): 18 | heights_x = [s.shape[1] for s in images_x] 19 | widths_x = [s.shape[2] for s in images_x] 20 | lengths_y = [len(s) for s in seqs_y] 21 | 22 | n_samples = len(heights_x) 23 | max_height_x = np.max(heights_x) 24 | max_width_x = np.max(widths_x) 25 | maxlen_y = np.max(lengths_y) 26 | num_relation = len(relation2id) 27 | 28 | x = np.zeros((n_samples, params['input_channels'], max_height_x, max_width_x)).astype(np.float32) - 1 29 | childs = np.zeros((maxlen_y, n_samples)).astype(np.int64) 30 | parents = np.zeros((maxlen_y, n_samples)).astype(np.int64) 31 | c_pos = np.zeros((maxlen_y, n_samples)).astype(np.int64) 32 | p_pos = np.zeros((maxlen_y, n_samples)).astype(np.int64) 33 | relations = np.zeros((maxlen_y, n_samples)).astype(np.int64) 34 | pathes = np.zeros((maxlen_y, n_samples, num_relation)).astype(np.float32) 35 | 36 | x_mask = np.zeros((n_samples, max_height_x, max_width_x)).astype(np.float32) 37 | y_mask = np.zeros((maxlen_y, n_samples)).astype(np.float32) 38 | 39 | for idx, (s_x, s_y, s_key) in enumerate(zip(images_x, seqs_y, seqs_key)): 40 | x[idx, :, :heights_x[idx], :widths_x[idx]] = (255 - s_x) / 255. 41 | x_mask[idx, :heights_x[idx], :widths_x[idx]] = 1. 42 | if shuffle: 43 | tree = list2node(s_y) 44 | s_y = [] 45 | node2list_shuffle('', 1, 'Start', tree, s_y, True) 46 | for i, line in enumerate(s_y): 47 | c_id, p_id, re_id = object2id[line[0]], object2id[line[2]], relation2id[line[4]] 48 | childs[i, idx] = c_id 49 | parents[i, idx] = p_id 50 | c_pos[i, idx] = int(line[1]) 51 | p_pos[i, idx] = int(line[3]) 52 | relations[i, idx] = re_id 53 | pathes[i, idx, -1] = 1 54 | if i > 0: 55 | pathes[int(line[3])-1, idx, -1] = 0 56 | pathes[int(line[3])-1, idx, re_id] = 1 57 | 58 | y_mask[:lengths_y[idx], idx] = 1. 59 | 60 | return x, x_mask, childs, y_mask, parents, relations, pathes, c_pos, p_pos 61 | 62 | 63 | # init model params 64 | def weight_init(m): 65 | if isinstance(m, nn.Conv2d): 66 | nn.init.xavier_uniform_(m.weight.data) 67 | try: 68 | nn.init.constant_(m.bias.data, 0.) 69 | except: 70 | pass 71 | 72 | if isinstance(m, nn.Linear): 73 | nn.init.xavier_uniform_(m.weight.data) 74 | try: 75 | nn.init.constant_(m.bias.data, 0.) 76 | except: 77 | pass 78 | 79 | # compute metric 80 | def cmp_result(rec,label): 81 | dist_mat = np.zeros((len(label)+1, len(rec)+1),dtype='int32') 82 | dist_mat[0,:] = range(len(rec) + 1) 83 | dist_mat[:,0] = range(len(label) + 1) 84 | for i in range(1, len(label) + 1): 85 | for j in range(1, len(rec) + 1): 86 | hit_score = dist_mat[i-1, j-1] + (label[i-1] != rec[j-1]) 87 | ins_score = dist_mat[i,j-1] + 1 88 | del_score = dist_mat[i-1, j] + 1 89 | dist_mat[i,j] = min(hit_score, ins_score, del_score) 90 | 91 | dist = dist_mat[len(label), len(rec)] 92 | return dist, len(label) 93 | 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /model/encoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | # DenseNet-B结构 7 | # BN->Relu->conv1(nChannels,4*growthRate,1,1)->BN->Relu->conv2(4*growthRate,growthRate,3,3) 8 | class Bottleneck(nn.Module): 9 | def __init__(self, nChannels, growthRate, use_dropout): 10 | super(Bottleneck, self).__init__() 11 | interChannels = 4 * growthRate 12 | self.bn1 = nn.BatchNorm2d(interChannels) 13 | self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, 14 | bias=False) 15 | self.bn2 = nn.BatchNorm2d(growthRate) 16 | self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, 17 | padding=1, bias=False) 18 | self.use_dropout = use_dropout 19 | self.dropout = nn.Dropout(p=0.2) 20 | 21 | def forward(self, x): 22 | out = F.relu(self.bn1(self.conv1(x)), inplace=True) 23 | if self.use_dropout: 24 | out = self.dropout(out) 25 | out = F.relu(self.bn2(self.conv2(out)), inplace=True) 26 | if self.use_dropout: 27 | out = self.dropout(out) 28 | out = torch.cat((x, out), 1) 29 | return out 30 | 31 | 32 | # 一般的DenseNet结构 33 | # BN->Relu->conv(nChannels,*growthRate,3,3) 34 | class SingleLayer(nn.Module): 35 | def __init__(self, nChannels, growthRate, use_dropout): 36 | super(SingleLayer, self).__init__() 37 | self.bn1 = nn.BatchNorm2d(nChannels) 38 | self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, 39 | padding=1, bias=False) 40 | self.use_dropout = use_dropout 41 | self.dropout = nn.Dropout(p=0.2) 42 | 43 | def forward(self, x): 44 | out = self.conv1(F.relu(x, inplace=True)) 45 | if self.use_dropout: 46 | out = self.dropout(out) 47 | out = torch.cat((x, out), 1) 48 | return out 49 | 50 | 51 | # DenseNet-C结构 52 | # BN->Relu->conv->avg_pool(2,2) 53 | class Transition(nn.Module): 54 | def __init__(self, nChannels, nOutChannels, use_dropout): 55 | super(Transition, self).__init__() 56 | self.bn1 = nn.BatchNorm2d(nOutChannels) 57 | self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, 58 | bias=False) 59 | self.use_dropout = use_dropout 60 | self.dropout = nn.Dropout(p=0.2) 61 | 62 | def forward(self, x): 63 | out = F.relu(self.bn1(self.conv1(x)), inplace=True) 64 | if self.use_dropout: 65 | out = self.dropout(out) 66 | out = F.avg_pool2d(out, 2, ceil_mode=True) 67 | return out 68 | 69 | 70 | class DenseNet(nn.Module): 71 | def __init__(self, growthRate, reduction, bottleneck, use_dropout): 72 | super(DenseNet, self).__init__() 73 | nDenseBlocks = 22 74 | nChannels = 2 * growthRate # 48 75 | self.conv1 = nn.Conv2d(1, nChannels, kernel_size=7, padding=3, stride=2, bias=False) 76 | 77 | self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck, use_dropout) 78 | nChannels += nDenseBlocks * growthRate 79 | nOutChannels = int(math.floor(nChannels * reduction)) 80 | self.trans1 = Transition(nChannels, nOutChannels, use_dropout) 81 | 82 | nChannels = nOutChannels 83 | self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck, use_dropout) 84 | nChannels += nDenseBlocks * growthRate 85 | nOutChannels = int(math.floor(nChannels * reduction)) 86 | self.trans2 = Transition(nChannels, nOutChannels, use_dropout) 87 | 88 | nChannels = nOutChannels 89 | self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck, use_dropout) 90 | 91 | def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck, use_dropout): 92 | layers = [] 93 | for _ in range(int(nDenseBlocks)): 94 | if bottleneck: 95 | layers.append(Bottleneck(nChannels, growthRate, use_dropout)) 96 | else: 97 | layers.append(SingleLayer(nChannels, growthRate, use_dropout)) 98 | nChannels += growthRate 99 | return nn.Sequential(*layers) 100 | 101 | def forward(self, x, x_mask): 102 | out = self.conv1(x) 103 | out_mask = x_mask[:, 0::2, 0::2] 104 | out = F.relu(out, inplace=True) 105 | out = F.max_pool2d(out, 2, ceil_mode=True) 106 | out_mask = out_mask[:, 0::2, 0::2] 107 | out = self.dense1(out) 108 | out = self.trans1(out) 109 | out_mask = out_mask[:, 0::2, 0::2] 110 | out = self.dense2(out) 111 | out = self.trans2(out) 112 | out_mask = out_mask[:, 0::2, 0::2] 113 | out = self.dense3(out) 114 | return out, out_mask 115 | -------------------------------------------------------------------------------- /utils/data_iterator.py: -------------------------------------------------------------------------------- 1 | import pickle as pkl 2 | import random 3 | import math 4 | 5 | class BatchBucket(): 6 | def __init__(self, max_h, max_w, max_l, max_img_size, max_batch_size, 7 | feature_file, lable_file, shuffle=True): 8 | self._max_img_size = max_img_size 9 | self._max_batch_size = max_batch_size 10 | self._shuffle = shuffle 11 | 12 | with open(feature_file, 'rb') as fp: 13 | self._features = pkl.load(fp) 14 | with open(lable_file, 'rb') as fp: 15 | self._labels = pkl.load(fp) 16 | self._data_info_list = [(uid, fea.shape[1], fea.shape[2], len(self._labels[uid])) 17 | for uid, fea in self._features.items()] 18 | 19 | self._keys = self._calc_keys(max_h, max_w, max_l) 20 | 21 | def _calc_keys(self, max_h, max_w, max_l): 22 | # 计算出真实的最大值 23 | _, h_info, w_info, l_info = zip(*self._data_info_list) 24 | mh, mw, ml = max(h_info), max(w_info), max(l_info) 25 | max_h, max_w, max_l = min(max_h, mh), min(max_w, mw), min(max_l, ml) 26 | 27 | # 根据真实的最大值切分网络 28 | keys = [] 29 | init_h = 100 if 100 < max_h else max_h 30 | init_w = 100 if 100 < max_w else max_w 31 | # init_l = 100 if 100 < max_l else max_l 32 | init_l = max_l 33 | #网格的切分间距 34 | h_step = 50 35 | w_step = 100 36 | l_step = 20 37 | h = init_h 38 | while h <= max_h: 39 | w = init_w 40 | while w <= max_w: 41 | l = init_l 42 | while l <= max_l: 43 | keys.append([h, w, l, h * w * l, 0, []]) 44 | if l < max_l and l + l_step > max_l: 45 | l = max_l 46 | else: 47 | l += l_step 48 | if w < max_w and w + max(int((w*0.3 // 10) * 10), w_step) > max_w: 49 | w = max_w 50 | else: 51 | w = w + max(int((w*0.3 // 10) * 10), w_step) 52 | if h < max_h and h + max(int((h*0.5 // 10) * 10), h_step) > max_h: 53 | h = max_h 54 | else: 55 | h = h + max(int((h*0.5 // 10) * 10), h_step) 56 | keys = sorted(keys, key=lambda area:area[3]) 57 | 58 | # 把每个数据分配到想对应的网格中 59 | # 统计每个网格中落下的样本数量 60 | unused_num = 0 61 | for uid, h, w, l in self._data_info_list: 62 | flag = False 63 | for i, key in enumerate(keys): 64 | hh, ww, ll, _, _, subset = key 65 | if h <= hh and w <= ww and l <= ll: 66 | keys[i][-2] += 1 67 | subset.append(uid) 68 | flag = True 69 | break 70 | if flag == False: 71 | print(uid, h, w, l) 72 | unused_num += 1 73 | print(f'The number of all samples: {len(self._data_info_list)}') 74 | print(f'The number of unused samples: {unused_num}') 75 | # 过滤所有网格中少于某个阈值的网格 76 | keys = list(filter(lambda temp_k:temp_k[-2]>0, keys)) 77 | # 计算每个子网格的batch大小 78 | # total_batch, total_sample = 0, 0 79 | for i, key in enumerate(keys): 80 | h, w, l, _, sample_num, _ = key 81 | batch_size = int(self._max_img_size / (h * w)) 82 | batch_size = max(1, min(batch_size, self._max_batch_size)) 83 | keys[i][3] = batch_size 84 | print(f'bucket [{h}, {w}, {l}], batch={batch_size}, sample={sample_num}') 85 | # 返回网格的key值 86 | return keys 87 | 88 | def _reset_batches(self): 89 | self._batches = [] 90 | #打乱每个子网格中的样本顺序 91 | for _, _, _, batch_size, sample_num, uids in self._keys: 92 | if self._shuffle: 93 | random.shuffle(uids) 94 | batch_num = math.ceil(sample_num / batch_size) 95 | for i in range(batch_num): 96 | start = i * batch_size 97 | end = start + batch_size if start + batch_size < sample_num else sample_num 98 | self._batches.append(uids[start:end]) 99 | if self._shuffle: 100 | random.shuffle(self._batches) 101 | 102 | def get_batches(self): 103 | batches = [] 104 | self._reset_batches() 105 | for uid_batch in self._batches: 106 | fea_batch = [self._features[uid] for uid in uid_batch] 107 | label_batch = [self._labels[uid] for uid in uid_batch] 108 | batches.append((fea_batch, label_batch, uid_batch)) 109 | print(f'The number of Bucket(subset) {len(self._keys)}') 110 | print(f'The number of Batches {len(batches)}') 111 | print(f'The number of Samples {len(self._data_info_list)}') 112 | 113 | return batches 114 | 115 | if __name__ == '__main__': 116 | data_path = 'data/' 117 | train_datasets = ['train_images.pkl', 'train_labels.pkl', 'train_relations.pkl'] 118 | train_data_iterator = BatchBucket(600, 2100, 200, 800000, 16, 119 | data_path+train_datasets[0], data_path+train_datasets[1]) 120 | train_batches = train_data_iterator.get_batches() -------------------------------------------------------------------------------- /model/encoder_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .encoder import DenseNet 5 | from .decoder import Decoder 6 | 7 | class Encoder_Decoder(nn.Module): 8 | def __init__(self, params): 9 | super(Encoder_Decoder, self).__init__() 10 | self.encoder = DenseNet(growthRate=params['growthRate'], 11 | reduction=params['reduction'], 12 | bottleneck=params['bottleneck'], 13 | use_dropout=params['use_dropout']) 14 | self.init_context = nn.Linear(params['D'], params['n']) 15 | self.decoder = Decoder(params) 16 | self.object_criterion = nn.CrossEntropyLoss(reduction='none') 17 | self.relation_criterion = nn.BCEWithLogitsLoss(reduction='none') 18 | self.object_pix_criterion = nn.NLLLoss(reduction='none') 19 | self.param_n = params['n'] 20 | 21 | def forward(self, params, x, x_mask, C_y, 22 | P_y, C_re, P_re, P_position, y_mask, re_mask, length): 23 | #encoder 24 | ctx, ctx_mask = self.encoder(x, x_mask) 25 | ctx_mean = (ctx * ctx_mask[:, None, :, :]).sum(3).sum(2) \ 26 | / ctx_mask.sum(2).sum(1)[:, None] 27 | init_state = torch.tanh(self.init_context(ctx_mean)) 28 | #decoder 29 | predict_objects, predict_relations,predict_objects_pix = self.decoder(ctx, ctx_mask, 30 | C_y, P_y, y_mask, P_re, P_position, init_state, length) 31 | 32 | #loss 33 | predict_objects = predict_objects.view(-1, predict_objects.shape[2]) 34 | object_loss = self.object_criterion(predict_objects, C_y.view(-1)) 35 | object_loss = object_loss.view(C_y.shape[0], C_y.shape[1]) 36 | object_loss = ((object_loss * y_mask).sum(0) / y_mask.sum(0)).mean() 37 | 38 | predict_objects_pix = predict_objects_pix.view(-1, predict_objects_pix.shape[2]) 39 | object_pix_loss = self.object_pix_criterion(predict_objects_pix, C_y.view(-1)) 40 | object_pix_loss = object_pix_loss.view(C_y.shape[0], C_y.shape[1]) 41 | object_pix_loss = ((object_pix_loss * y_mask).sum(0) / y_mask.sum(0)).mean() 42 | 43 | relation_loss = predict_relations.view(-1, predict_relations.shape[2]) 44 | relation_loss = self.relation_criterion(relation_loss, C_re.view(-1, C_re.shape[2])) 45 | relation_loss = relation_loss.view(C_re.shape[0], C_re.shape[1], C_re.shape[2]) 46 | relation_loss = (relation_loss * re_mask[:, :, None]).sum(2).sum(0) / re_mask.sum(0) 47 | relation_loss = relation_loss.mean() 48 | 49 | loss = params['lc_lambda'] * object_loss + \ 50 | params['lr_lambda'] * relation_loss + \ 51 | params['lc_lambda_pix'] * object_pix_loss 52 | 53 | return loss, object_loss, relation_loss 54 | 55 | def greedy_inference(self, x, x_mask, max_length, p_y, p_re, p_mask): 56 | 57 | ctx, ctx_mask = self.encoder(x, x_mask) 58 | ctx_mean = (ctx * ctx_mask[:, None, :, :]).sum(3).sum(2) \ 59 | / ctx_mask.sum(2).sum(1)[:, None] 60 | init_state = torch.tanh(self.init_context(ctx_mean)) 61 | 62 | B, H, W = ctx_mask.shape 63 | attention_past = torch.zeros(B, 1, H, W).cuda() 64 | 65 | ctx_key_object = self.decoder.conv_key_object(ctx).permute(0, 2, 3, 1) 66 | ctx_key_relation = self.decoder.conv_key_relation(ctx).permute(0, 2, 3, 1) 67 | 68 | relation_table = torch.zeros(B, max_length, 9).to(torch.long) 69 | relation_table_static = torch.zeros(B, max_length, 9).to(torch.long) 70 | predict_relation_static = torch.zeros(B, max_length, 9) 71 | P_masks = torch.zeros(B, max_length).cuda() 72 | predict_childs = torch.zeros(max_length, B).to(torch.long).cuda() 73 | 74 | ht = init_state 75 | parent_ht = init_state 76 | for i in range(max_length): 77 | predict_child, ht, attention, ct = self.decoder.get_child(ctx, ctx_key_object, ctx_mask, 78 | attention_past, p_y, p_mask, p_re, ht) 79 | predict_childs[i] = torch.argmax(predict_child, dim=1) 80 | predict_childs[i] *= p_mask.to(torch.long) 81 | attention_past = attention[:, None, :, :] + attention_past 82 | 83 | predict_relation, ht_relation = self.decoder.get_relation(ctx, ctx_key_relation, ctx_mask, 84 | predict_childs[i], ht_relation, ct) 85 | 86 | P_masks[:, i] = p_mask 87 | 88 | predict_relation_static[:, i, :] = predict_relation 89 | relation_table[:, i, :] = (predict_relation > 0) 90 | relation_table_static[:, i, :] = (predict_relation > 0) 91 | 92 | relation_table[:, :, 8] = relation_table[:, :, :8].sum(2) 93 | 94 | find_number = 0 95 | for ii in range(B): 96 | if p_mask[ii] < 0.5: 97 | continue 98 | ji = i 99 | find_flag = 0 100 | while ji >= 0: 101 | if relation_table[ii, ji, 8] > 0: 102 | for iii in range(9): 103 | if relation_table[ii, ji, iii] != 0: 104 | p_re[ii] = iii 105 | p_y[ii] = predict_childs[ji, ii] 106 | relation_table[ii, ji, iii] = 0 107 | relation_table[ii, ji, 8] -= 1 108 | find_flag = 1 109 | break 110 | if find_flag: 111 | break 112 | ji -= 1 113 | find_number += find_flag 114 | if not find_flag: 115 | p_mask[ii] = 0. 116 | 117 | if find_number == 0: 118 | break 119 | return predict_childs, P_masks, relation_table_static, predict_relation_static 120 | 121 | -------------------------------------------------------------------------------- /model/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class Decoder(nn.Module): 6 | def __init__(self, params): 7 | super(Decoder, self).__init__() 8 | self.object_embedding = nn.Embedding(params['K'], params['m']) 9 | self.relation_embedding = nn.Embedding(params['Kre'], params['re_m']) 10 | 11 | self.gru00 = nn.GRUCell(params['m']+params['re_m'], params['n']) 12 | self.gru01 = nn.GRUCell(params['D'], params['n']) 13 | self.gru10 = nn.GRUCell(params['D'], params['n']) 14 | self.gru11 = nn.GRUCell(params['D'], params['n']) 15 | 16 | self.conv_key_object = nn.Conv2d(params['D'], params['dim_attention'], kernel_size=1) 17 | self.object_attention = CoverageAttention(params) 18 | self.object_probility = ObjectProbility(params) 19 | 20 | self.conv_key_relation = nn.Conv2d(params['D'], params['dim_attention'], kernel_size=1) 21 | self.relation_attention = Attention(params) 22 | self.relation_probility = RelationProbility(params) 23 | 24 | self.pix_object_probility = PixObjectProbility(params) 25 | 26 | self.param_K = params['K'] 27 | self.param_Kre = params['Kre'] 28 | self.param_n = params['n'] 29 | 30 | def forward(self, ctx_val, ctx_mask, C_ys, P_ys, P_y_masks, 31 | P_res, P_positions, init_state, length): 32 | 33 | B, H, W = ctx_mask.shape 34 | attention_past = torch.zeros(B, 1, H, W).cuda() 35 | predict_childs = torch.zeros(length, B, self.param_K).cuda() 36 | predict_childs_pix = torch.zeros(length, B, self.param_K).cuda() 37 | predict_relations = torch.zeros(length, B, self.param_Kre).cuda() 38 | 39 | ht = init_state 40 | ht_relation = init_state 41 | 42 | ctx_key_object = self.conv_key_object(ctx_val).permute(0, 2, 3, 1) 43 | ctx_key_relation = self.conv_key_relation(ctx_val).permute(0, 2, 3, 1) 44 | 45 | predict_features = self.pix_object_probility(ctx_val) 46 | for i in range(length): 47 | predict_childs[i], ht, attention, ct = self.get_child(ctx_val, ctx_key_object, ctx_mask, 48 | attention_past, P_ys[i], P_y_masks[i], P_res[i], ht) 49 | attention_past = attention[:, None, :, :] + attention_past 50 | 51 | predict_childs_pix[i] = torch.log((attention.detach()[:, :, :, None] * predict_features).sum(2).sum(1)) 52 | predict_relations[i], ht_relation = self.get_relation(ctx_val, ctx_key_relation, 53 | ctx_mask, C_ys[i], ht_relation, ct) 54 | 55 | return predict_childs, predict_relations, predict_childs_pix 56 | 57 | def get_child(self, ctx_val, ctx_key, ctx_mask, attention_past, p_y, p_y_mask, p_re, ht): 58 | p_y = self.object_embedding(p_y) 59 | p_re = self.relation_embedding(p_re) 60 | p = torch.cat([p_y, p_re], dim=1) 61 | ht_hat = self.gru00(p, ht) 62 | ht_hat = p_y_mask[:, None] * ht_hat + (1 - p_y_mask)[:, None] * ht 63 | 64 | ct, attention = self.object_attention(ctx_val, ctx_key, ctx_mask, attention_past, ht_hat) 65 | 66 | ht = self.gru01(ct, ht_hat) 67 | ht = p_y_mask[:, None] * ht + (1 - p_y_mask)[:, None] * ht_hat 68 | 69 | predict_child = self.object_probility(ct, ht, p_y, p_re) 70 | 71 | return predict_child, ht, attention, ct 72 | 73 | def get_relation(self, ctx_val, ctx_key, ctx_mask, c_y, ht, ct): 74 | c_y = self.object_embedding(c_y) 75 | ht_query = self.gru10(ct, ht) 76 | ct, _ = self.relation_attention(ctx_val, ctx_key, ctx_mask, ht_query) 77 | ht = self.gru11(ct, ht_query) 78 | 79 | predict_relation = self.relation_probility(ct, c_y, ht) 80 | return predict_relation, ht 81 | 82 | class CoverageAttention(nn.Module): 83 | def __init__(self, params): 84 | super(CoverageAttention, self).__init__() 85 | self.fc_query = nn.Linear(params['n'], params['dim_attention'], bias=False) 86 | self.conv_att_past = nn.Conv2d(1, 512, kernel_size=11, bias=False, padding=5) 87 | self.fc_att_past = nn.Linear(512, params['dim_attention']) 88 | self.fc_attention = nn.Linear(params['dim_attention'], 1) 89 | 90 | def forward(self, ctx_val, ctx_key, ctx_mask, attention_past, ht_query): 91 | 92 | ht_query = self.fc_query(ht_query) 93 | 94 | attention_past = self.conv_att_past(attention_past).permute(0, 2, 3, 1) 95 | attention_past = self.fc_att_past(attention_past) #(batch, H, W, dim_att) 96 | 97 | attention_score = torch.tanh(ctx_key + ht_query[:, None, None, :] + attention_past) 98 | attention_score = self.fc_attention(attention_score).squeeze(3) 99 | 100 | attention_score = attention_score - attention_score.max() 101 | attention_score = torch.exp(attention_score) * ctx_mask 102 | attention_score = attention_score / (attention_score.sum(2).sum(1)[:, None, None] + 1e-10) 103 | 104 | ct = (ctx_val * attention_score[:, None, :, :]).sum(3).sum(2) 105 | 106 | return ct, attention_score 107 | 108 | class Attention(nn.Module): 109 | def __init__(self, params): 110 | super(Attention, self).__init__() 111 | self.fc_query = nn.Linear(params['n'], params['dim_attention'], bias=False) 112 | self.fc_attention = nn.Linear(params['dim_attention'], 1) 113 | 114 | def forward(self, ctx_val, ctx_key, ctx_mask, ht_query): 115 | 116 | ht_query = self.fc_query(ht_query) 117 | 118 | attention_score = torch.tanh(ctx_key + ht_query[:, None, None, :]) 119 | attention_score = self.fc_attention(attention_score).squeeze(3) 120 | 121 | attention_score = attention_score - attention_score.max() 122 | attention_score = torch.exp(attention_score) * ctx_mask 123 | attention_score = attention_score / (attention_score.sum(2).sum(1)[:, None, None] + 1e-10) 124 | 125 | ct = (ctx_val * attention_score[:, None, :, :]).sum(3).sum(2) 126 | 127 | return ct, attention_score 128 | 129 | class ObjectProbility(nn.Module): 130 | def __init__(self, params): 131 | super(ObjectProbility, self).__init__() 132 | self.fc_ct = nn.Linear(params['D'], params['m']) 133 | self.fc_ht = nn.Linear(params['n'], params['m']) 134 | self.fc_p_y = nn.Linear(params['m'], params['m']) 135 | self.fc_p_re = nn.Linear(params['re_m'], params['m']) 136 | self.dropout = nn.Dropout(p=0.2) 137 | self.fc_probility = nn.Linear(int(params['m']/2), params['K']) 138 | 139 | def forward(self, ct, ht, p_y, p_re): 140 | out = self.fc_ct(ct) + self.fc_ht(ht) + self.fc_p_y(p_y) + self.fc_p_re(p_re) 141 | #maxout 142 | out = out.view(out.shape[0], -1, 2) 143 | out = out.max(2)[0] 144 | out = self.dropout(out) 145 | out = self.fc_probility(out) 146 | return out 147 | 148 | class PixObjectProbility(nn.Module): 149 | def __init__(self, params): 150 | super(ObjectProbility, self).__init__() 151 | self.fc_ct = nn.Linear(params['D'], params['m']) 152 | self.dropout = nn.Dropout(p=0.2) 153 | self.fc_probility = nn.Linear(int(params['m']/2), params['K']) 154 | 155 | def forward(self, ctx): 156 | out = self.fc_ct(ctx.permute(0, 2, 3, 1)) # (B, H, W, D) --> (B, H, W, m) 157 | #maxout 158 | out = out.view(out.shape[0], out.shape[1], out.shape[2], -1, 2) #(B, H, W, m/2, 2) 159 | out = out.max(4)[0] #(B, H, W, m/2) 160 | out = self.dropout(out) 161 | out = self.fc_probility(out) #(B, H, W, k) 162 | out = F.softmax(out, dim=3) 163 | return out 164 | 165 | class RelationProbility(nn.Module): 166 | def __init__(self, params): 167 | super(RelationProbility, self).__init__() 168 | self.fc_ct = nn.Linear(params['D'], params['mre']) 169 | self.fc_c_y = nn.Linear(params['m'], params['mre']) 170 | self.fc_ht = nn.Linear(params['n'], params['mre']) 171 | self.dropout = nn.Dropout(p=0.2) 172 | self.fc_probility = nn.Linear(int(params['mre']/2), params['Kre']) 173 | 174 | def forward(self, ct, c_y, ht): 175 | out = self.fc_ct(ct) + self.fc_c_y(c_y) + self.fc_ht(ht) 176 | #maxout 177 | out = out.view(out.shape[0], -1, 2) 178 | out = out.max(2)[0] 179 | out = self.dropout(out) 180 | out = self.fc_probility(out) 181 | 182 | return out 183 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import pickle as pkl 4 | import torch 5 | from torch import optim, nn 6 | from utils.utils import load_dict, prepare_data, weight_init, cmp_result 7 | from model.encoder_decoder import Encoder_Decoder 8 | from utils.data_iterator import BatchBucket 9 | from utils.latex2gtd_v2_2 import list2node, tree2latex, relation2gtd 10 | 11 | # configurations 12 | init_param_flag = True # whether init params 13 | reload_flag = False # whether relaod params 14 | data_path = 'data/' 15 | work_path = 'results/' 16 | dictionaries = ['dictionary_object.txt', 'dictionary_relation_noend.txt'] 17 | train_datasets = ['train_images.pkl', 'train_labels.pkl', 'train_relations.pkl'] 18 | valid_datasets = ['valid_images.pkl', 'valid_labels.pkl', 'valid_relations.pkl'] 19 | valid_outputs = ['valid_results.txt'] 20 | model_params = ['SimTree_best_params.pkl', 'SimTree_last_params.pkl'] 21 | 22 | #train Settings 23 | maxlen = 200 24 | max_epochs = 5000 25 | lrate = 1 26 | my_eps = 1e-6 27 | decay_c = 1e-4 28 | clip_c = 100. 29 | 30 | # early stop 31 | estop = False 32 | halfLrFlag = 0 33 | bad_counter = 0 34 | patience = 15 35 | validStart = 0 # 模型未学习好,解码容易溢出,浪费时间 36 | finish_after = 10000000 37 | 38 | # model architecture 39 | # 这部分应该转移到模型代码中 40 | params = {} 41 | params['n'] = 256 42 | params['m'] = 256 43 | params['re_m'] = 64 44 | params['dim_attention'] = 512 45 | params['D'] = 936 46 | params['K'] = 107 47 | params['Kre'] = 9 48 | params['mre'] = 256 49 | params['maxlen'] = maxlen 50 | 51 | params['growthRate'] = 24 52 | params['reduction'] = 0.5 53 | params['bottleneck'] = True 54 | params['use_dropout'] = True 55 | params['input_channels'] = 1 56 | 57 | params['lc_lambda'] = 1. 58 | params['lr_lambda'] = 1. 59 | params['lc_lambda_pix'] = 1. 60 | 61 | symbol2id = load_dict(data_path+dictionaries[0]) 62 | print('total chars', len(symbol2id)) 63 | id2symbol = {} 64 | for symbol, symbol_id in symbol2id.items(): 65 | id2symbol[symbol_id] = symbol 66 | 67 | relation2id = load_dict(data_path+dictionaries[1]) 68 | print('total relations', len(relation2id)) 69 | id2relation = {} 70 | for relation, relation_id in relation2id.items(): 71 | id2relation[relation_id] = relation 72 | 73 | train_data_iterator = BatchBucket(600, 2100, 200, 800000, 16, 74 | data_path+train_datasets[0], data_path+train_datasets[1]) 75 | valid_data_iterator = BatchBucket(9999, 9999, 9999, 999999, 1, 76 | data_path+valid_datasets[0], data_path+valid_datasets[1]) 77 | valid = valid_data_iterator.get_batches() 78 | with open(data_path + valid_datasets[1], 'rb') as fp: 79 | valid_gtds = pkl.load(fp) 80 | 81 | # display 82 | uidx = 0 83 | object_loss_s = 0. 84 | relation_loss_s = 0. 85 | loss_s = 0. 86 | 87 | ud_s = 0 88 | validFreq = -1 89 | saveFreq = -1 90 | sampleFreq = -1 91 | dispFreq = 100 92 | WER = 100 93 | 94 | # inititalize model 95 | SimTree_model = Encoder_Decoder(params) 96 | if init_param_flag: 97 | SimTree_model.apply(weight_init) 98 | if reload_flag: 99 | print('Loading pretrained model ...') 100 | SimTree_model.load_state_dict(torch.load(work_path+model_params[1], map_location=lambda storage, loc:storage)) 101 | SimTree_model.cuda() 102 | 103 | optimizer = optim.Adadelta(SimTree_model.parameters(), lr=lrate, eps=my_eps, weight_decay=decay_c) 104 | print('Optimization') 105 | 106 | # statistics 107 | history_errs = [] 108 | for eidx in range(max_epochs): 109 | n_samples = 0 110 | ud_epoch = time.time() 111 | train = train_data_iterator.get_batches() 112 | if validFreq == -1: 113 | validFreq = len(train) 114 | if saveFreq == -1: 115 | saveFreq = len(train) 116 | if sampleFreq == -1: 117 | sampleFreq = len(train) 118 | 119 | for x, y, key in train: 120 | SimTree_model.train() 121 | ud_start = time.time() 122 | n_samples += len(x) 123 | uidx += 1 124 | x, x_mask, C_y, y_mask, P_y, P_re, C_re, lp, rp = \ 125 | prepare_data(params, x, y, key, symbol2id, relation2id, shuffle=False) 126 | 127 | length = C_y.shape[0] 128 | x = torch.from_numpy(x).cuda() 129 | x_mask = torch.from_numpy(x_mask).cuda() 130 | C_y = torch.from_numpy(C_y).to(torch.long).cuda() 131 | y_mask = torch.from_numpy(y_mask).cuda() 132 | P_y = torch.from_numpy(P_y).to(torch.long).cuda() 133 | P_re = torch.from_numpy(P_re).to(torch.long).cuda() 134 | P_position = torch.from_numpy(rp).to(torch.long).cuda() 135 | C_re = torch.from_numpy(C_re).cuda() 136 | 137 | loss, object_loss, relation_loss = SimTree_model(params, x, x_mask, 138 | C_y, P_y, C_re, P_re, P_position, y_mask, y_mask, length) 139 | 140 | object_loss_s += object_loss.item() 141 | relation_loss_s += relation_loss.item() 142 | loss_s = loss.item() 143 | 144 | # backward 145 | optimizer.zero_grad() 146 | loss.backward() 147 | if clip_c > 0.: 148 | torch.nn.utils.clip_grad_norm_(SimTree_model.parameters(), clip_c) 149 | 150 | # updata 151 | optimizer.step() 152 | 153 | # display 154 | ud = time.time() - ud_start 155 | ud_s += ud 156 | 157 | if np.mod(uidx, dispFreq) == 0: 158 | ud_s /= 60. 159 | loss_s /= dispFreq 160 | object_loss_s /= dispFreq 161 | relation_loss_s /= dispFreq 162 | print(f'Epoch {eidx}, Update {uidx} Cost_object {object_loss_s:.7}', end='') 163 | print(f'Cost_relation {relation_loss_s}, UD {ud_s:.3} lrate {lrate} eps {my_eps} bad_counter {bad_counter}') 164 | ud_s = 0 165 | loss_s = 0. 166 | object_loss_s = 0. 167 | relation_loss_s = 0. 168 | 169 | if np.mod(uidx, saveFreq) == 0: 170 | print('Saving latest model params ... ') 171 | torch.save(SimTree_model.state_dict(), work_path+model_params[1]) 172 | 173 | # validation 174 | if np.mod(uidx, sampleFreq) == 0 and (eidx % 2) == 0: 175 | number_right = 0 176 | total_distance = 0 177 | total_length = 0 178 | latex_right = 0 179 | total_latex_distance = 0 180 | total_latex_length = 0 181 | total_number = 0 182 | 183 | print('begin sampling') 184 | ud_epoch_train = (time.time() - ud_epoch) / 60. 185 | print('epoch training cost time ...', ud_epoch_train) 186 | SimTree_model.eval() 187 | 188 | fp_results = open(work_path+valid_outputs[0], 'w') 189 | with torch.no_grad(): 190 | valid_count_idx = 0 191 | for x, y, valid_key in valid: 192 | x, x_mask, C_y, y_mask, P_y, P_re, C_re, lp, rp = \ 193 | prepare_data(params, x, y, valid_key, symbol2id, relation2id) 194 | 195 | L, B = C_y.shape[:2] 196 | x = torch.from_numpy(x).cuda() 197 | x_mask = torch.from_numpy(x_mask).cuda() 198 | lengths_gt = (y_mask > 0.5).sum(0) 199 | y_mask = torch.from_numpy(y_mask).cuda() 200 | P_y = torch.from_numpy(P_y).to(torch.long).cuda() 201 | P_re = torch.from_numpy(P_re).to(torch.long).cuda() 202 | 203 | object_predicts, P_masks, relation_table_static, _ \ 204 | = SimTree_model.greedy_inference(x, x_mask, L+1, P_y[0], P_re[0], y_mask[0]) 205 | object_predicts, P_masks = object_predicts.cpu().numpy(), P_masks.cpu().numpy() 206 | relation_table_static = relation_table_static.numpy() 207 | for bi in range(B): 208 | length_predict = min((P_masks[bi, :] > 0.5).sum(), P_masks.shape[1]) 209 | object_predict = object_predicts[:int(length_predict), bi] 210 | relation_predict = relation_table_static[bi, :int(length_predict), :] 211 | gtd = relation2gtd(object_predict, relation_predict, id2symbol, id2relation) 212 | latex = tree2latex(list2node(gtd)) 213 | 214 | uid = valid_key[bi] 215 | groud_truth_gtd = valid_gtds[uid] 216 | groud_truth_latex = tree2latex(list2node(groud_truth_gtd)) 217 | 218 | child = [symbol2id[g[0]] for g in groud_truth_gtd if g[0] != ''] 219 | distance, length = cmp_result(object_predict, child) 220 | total_number += 1 221 | 222 | if distance == 0: 223 | number_right += 1 224 | fp_results.write(uid + '\tObject True\t') 225 | else: 226 | fp_results.write(uid + '\tObject False\t') 227 | 228 | latex_distance, latex_length = cmp_result(groud_truth_latex, latex) 229 | if latex_distance == 0: 230 | latex_right += 1 231 | fp_results.write('Latex True\n') 232 | else: 233 | fp_results.write('Latex False\n') 234 | 235 | total_distance += distance 236 | total_length += length 237 | total_latex_distance += latex_distance 238 | total_latex_length += latex_length 239 | 240 | fp_results.write(' '.join(groud_truth_latex) + '\n') 241 | fp_results.write(' '.join(latex) + '\n') 242 | 243 | for c in child: 244 | fp_results.write(id2symbol[c] + ' ') 245 | fp_results.write('\n') 246 | 247 | for ob_p in object_predict: 248 | fp_results.write(id2symbol[ob_p] + ' ') 249 | fp_results.write('\n') 250 | 251 | wer = total_distance / total_length * 100 252 | sacc = number_right / total_number * 100 253 | latex_wer = total_latex_distance / total_latex_length * 100 254 | latex_acc = latex_right / total_number * 100 255 | fp_results.close() 256 | 257 | ud_epoch = (time.time() - ud_epoch) / 60. 258 | print(f'valid set decode done, epoch cost time: {ud_epoch} min') 259 | print(f'WER {wer} SACC {sacc} Latex WER {latex_wer} Latex SACC {latex_acc}') 260 | 261 | if latex_wer <= WER: 262 | WER = latex_wer 263 | bad_counter = 0 264 | print('Saving best model params ... ') 265 | torch.save(SimTree_model.state_dict(), work_path+model_params[0]) 266 | else: 267 | bad_counter += 1 268 | if bad_counter > patience: 269 | if halfLrFlag == 2: 270 | print('Early Stop!') 271 | estop = True 272 | break 273 | else: 274 | print('Lr decay and retrain!') 275 | bad_counter = 0 276 | lrate = lrate / 10. 277 | for param_group in optimizer.param_groups: 278 | param_group['lr'] = lrate 279 | halfLrFlag += 1 280 | if uidx >= finish_after: 281 | print(f'Finishing after {uidx} iterations!') 282 | estop = True 283 | break 284 | 285 | print(f'Seen {n_samples} samples') 286 | 287 | if estop: 288 | break -------------------------------------------------------------------------------- /utils/latex2gtd_v2_2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | 4 | class Node(): 5 | def __init__(self, x=0): 6 | self.x = x 7 | self.childs = [] 8 | self.relations = [] 9 | 10 | def findnextbracket(latex:list, leftbracket='{'): 11 | if leftbracket == '{': 12 | rightbracket = '}' 13 | elif leftbracket == '[': 14 | rightbracket = ']' 15 | else: 16 | raise AssertionError('Unkown Bracket!') 17 | 18 | num = 0 19 | for li, l in enumerate(latex): 20 | if l == leftbracket: 21 | num += 1 22 | if l == rightbracket: 23 | num -= 1 24 | if num == 0: 25 | return li 26 | return -1 27 | 28 | def findendmatrix(latex): 29 | num = 1 30 | for li, l in enumerate(latex): 31 | if l == '\\begin{matrix}': 32 | num += 1 33 | if l == '\\end{matrix}': 34 | num -= 1 35 | if num == 0: 36 | return li 37 | return -1 38 | 39 | def latex2Tree(latex:list): 40 | ''' 41 | input: latex --> list 42 | output: Node 43 | ''' 44 | if len(latex) == 0: 45 | return Node('') 46 | 47 | cur_node = Node(latex[0]) 48 | symbol = latex.pop(0) 49 | 50 | if symbol == '': 51 | if len(latex) > 0 and latex[0] == '_': 52 | latex.pop(0) 53 | assert latex[0] == '{', '_ not with {' 54 | li = findnextbracket(latex, leftbracket='{') 55 | sub_latex = latex[1:li] 56 | node = latex2Tree(sub_latex) 57 | cur_node.childs.append(node) 58 | cur_node.relations.append('sub') 59 | for _ in range(li+1): 60 | latex.pop(0) 61 | if len(latex) > 0 and latex[0] == '^': 62 | latex.pop(0) 63 | assert latex[0] == '{', '^ not with {' 64 | li = findnextbracket(latex, leftbracket='{') 65 | sub_latex = latex[1:li] 66 | node = latex2Tree(sub_latex) 67 | cur_node.childs.append(node) 68 | cur_node.relations.append('sup') 69 | for _ in range(li+1): 70 | latex.pop(0) 71 | li = findnextbracket(latex, leftbracket='{') 72 | 73 | elif symbol == '\\begin{matrix}': 74 | li = findendmatrix(latex) 75 | sub_latex = latex[:li] 76 | node = latex2Tree(sub_latex) 77 | cur_node.childs.append(node) 78 | cur_node.relations.append('Mstart') 79 | for _ in range(li+1): 80 | latex.pop(0) 81 | 82 | elif symbol in ['\\iint', '\\bigcup', '\\sum', '\\lim', '\\coprod']: 83 | if len(latex) > 0 and latex[0] == '_': 84 | latex.pop(0) 85 | assert latex[0] == '{', '_ not with {' 86 | li = findnextbracket(latex, leftbracket='{') 87 | sub_latex = latex[1:li] 88 | node = latex2Tree(sub_latex) 89 | cur_node.childs.append(node) 90 | cur_node.relations.append('below') 91 | for _ in range(li+1): 92 | latex.pop(0) 93 | if len(latex) > 0 and latex[0] == '^': 94 | latex.pop(0) 95 | assert latex[0] == '{', '^ not with {' 96 | li = findnextbracket(latex, leftbracket='{') 97 | sub_latex = latex[1:li] 98 | node = latex2Tree(sub_latex) 99 | cur_node.childs.append(node) 100 | cur_node.relations.append('above') 101 | for _ in range(li+1): 102 | latex.pop(0) 103 | 104 | elif symbol in ['\\dot', '\\ddot', '\\hat', '\\check', '\\grave', '\\acute', 105 | '\\tilde', '\\breve', '\\bar', '\\vec', '\\widehat', 106 | '\\overbrace', '\\widetilde','\\overleftarrow', 107 | '\\overrightarrow','\\overline']: 108 | assert latex[0] == '{', 'CASE 3 above not with {' 109 | li = findnextbracket(latex, leftbracket='{') 110 | sub_latex = latex[1:li] 111 | node = latex2Tree(sub_latex) 112 | cur_node.childs.append(node) 113 | cur_node.relations.append('below') 114 | for _ in range(li+1): 115 | latex.pop(0) 116 | 117 | elif symbol in ['\\underline', '\\underbrace']: 118 | assert latex[0] == '{', 'CASE 3 above not with {' 119 | li = findnextbracket(latex, leftbracket='{') 120 | sub_latex = latex[1:li] 121 | node = latex2Tree(sub_latex) 122 | cur_node.childs.append(node) 123 | cur_node.relations.append('above') 124 | for _ in range(li+1): 125 | latex.pop(0) 126 | 127 | elif symbol in ['\\xrightarrow', '\\xleftarrow']: 128 | if latex[0] == '[': 129 | li = findnextbracket(latex, leftbracket='[') 130 | sub_latex = latex[1:li] 131 | node = latex2Tree(sub_latex) 132 | cur_node.childs.append(node) 133 | cur_node.relations.append('below') 134 | for _ in range(li+1): 135 | latex.pop(0) 136 | if latex[0] == '{': 137 | li = findnextbracket(latex, leftbracket='{') 138 | sub_latex = latex[1:li] 139 | node = latex2Tree(sub_latex) 140 | cur_node.childs.append(node) 141 | cur_node.relations.append('above') 142 | for _ in range(li+1): 143 | latex.pop(0) 144 | 145 | elif symbol == '\\frac': 146 | assert latex[0] == '{', '\\frac above not with {' 147 | li = findnextbracket(latex, leftbracket='{') 148 | sub_latex = latex[1:li] 149 | node = latex2Tree(sub_latex) 150 | cur_node.childs.append(node) 151 | cur_node.relations.append('above') 152 | for _ in range(li+1): 153 | latex.pop(0) 154 | assert latex[0] == '{', '\\frac below not with {' 155 | li = findnextbracket(latex, leftbracket='{') 156 | sub_latex = latex[1:li] 157 | node = latex2Tree(sub_latex) 158 | cur_node.childs.insert(-1, node) 159 | cur_node.relations.insert(-1, 'below') 160 | for _ in range(li+1): 161 | latex.pop(0) 162 | 163 | elif symbol == '\\sqrt': 164 | if latex[0] == '[': 165 | li = findnextbracket(latex, leftbracket='[') 166 | sub_latex = latex[1:li] 167 | node = latex2Tree(sub_latex) 168 | cur_node.childs.append(node) 169 | cur_node.relations.append('leftup') 170 | for _ in range(li+1): 171 | latex.pop(0) 172 | assert latex[0] == '{', '\\sqrt inside not with {' 173 | li = findnextbracket(latex, leftbracket='{') 174 | sub_latex = latex[1:li] 175 | node = latex2Tree(sub_latex) 176 | cur_node.childs.append(node) 177 | cur_node.relations.append('inside') 178 | for _ in range(li+1): 179 | latex.pop(0) 180 | 181 | else: 182 | if len(latex) > 0 and latex[0] == '_': 183 | latex.pop(0) 184 | assert latex[0] == '{', '_ not with {' 185 | li = findnextbracket(latex, leftbracket='{') 186 | sub_latex = latex[1:li] 187 | node = latex2Tree(sub_latex) 188 | cur_node.childs.append(node) 189 | cur_node.relations.append('sub') 190 | for _ in range(li+1): 191 | latex.pop(0) 192 | if len(latex) > 0 and latex[0] == '^': 193 | latex.pop(0) 194 | assert latex[0] == '{', '^ not with {' 195 | li = findnextbracket(latex, leftbracket='{') 196 | sub_latex = latex[1:li] 197 | node = latex2Tree(sub_latex) 198 | cur_node.childs.append(node) 199 | cur_node.relations.append('sup') 200 | for _ in range(li+1): 201 | latex.pop(0) 202 | 203 | if len(latex) > 0 and latex[0] == '\\\\': 204 | latex.pop(0) 205 | relation = 'nextline' 206 | elif len(latex) > 0: 207 | relation = 'right' 208 | else: 209 | relation = 'end' 210 | node = latex2Tree(latex) 211 | cur_node.childs.append(node) 212 | cur_node.relations.append(relation) 213 | 214 | return cur_node 215 | 216 | index = 0 217 | def node2list(parent, parent_index, relation, current, gtd, initial=None): 218 | if current is None or current.x == '': 219 | return 220 | global index 221 | if initial is not None: 222 | index = 1 223 | else: 224 | index = index + 1 225 | gtd.append([current.x, index, parent, parent_index, relation]) 226 | parent_index = index 227 | for child, relation in zip(current.childs, current.relations): 228 | node2list(current.x, parent_index, relation, child, gtd) 229 | 230 | def node2list_shuffle(parent, parent_index, relation, current, gtd, initial=None): 231 | if current is None: 232 | return 233 | global index 234 | if initial is not None: 235 | index = 1 236 | else: 237 | index = index + 1 238 | gtd.append([current.x, index, parent, parent_index, relation]) 239 | parent_index = index 240 | zip_childs = list(zip(current.childs, current.relations)) 241 | random.shuffle(zip_childs) 242 | for child, relation in zip_childs: 243 | node2list_shuffle(current.x, parent_index, relation, child, gtd) 244 | 245 | def list2node(gtd): 246 | node_list = [] 247 | root = Node('root') 248 | node_list.append(root) 249 | for g in gtd: 250 | child_node = Node(g[0]) 251 | node_list.append(child_node) 252 | parent_node = node_list[g[3]] 253 | parent_node.childs.append(child_node) 254 | parent_node.relations.append(g[4]) 255 | return node_list[1] 256 | 257 | 258 | 259 | 260 | def tree2latex(root): 261 | symbol = root.x 262 | latex = [symbol] 263 | if symbol == '': 264 | return [] 265 | elif symbol == '\\begin{matrix}': 266 | for child, relation in zip(root.childs, root.relations): 267 | if relation == 'Mstart': 268 | latex += tree2latex(child) 269 | latex.append('\\end{matrix}') 270 | elif relation == 'nextline': 271 | latex.append('\\\\') 272 | latex += tree2latex(child) 273 | else: 274 | latex += tree2latex(child) 275 | 276 | elif symbol == '\\frac': 277 | below_latex = [] 278 | for child, relation in zip(root.childs, root.relations): 279 | if relation == 'below': 280 | below_latex.append('{') 281 | below_latex += tree2latex(child) 282 | below_latex.append('}') 283 | elif relation == 'above': 284 | latex.append('{') 285 | latex += tree2latex(child) 286 | latex.append('}') 287 | latex += below_latex 288 | elif relation == 'nextline': 289 | latex.append('\\\\') 290 | latex += tree2latex(child) 291 | else: 292 | latex += tree2latex(child) 293 | elif symbol in ['\\frac', '\\underline', '\\underbrace', '\\dot', 294 | '\\ddot', '\\hat', '\\check', '\\grave', '\\acute', 295 | '\\tilde', '\\breve', '\\bar', '\\vec', '\\widehat', 296 | '\\overbrace', '\\widetilde','\\overleftarrow', 297 | '\\overrightarrow','\\overline'] : 298 | for child, relation in zip(root.childs, root.relations): 299 | if relation in ['above', 'below']: 300 | latex.append('{') 301 | latex += tree2latex(child) 302 | latex.append('}') 303 | elif relation == 'nextline': 304 | latex.append('\\\\') 305 | latex += tree2latex(child) 306 | else: 307 | latex += tree2latex(child) 308 | elif symbol == '\\sqrt': 309 | for child, relation in zip(root.childs, root.relations): 310 | if relation == 'leftup': 311 | latex.append('[') 312 | latex += tree2latex(child) 313 | latex.append(']') 314 | elif relation == 'inside': 315 | latex.append('{') 316 | latex += tree2latex(child) 317 | latex.append('}') 318 | elif relation == 'nextline': 319 | latex.append('\\\\') 320 | latex += tree2latex(child) 321 | else: 322 | latex += tree2latex(child) 323 | elif symbol in ['\\xrightarrow', '\\xleftarrow']: 324 | for child, relation in zip(root.childs, root.relations): 325 | if relation == 'below': 326 | latex.append('[') 327 | latex += tree2latex(child) 328 | latex.append(']') 329 | elif relation == 'above': 330 | latex.append('{') 331 | latex += tree2latex(child) 332 | latex.append('}') 333 | elif relation == 'nextline': 334 | latex.append('\\\\') 335 | latex += tree2latex(child) 336 | else: 337 | latex += tree2latex(child) 338 | elif symbol in ['\\iint', '\\bigcup', '\\sum', '\\lim', '\\coprod']: 339 | for child, relation in zip(root.childs, root.relations): 340 | if relation == 'below': 341 | latex.append('_') 342 | latex.append('{') 343 | latex += tree2latex(child) 344 | latex.append('}') 345 | elif relation == 'above': 346 | latex.append('^') 347 | latex.append('{') 348 | latex += tree2latex(child) 349 | latex.append('}') 350 | elif relation == 'nextline': 351 | latex.append('\\\\') 352 | latex += tree2latex(child) 353 | else: 354 | latex += tree2latex(child) 355 | else: 356 | for child, relation in zip(root.childs, root.relations): 357 | if relation == 'sub': 358 | latex.append('_') 359 | latex.append('{') 360 | latex += tree2latex(child) 361 | latex.append('}') 362 | elif relation == 'sup': 363 | latex.append('^') 364 | latex.append('{') 365 | latex += tree2latex(child) 366 | latex.append('}') 367 | elif relation == 'nextline': 368 | latex.append('\\\\') 369 | latex += tree2latex(child) 370 | else: 371 | latex += tree2latex(child) 372 | return latex 373 | 374 | def relation2gtd(objects, relations, id2object, id2relation): 375 | gtd = [[] for o in objects] 376 | num_relation = len(relations[0]) 377 | 378 | start_relation = np.array([0 for _ in range(num_relation)]) 379 | start_relation[0] = 1 380 | relation_stack = [start_relation] 381 | parent_stack = [(len(id2object)-1, 0)] 382 | p_re = len(id2relation) - 1 383 | p_y = len(id2object)-1 384 | p_id = 0 385 | 386 | for ci, c in enumerate(objects): 387 | gtd[ci].append(id2object[c]) 388 | gtd[ci].append(ci+1) 389 | 390 | find_flag = False 391 | while relation_stack != []: 392 | if relation_stack[-1][:num_relation-1].sum() > 0: 393 | for index_relation in range(num_relation): 394 | if relation_stack[-1][index_relation] != 0: 395 | p_re = index_relation 396 | p_y, p_id = parent_stack[-1] 397 | relation_stack[-1][index_relation] = 0 398 | if relation_stack[-1][:num_relation-1].sum() == 0: 399 | relation_stack.pop() 400 | parent_stack.pop() 401 | find_flag = 1 402 | break 403 | else: 404 | relation_stack.pop() 405 | parent_stack.pop() 406 | 407 | if find_flag: 408 | break 409 | 410 | if not find_flag: 411 | p_y = objects[ci-1] 412 | p_id = ci 413 | p_re = num_relation - 1 414 | gtd[ci].append(id2object[p_y]) 415 | gtd[ci].append(p_id) 416 | gtd[ci].append(id2relation[p_re]) 417 | 418 | relation_stack.append(relations[ci]) 419 | parent_stack.append((c, ci+1)) 420 | 421 | return gtd 422 | 423 | 424 | if __name__ == '__main__': 425 | latex = 'a = \\frac { x } { y } + \sqrt [ c ] { b }' 426 | tree = latex2Tree(latex.split()) 427 | gtd = [] 428 | #global index 429 | index = 0 430 | node2list('root', 0, 'start', tree, gtd, initial=True) 431 | print('original gtd:') 432 | for g in gtd: 433 | g = [str(item) for item in g] 434 | print('\t\t'.join(g)) 435 | predict_tree = list2node(gtd) 436 | predict_latex = tree2latex(predict_tree) 437 | predict_latex = ' '.join(predict_latex) 438 | 439 | print('shuffled gtd:') 440 | gtd = [] 441 | node2list_shuffle('root', 0, 'start', predict_tree, gtd, initial=True) 442 | for g in gtd: 443 | g = [str(item) for item in g] 444 | print('\t\t'.join(g)) 445 | shuffle_tree = list2node(gtd) 446 | shuffle_latex = tree2latex(shuffle_tree) 447 | shuffle_latex = ' '.join(shuffle_latex) 448 | print(latex == predict_latex) 449 | print(latex) 450 | print(predict_latex) 451 | print(shuffle_latex) 452 | 453 | 454 | 455 | 456 | --------------------------------------------------------------------------------