├── utils
├── __init__.py
├── utils.py
├── data_iterator.py
└── latex2gtd_v2_2.py
├── README.md
├── model
├── encoder.py
├── encoder_decoder.py
└── decoder.py
└── train.py
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TDv2
2 | The source codes of TDv2 in paper:
3 |
4 | TDv2: A Novel Tree-Structured Decoder for Offline Mathematical Expression Recognition.
5 |
6 | ## Note
7 | Due to the company's confidentiality regulations, we cannot take out the original experimental code from the company's intranet, and can only reproduce a copy of the code. More details will be added in the future.
8 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | from torch import nn
4 | from .latex2gtd_v2_2 import list2node, node2list_shuffle
5 |
6 | # load dictionary
7 | def load_dict(dictFile):
8 | lexicon = {}
9 | with open(dictFile) as fp:
10 | stuff = fp.readlines()
11 | for l in stuff:
12 | w = l.strip().split()
13 | lexicon[w[0]] = int(w[1])
14 | return lexicon
15 |
16 | # create batch
17 | def prepare_data(params, images_x, seqs_y, seqs_key, object2id, relation2id, shuffle=False):
18 | heights_x = [s.shape[1] for s in images_x]
19 | widths_x = [s.shape[2] for s in images_x]
20 | lengths_y = [len(s) for s in seqs_y]
21 |
22 | n_samples = len(heights_x)
23 | max_height_x = np.max(heights_x)
24 | max_width_x = np.max(widths_x)
25 | maxlen_y = np.max(lengths_y)
26 | num_relation = len(relation2id)
27 |
28 | x = np.zeros((n_samples, params['input_channels'], max_height_x, max_width_x)).astype(np.float32) - 1
29 | childs = np.zeros((maxlen_y, n_samples)).astype(np.int64)
30 | parents = np.zeros((maxlen_y, n_samples)).astype(np.int64)
31 | c_pos = np.zeros((maxlen_y, n_samples)).astype(np.int64)
32 | p_pos = np.zeros((maxlen_y, n_samples)).astype(np.int64)
33 | relations = np.zeros((maxlen_y, n_samples)).astype(np.int64)
34 | pathes = np.zeros((maxlen_y, n_samples, num_relation)).astype(np.float32)
35 |
36 | x_mask = np.zeros((n_samples, max_height_x, max_width_x)).astype(np.float32)
37 | y_mask = np.zeros((maxlen_y, n_samples)).astype(np.float32)
38 |
39 | for idx, (s_x, s_y, s_key) in enumerate(zip(images_x, seqs_y, seqs_key)):
40 | x[idx, :, :heights_x[idx], :widths_x[idx]] = (255 - s_x) / 255.
41 | x_mask[idx, :heights_x[idx], :widths_x[idx]] = 1.
42 | if shuffle:
43 | tree = list2node(s_y)
44 | s_y = []
45 | node2list_shuffle('', 1, 'Start', tree, s_y, True)
46 | for i, line in enumerate(s_y):
47 | c_id, p_id, re_id = object2id[line[0]], object2id[line[2]], relation2id[line[4]]
48 | childs[i, idx] = c_id
49 | parents[i, idx] = p_id
50 | c_pos[i, idx] = int(line[1])
51 | p_pos[i, idx] = int(line[3])
52 | relations[i, idx] = re_id
53 | pathes[i, idx, -1] = 1
54 | if i > 0:
55 | pathes[int(line[3])-1, idx, -1] = 0
56 | pathes[int(line[3])-1, idx, re_id] = 1
57 |
58 | y_mask[:lengths_y[idx], idx] = 1.
59 |
60 | return x, x_mask, childs, y_mask, parents, relations, pathes, c_pos, p_pos
61 |
62 |
63 | # init model params
64 | def weight_init(m):
65 | if isinstance(m, nn.Conv2d):
66 | nn.init.xavier_uniform_(m.weight.data)
67 | try:
68 | nn.init.constant_(m.bias.data, 0.)
69 | except:
70 | pass
71 |
72 | if isinstance(m, nn.Linear):
73 | nn.init.xavier_uniform_(m.weight.data)
74 | try:
75 | nn.init.constant_(m.bias.data, 0.)
76 | except:
77 | pass
78 |
79 | # compute metric
80 | def cmp_result(rec,label):
81 | dist_mat = np.zeros((len(label)+1, len(rec)+1),dtype='int32')
82 | dist_mat[0,:] = range(len(rec) + 1)
83 | dist_mat[:,0] = range(len(label) + 1)
84 | for i in range(1, len(label) + 1):
85 | for j in range(1, len(rec) + 1):
86 | hit_score = dist_mat[i-1, j-1] + (label[i-1] != rec[j-1])
87 | ins_score = dist_mat[i,j-1] + 1
88 | del_score = dist_mat[i-1, j] + 1
89 | dist_mat[i,j] = min(hit_score, ins_score, del_score)
90 |
91 | dist = dist_mat[len(label), len(rec)]
92 | return dist, len(label)
93 |
94 |
95 |
96 |
97 |
--------------------------------------------------------------------------------
/model/encoder.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | # DenseNet-B结构
7 | # BN->Relu->conv1(nChannels,4*growthRate,1,1)->BN->Relu->conv2(4*growthRate,growthRate,3,3)
8 | class Bottleneck(nn.Module):
9 | def __init__(self, nChannels, growthRate, use_dropout):
10 | super(Bottleneck, self).__init__()
11 | interChannels = 4 * growthRate
12 | self.bn1 = nn.BatchNorm2d(interChannels)
13 | self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1,
14 | bias=False)
15 | self.bn2 = nn.BatchNorm2d(growthRate)
16 | self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3,
17 | padding=1, bias=False)
18 | self.use_dropout = use_dropout
19 | self.dropout = nn.Dropout(p=0.2)
20 |
21 | def forward(self, x):
22 | out = F.relu(self.bn1(self.conv1(x)), inplace=True)
23 | if self.use_dropout:
24 | out = self.dropout(out)
25 | out = F.relu(self.bn2(self.conv2(out)), inplace=True)
26 | if self.use_dropout:
27 | out = self.dropout(out)
28 | out = torch.cat((x, out), 1)
29 | return out
30 |
31 |
32 | # 一般的DenseNet结构
33 | # BN->Relu->conv(nChannels,*growthRate,3,3)
34 | class SingleLayer(nn.Module):
35 | def __init__(self, nChannels, growthRate, use_dropout):
36 | super(SingleLayer, self).__init__()
37 | self.bn1 = nn.BatchNorm2d(nChannels)
38 | self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3,
39 | padding=1, bias=False)
40 | self.use_dropout = use_dropout
41 | self.dropout = nn.Dropout(p=0.2)
42 |
43 | def forward(self, x):
44 | out = self.conv1(F.relu(x, inplace=True))
45 | if self.use_dropout:
46 | out = self.dropout(out)
47 | out = torch.cat((x, out), 1)
48 | return out
49 |
50 |
51 | # DenseNet-C结构
52 | # BN->Relu->conv->avg_pool(2,2)
53 | class Transition(nn.Module):
54 | def __init__(self, nChannels, nOutChannels, use_dropout):
55 | super(Transition, self).__init__()
56 | self.bn1 = nn.BatchNorm2d(nOutChannels)
57 | self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1,
58 | bias=False)
59 | self.use_dropout = use_dropout
60 | self.dropout = nn.Dropout(p=0.2)
61 |
62 | def forward(self, x):
63 | out = F.relu(self.bn1(self.conv1(x)), inplace=True)
64 | if self.use_dropout:
65 | out = self.dropout(out)
66 | out = F.avg_pool2d(out, 2, ceil_mode=True)
67 | return out
68 |
69 |
70 | class DenseNet(nn.Module):
71 | def __init__(self, growthRate, reduction, bottleneck, use_dropout):
72 | super(DenseNet, self).__init__()
73 | nDenseBlocks = 22
74 | nChannels = 2 * growthRate # 48
75 | self.conv1 = nn.Conv2d(1, nChannels, kernel_size=7, padding=3, stride=2, bias=False)
76 |
77 | self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck, use_dropout)
78 | nChannels += nDenseBlocks * growthRate
79 | nOutChannels = int(math.floor(nChannels * reduction))
80 | self.trans1 = Transition(nChannels, nOutChannels, use_dropout)
81 |
82 | nChannels = nOutChannels
83 | self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck, use_dropout)
84 | nChannels += nDenseBlocks * growthRate
85 | nOutChannels = int(math.floor(nChannels * reduction))
86 | self.trans2 = Transition(nChannels, nOutChannels, use_dropout)
87 |
88 | nChannels = nOutChannels
89 | self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck, use_dropout)
90 |
91 | def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck, use_dropout):
92 | layers = []
93 | for _ in range(int(nDenseBlocks)):
94 | if bottleneck:
95 | layers.append(Bottleneck(nChannels, growthRate, use_dropout))
96 | else:
97 | layers.append(SingleLayer(nChannels, growthRate, use_dropout))
98 | nChannels += growthRate
99 | return nn.Sequential(*layers)
100 |
101 | def forward(self, x, x_mask):
102 | out = self.conv1(x)
103 | out_mask = x_mask[:, 0::2, 0::2]
104 | out = F.relu(out, inplace=True)
105 | out = F.max_pool2d(out, 2, ceil_mode=True)
106 | out_mask = out_mask[:, 0::2, 0::2]
107 | out = self.dense1(out)
108 | out = self.trans1(out)
109 | out_mask = out_mask[:, 0::2, 0::2]
110 | out = self.dense2(out)
111 | out = self.trans2(out)
112 | out_mask = out_mask[:, 0::2, 0::2]
113 | out = self.dense3(out)
114 | return out, out_mask
115 |
--------------------------------------------------------------------------------
/utils/data_iterator.py:
--------------------------------------------------------------------------------
1 | import pickle as pkl
2 | import random
3 | import math
4 |
5 | class BatchBucket():
6 | def __init__(self, max_h, max_w, max_l, max_img_size, max_batch_size,
7 | feature_file, lable_file, shuffle=True):
8 | self._max_img_size = max_img_size
9 | self._max_batch_size = max_batch_size
10 | self._shuffle = shuffle
11 |
12 | with open(feature_file, 'rb') as fp:
13 | self._features = pkl.load(fp)
14 | with open(lable_file, 'rb') as fp:
15 | self._labels = pkl.load(fp)
16 | self._data_info_list = [(uid, fea.shape[1], fea.shape[2], len(self._labels[uid]))
17 | for uid, fea in self._features.items()]
18 |
19 | self._keys = self._calc_keys(max_h, max_w, max_l)
20 |
21 | def _calc_keys(self, max_h, max_w, max_l):
22 | # 计算出真实的最大值
23 | _, h_info, w_info, l_info = zip(*self._data_info_list)
24 | mh, mw, ml = max(h_info), max(w_info), max(l_info)
25 | max_h, max_w, max_l = min(max_h, mh), min(max_w, mw), min(max_l, ml)
26 |
27 | # 根据真实的最大值切分网络
28 | keys = []
29 | init_h = 100 if 100 < max_h else max_h
30 | init_w = 100 if 100 < max_w else max_w
31 | # init_l = 100 if 100 < max_l else max_l
32 | init_l = max_l
33 | #网格的切分间距
34 | h_step = 50
35 | w_step = 100
36 | l_step = 20
37 | h = init_h
38 | while h <= max_h:
39 | w = init_w
40 | while w <= max_w:
41 | l = init_l
42 | while l <= max_l:
43 | keys.append([h, w, l, h * w * l, 0, []])
44 | if l < max_l and l + l_step > max_l:
45 | l = max_l
46 | else:
47 | l += l_step
48 | if w < max_w and w + max(int((w*0.3 // 10) * 10), w_step) > max_w:
49 | w = max_w
50 | else:
51 | w = w + max(int((w*0.3 // 10) * 10), w_step)
52 | if h < max_h and h + max(int((h*0.5 // 10) * 10), h_step) > max_h:
53 | h = max_h
54 | else:
55 | h = h + max(int((h*0.5 // 10) * 10), h_step)
56 | keys = sorted(keys, key=lambda area:area[3])
57 |
58 | # 把每个数据分配到想对应的网格中
59 | # 统计每个网格中落下的样本数量
60 | unused_num = 0
61 | for uid, h, w, l in self._data_info_list:
62 | flag = False
63 | for i, key in enumerate(keys):
64 | hh, ww, ll, _, _, subset = key
65 | if h <= hh and w <= ww and l <= ll:
66 | keys[i][-2] += 1
67 | subset.append(uid)
68 | flag = True
69 | break
70 | if flag == False:
71 | print(uid, h, w, l)
72 | unused_num += 1
73 | print(f'The number of all samples: {len(self._data_info_list)}')
74 | print(f'The number of unused samples: {unused_num}')
75 | # 过滤所有网格中少于某个阈值的网格
76 | keys = list(filter(lambda temp_k:temp_k[-2]>0, keys))
77 | # 计算每个子网格的batch大小
78 | # total_batch, total_sample = 0, 0
79 | for i, key in enumerate(keys):
80 | h, w, l, _, sample_num, _ = key
81 | batch_size = int(self._max_img_size / (h * w))
82 | batch_size = max(1, min(batch_size, self._max_batch_size))
83 | keys[i][3] = batch_size
84 | print(f'bucket [{h}, {w}, {l}], batch={batch_size}, sample={sample_num}')
85 | # 返回网格的key值
86 | return keys
87 |
88 | def _reset_batches(self):
89 | self._batches = []
90 | #打乱每个子网格中的样本顺序
91 | for _, _, _, batch_size, sample_num, uids in self._keys:
92 | if self._shuffle:
93 | random.shuffle(uids)
94 | batch_num = math.ceil(sample_num / batch_size)
95 | for i in range(batch_num):
96 | start = i * batch_size
97 | end = start + batch_size if start + batch_size < sample_num else sample_num
98 | self._batches.append(uids[start:end])
99 | if self._shuffle:
100 | random.shuffle(self._batches)
101 |
102 | def get_batches(self):
103 | batches = []
104 | self._reset_batches()
105 | for uid_batch in self._batches:
106 | fea_batch = [self._features[uid] for uid in uid_batch]
107 | label_batch = [self._labels[uid] for uid in uid_batch]
108 | batches.append((fea_batch, label_batch, uid_batch))
109 | print(f'The number of Bucket(subset) {len(self._keys)}')
110 | print(f'The number of Batches {len(batches)}')
111 | print(f'The number of Samples {len(self._data_info_list)}')
112 |
113 | return batches
114 |
115 | if __name__ == '__main__':
116 | data_path = 'data/'
117 | train_datasets = ['train_images.pkl', 'train_labels.pkl', 'train_relations.pkl']
118 | train_data_iterator = BatchBucket(600, 2100, 200, 800000, 16,
119 | data_path+train_datasets[0], data_path+train_datasets[1])
120 | train_batches = train_data_iterator.get_batches()
--------------------------------------------------------------------------------
/model/encoder_decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .encoder import DenseNet
5 | from .decoder import Decoder
6 |
7 | class Encoder_Decoder(nn.Module):
8 | def __init__(self, params):
9 | super(Encoder_Decoder, self).__init__()
10 | self.encoder = DenseNet(growthRate=params['growthRate'],
11 | reduction=params['reduction'],
12 | bottleneck=params['bottleneck'],
13 | use_dropout=params['use_dropout'])
14 | self.init_context = nn.Linear(params['D'], params['n'])
15 | self.decoder = Decoder(params)
16 | self.object_criterion = nn.CrossEntropyLoss(reduction='none')
17 | self.relation_criterion = nn.BCEWithLogitsLoss(reduction='none')
18 | self.object_pix_criterion = nn.NLLLoss(reduction='none')
19 | self.param_n = params['n']
20 |
21 | def forward(self, params, x, x_mask, C_y,
22 | P_y, C_re, P_re, P_position, y_mask, re_mask, length):
23 | #encoder
24 | ctx, ctx_mask = self.encoder(x, x_mask)
25 | ctx_mean = (ctx * ctx_mask[:, None, :, :]).sum(3).sum(2) \
26 | / ctx_mask.sum(2).sum(1)[:, None]
27 | init_state = torch.tanh(self.init_context(ctx_mean))
28 | #decoder
29 | predict_objects, predict_relations,predict_objects_pix = self.decoder(ctx, ctx_mask,
30 | C_y, P_y, y_mask, P_re, P_position, init_state, length)
31 |
32 | #loss
33 | predict_objects = predict_objects.view(-1, predict_objects.shape[2])
34 | object_loss = self.object_criterion(predict_objects, C_y.view(-1))
35 | object_loss = object_loss.view(C_y.shape[0], C_y.shape[1])
36 | object_loss = ((object_loss * y_mask).sum(0) / y_mask.sum(0)).mean()
37 |
38 | predict_objects_pix = predict_objects_pix.view(-1, predict_objects_pix.shape[2])
39 | object_pix_loss = self.object_pix_criterion(predict_objects_pix, C_y.view(-1))
40 | object_pix_loss = object_pix_loss.view(C_y.shape[0], C_y.shape[1])
41 | object_pix_loss = ((object_pix_loss * y_mask).sum(0) / y_mask.sum(0)).mean()
42 |
43 | relation_loss = predict_relations.view(-1, predict_relations.shape[2])
44 | relation_loss = self.relation_criterion(relation_loss, C_re.view(-1, C_re.shape[2]))
45 | relation_loss = relation_loss.view(C_re.shape[0], C_re.shape[1], C_re.shape[2])
46 | relation_loss = (relation_loss * re_mask[:, :, None]).sum(2).sum(0) / re_mask.sum(0)
47 | relation_loss = relation_loss.mean()
48 |
49 | loss = params['lc_lambda'] * object_loss + \
50 | params['lr_lambda'] * relation_loss + \
51 | params['lc_lambda_pix'] * object_pix_loss
52 |
53 | return loss, object_loss, relation_loss
54 |
55 | def greedy_inference(self, x, x_mask, max_length, p_y, p_re, p_mask):
56 |
57 | ctx, ctx_mask = self.encoder(x, x_mask)
58 | ctx_mean = (ctx * ctx_mask[:, None, :, :]).sum(3).sum(2) \
59 | / ctx_mask.sum(2).sum(1)[:, None]
60 | init_state = torch.tanh(self.init_context(ctx_mean))
61 |
62 | B, H, W = ctx_mask.shape
63 | attention_past = torch.zeros(B, 1, H, W).cuda()
64 |
65 | ctx_key_object = self.decoder.conv_key_object(ctx).permute(0, 2, 3, 1)
66 | ctx_key_relation = self.decoder.conv_key_relation(ctx).permute(0, 2, 3, 1)
67 |
68 | relation_table = torch.zeros(B, max_length, 9).to(torch.long)
69 | relation_table_static = torch.zeros(B, max_length, 9).to(torch.long)
70 | predict_relation_static = torch.zeros(B, max_length, 9)
71 | P_masks = torch.zeros(B, max_length).cuda()
72 | predict_childs = torch.zeros(max_length, B).to(torch.long).cuda()
73 |
74 | ht = init_state
75 | parent_ht = init_state
76 | for i in range(max_length):
77 | predict_child, ht, attention, ct = self.decoder.get_child(ctx, ctx_key_object, ctx_mask,
78 | attention_past, p_y, p_mask, p_re, ht)
79 | predict_childs[i] = torch.argmax(predict_child, dim=1)
80 | predict_childs[i] *= p_mask.to(torch.long)
81 | attention_past = attention[:, None, :, :] + attention_past
82 |
83 | predict_relation, ht_relation = self.decoder.get_relation(ctx, ctx_key_relation, ctx_mask,
84 | predict_childs[i], ht_relation, ct)
85 |
86 | P_masks[:, i] = p_mask
87 |
88 | predict_relation_static[:, i, :] = predict_relation
89 | relation_table[:, i, :] = (predict_relation > 0)
90 | relation_table_static[:, i, :] = (predict_relation > 0)
91 |
92 | relation_table[:, :, 8] = relation_table[:, :, :8].sum(2)
93 |
94 | find_number = 0
95 | for ii in range(B):
96 | if p_mask[ii] < 0.5:
97 | continue
98 | ji = i
99 | find_flag = 0
100 | while ji >= 0:
101 | if relation_table[ii, ji, 8] > 0:
102 | for iii in range(9):
103 | if relation_table[ii, ji, iii] != 0:
104 | p_re[ii] = iii
105 | p_y[ii] = predict_childs[ji, ii]
106 | relation_table[ii, ji, iii] = 0
107 | relation_table[ii, ji, 8] -= 1
108 | find_flag = 1
109 | break
110 | if find_flag:
111 | break
112 | ji -= 1
113 | find_number += find_flag
114 | if not find_flag:
115 | p_mask[ii] = 0.
116 |
117 | if find_number == 0:
118 | break
119 | return predict_childs, P_masks, relation_table_static, predict_relation_static
120 |
121 |
--------------------------------------------------------------------------------
/model/decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class Decoder(nn.Module):
6 | def __init__(self, params):
7 | super(Decoder, self).__init__()
8 | self.object_embedding = nn.Embedding(params['K'], params['m'])
9 | self.relation_embedding = nn.Embedding(params['Kre'], params['re_m'])
10 |
11 | self.gru00 = nn.GRUCell(params['m']+params['re_m'], params['n'])
12 | self.gru01 = nn.GRUCell(params['D'], params['n'])
13 | self.gru10 = nn.GRUCell(params['D'], params['n'])
14 | self.gru11 = nn.GRUCell(params['D'], params['n'])
15 |
16 | self.conv_key_object = nn.Conv2d(params['D'], params['dim_attention'], kernel_size=1)
17 | self.object_attention = CoverageAttention(params)
18 | self.object_probility = ObjectProbility(params)
19 |
20 | self.conv_key_relation = nn.Conv2d(params['D'], params['dim_attention'], kernel_size=1)
21 | self.relation_attention = Attention(params)
22 | self.relation_probility = RelationProbility(params)
23 |
24 | self.pix_object_probility = PixObjectProbility(params)
25 |
26 | self.param_K = params['K']
27 | self.param_Kre = params['Kre']
28 | self.param_n = params['n']
29 |
30 | def forward(self, ctx_val, ctx_mask, C_ys, P_ys, P_y_masks,
31 | P_res, P_positions, init_state, length):
32 |
33 | B, H, W = ctx_mask.shape
34 | attention_past = torch.zeros(B, 1, H, W).cuda()
35 | predict_childs = torch.zeros(length, B, self.param_K).cuda()
36 | predict_childs_pix = torch.zeros(length, B, self.param_K).cuda()
37 | predict_relations = torch.zeros(length, B, self.param_Kre).cuda()
38 |
39 | ht = init_state
40 | ht_relation = init_state
41 |
42 | ctx_key_object = self.conv_key_object(ctx_val).permute(0, 2, 3, 1)
43 | ctx_key_relation = self.conv_key_relation(ctx_val).permute(0, 2, 3, 1)
44 |
45 | predict_features = self.pix_object_probility(ctx_val)
46 | for i in range(length):
47 | predict_childs[i], ht, attention, ct = self.get_child(ctx_val, ctx_key_object, ctx_mask,
48 | attention_past, P_ys[i], P_y_masks[i], P_res[i], ht)
49 | attention_past = attention[:, None, :, :] + attention_past
50 |
51 | predict_childs_pix[i] = torch.log((attention.detach()[:, :, :, None] * predict_features).sum(2).sum(1))
52 | predict_relations[i], ht_relation = self.get_relation(ctx_val, ctx_key_relation,
53 | ctx_mask, C_ys[i], ht_relation, ct)
54 |
55 | return predict_childs, predict_relations, predict_childs_pix
56 |
57 | def get_child(self, ctx_val, ctx_key, ctx_mask, attention_past, p_y, p_y_mask, p_re, ht):
58 | p_y = self.object_embedding(p_y)
59 | p_re = self.relation_embedding(p_re)
60 | p = torch.cat([p_y, p_re], dim=1)
61 | ht_hat = self.gru00(p, ht)
62 | ht_hat = p_y_mask[:, None] * ht_hat + (1 - p_y_mask)[:, None] * ht
63 |
64 | ct, attention = self.object_attention(ctx_val, ctx_key, ctx_mask, attention_past, ht_hat)
65 |
66 | ht = self.gru01(ct, ht_hat)
67 | ht = p_y_mask[:, None] * ht + (1 - p_y_mask)[:, None] * ht_hat
68 |
69 | predict_child = self.object_probility(ct, ht, p_y, p_re)
70 |
71 | return predict_child, ht, attention, ct
72 |
73 | def get_relation(self, ctx_val, ctx_key, ctx_mask, c_y, ht, ct):
74 | c_y = self.object_embedding(c_y)
75 | ht_query = self.gru10(ct, ht)
76 | ct, _ = self.relation_attention(ctx_val, ctx_key, ctx_mask, ht_query)
77 | ht = self.gru11(ct, ht_query)
78 |
79 | predict_relation = self.relation_probility(ct, c_y, ht)
80 | return predict_relation, ht
81 |
82 | class CoverageAttention(nn.Module):
83 | def __init__(self, params):
84 | super(CoverageAttention, self).__init__()
85 | self.fc_query = nn.Linear(params['n'], params['dim_attention'], bias=False)
86 | self.conv_att_past = nn.Conv2d(1, 512, kernel_size=11, bias=False, padding=5)
87 | self.fc_att_past = nn.Linear(512, params['dim_attention'])
88 | self.fc_attention = nn.Linear(params['dim_attention'], 1)
89 |
90 | def forward(self, ctx_val, ctx_key, ctx_mask, attention_past, ht_query):
91 |
92 | ht_query = self.fc_query(ht_query)
93 |
94 | attention_past = self.conv_att_past(attention_past).permute(0, 2, 3, 1)
95 | attention_past = self.fc_att_past(attention_past) #(batch, H, W, dim_att)
96 |
97 | attention_score = torch.tanh(ctx_key + ht_query[:, None, None, :] + attention_past)
98 | attention_score = self.fc_attention(attention_score).squeeze(3)
99 |
100 | attention_score = attention_score - attention_score.max()
101 | attention_score = torch.exp(attention_score) * ctx_mask
102 | attention_score = attention_score / (attention_score.sum(2).sum(1)[:, None, None] + 1e-10)
103 |
104 | ct = (ctx_val * attention_score[:, None, :, :]).sum(3).sum(2)
105 |
106 | return ct, attention_score
107 |
108 | class Attention(nn.Module):
109 | def __init__(self, params):
110 | super(Attention, self).__init__()
111 | self.fc_query = nn.Linear(params['n'], params['dim_attention'], bias=False)
112 | self.fc_attention = nn.Linear(params['dim_attention'], 1)
113 |
114 | def forward(self, ctx_val, ctx_key, ctx_mask, ht_query):
115 |
116 | ht_query = self.fc_query(ht_query)
117 |
118 | attention_score = torch.tanh(ctx_key + ht_query[:, None, None, :])
119 | attention_score = self.fc_attention(attention_score).squeeze(3)
120 |
121 | attention_score = attention_score - attention_score.max()
122 | attention_score = torch.exp(attention_score) * ctx_mask
123 | attention_score = attention_score / (attention_score.sum(2).sum(1)[:, None, None] + 1e-10)
124 |
125 | ct = (ctx_val * attention_score[:, None, :, :]).sum(3).sum(2)
126 |
127 | return ct, attention_score
128 |
129 | class ObjectProbility(nn.Module):
130 | def __init__(self, params):
131 | super(ObjectProbility, self).__init__()
132 | self.fc_ct = nn.Linear(params['D'], params['m'])
133 | self.fc_ht = nn.Linear(params['n'], params['m'])
134 | self.fc_p_y = nn.Linear(params['m'], params['m'])
135 | self.fc_p_re = nn.Linear(params['re_m'], params['m'])
136 | self.dropout = nn.Dropout(p=0.2)
137 | self.fc_probility = nn.Linear(int(params['m']/2), params['K'])
138 |
139 | def forward(self, ct, ht, p_y, p_re):
140 | out = self.fc_ct(ct) + self.fc_ht(ht) + self.fc_p_y(p_y) + self.fc_p_re(p_re)
141 | #maxout
142 | out = out.view(out.shape[0], -1, 2)
143 | out = out.max(2)[0]
144 | out = self.dropout(out)
145 | out = self.fc_probility(out)
146 | return out
147 |
148 | class PixObjectProbility(nn.Module):
149 | def __init__(self, params):
150 | super(ObjectProbility, self).__init__()
151 | self.fc_ct = nn.Linear(params['D'], params['m'])
152 | self.dropout = nn.Dropout(p=0.2)
153 | self.fc_probility = nn.Linear(int(params['m']/2), params['K'])
154 |
155 | def forward(self, ctx):
156 | out = self.fc_ct(ctx.permute(0, 2, 3, 1)) # (B, H, W, D) --> (B, H, W, m)
157 | #maxout
158 | out = out.view(out.shape[0], out.shape[1], out.shape[2], -1, 2) #(B, H, W, m/2, 2)
159 | out = out.max(4)[0] #(B, H, W, m/2)
160 | out = self.dropout(out)
161 | out = self.fc_probility(out) #(B, H, W, k)
162 | out = F.softmax(out, dim=3)
163 | return out
164 |
165 | class RelationProbility(nn.Module):
166 | def __init__(self, params):
167 | super(RelationProbility, self).__init__()
168 | self.fc_ct = nn.Linear(params['D'], params['mre'])
169 | self.fc_c_y = nn.Linear(params['m'], params['mre'])
170 | self.fc_ht = nn.Linear(params['n'], params['mre'])
171 | self.dropout = nn.Dropout(p=0.2)
172 | self.fc_probility = nn.Linear(int(params['mre']/2), params['Kre'])
173 |
174 | def forward(self, ct, c_y, ht):
175 | out = self.fc_ct(ct) + self.fc_c_y(c_y) + self.fc_ht(ht)
176 | #maxout
177 | out = out.view(out.shape[0], -1, 2)
178 | out = out.max(2)[0]
179 | out = self.dropout(out)
180 | out = self.fc_probility(out)
181 |
182 | return out
183 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 | import pickle as pkl
4 | import torch
5 | from torch import optim, nn
6 | from utils.utils import load_dict, prepare_data, weight_init, cmp_result
7 | from model.encoder_decoder import Encoder_Decoder
8 | from utils.data_iterator import BatchBucket
9 | from utils.latex2gtd_v2_2 import list2node, tree2latex, relation2gtd
10 |
11 | # configurations
12 | init_param_flag = True # whether init params
13 | reload_flag = False # whether relaod params
14 | data_path = 'data/'
15 | work_path = 'results/'
16 | dictionaries = ['dictionary_object.txt', 'dictionary_relation_noend.txt']
17 | train_datasets = ['train_images.pkl', 'train_labels.pkl', 'train_relations.pkl']
18 | valid_datasets = ['valid_images.pkl', 'valid_labels.pkl', 'valid_relations.pkl']
19 | valid_outputs = ['valid_results.txt']
20 | model_params = ['SimTree_best_params.pkl', 'SimTree_last_params.pkl']
21 |
22 | #train Settings
23 | maxlen = 200
24 | max_epochs = 5000
25 | lrate = 1
26 | my_eps = 1e-6
27 | decay_c = 1e-4
28 | clip_c = 100.
29 |
30 | # early stop
31 | estop = False
32 | halfLrFlag = 0
33 | bad_counter = 0
34 | patience = 15
35 | validStart = 0 # 模型未学习好,解码容易溢出,浪费时间
36 | finish_after = 10000000
37 |
38 | # model architecture
39 | # 这部分应该转移到模型代码中
40 | params = {}
41 | params['n'] = 256
42 | params['m'] = 256
43 | params['re_m'] = 64
44 | params['dim_attention'] = 512
45 | params['D'] = 936
46 | params['K'] = 107
47 | params['Kre'] = 9
48 | params['mre'] = 256
49 | params['maxlen'] = maxlen
50 |
51 | params['growthRate'] = 24
52 | params['reduction'] = 0.5
53 | params['bottleneck'] = True
54 | params['use_dropout'] = True
55 | params['input_channels'] = 1
56 |
57 | params['lc_lambda'] = 1.
58 | params['lr_lambda'] = 1.
59 | params['lc_lambda_pix'] = 1.
60 |
61 | symbol2id = load_dict(data_path+dictionaries[0])
62 | print('total chars', len(symbol2id))
63 | id2symbol = {}
64 | for symbol, symbol_id in symbol2id.items():
65 | id2symbol[symbol_id] = symbol
66 |
67 | relation2id = load_dict(data_path+dictionaries[1])
68 | print('total relations', len(relation2id))
69 | id2relation = {}
70 | for relation, relation_id in relation2id.items():
71 | id2relation[relation_id] = relation
72 |
73 | train_data_iterator = BatchBucket(600, 2100, 200, 800000, 16,
74 | data_path+train_datasets[0], data_path+train_datasets[1])
75 | valid_data_iterator = BatchBucket(9999, 9999, 9999, 999999, 1,
76 | data_path+valid_datasets[0], data_path+valid_datasets[1])
77 | valid = valid_data_iterator.get_batches()
78 | with open(data_path + valid_datasets[1], 'rb') as fp:
79 | valid_gtds = pkl.load(fp)
80 |
81 | # display
82 | uidx = 0
83 | object_loss_s = 0.
84 | relation_loss_s = 0.
85 | loss_s = 0.
86 |
87 | ud_s = 0
88 | validFreq = -1
89 | saveFreq = -1
90 | sampleFreq = -1
91 | dispFreq = 100
92 | WER = 100
93 |
94 | # inititalize model
95 | SimTree_model = Encoder_Decoder(params)
96 | if init_param_flag:
97 | SimTree_model.apply(weight_init)
98 | if reload_flag:
99 | print('Loading pretrained model ...')
100 | SimTree_model.load_state_dict(torch.load(work_path+model_params[1], map_location=lambda storage, loc:storage))
101 | SimTree_model.cuda()
102 |
103 | optimizer = optim.Adadelta(SimTree_model.parameters(), lr=lrate, eps=my_eps, weight_decay=decay_c)
104 | print('Optimization')
105 |
106 | # statistics
107 | history_errs = []
108 | for eidx in range(max_epochs):
109 | n_samples = 0
110 | ud_epoch = time.time()
111 | train = train_data_iterator.get_batches()
112 | if validFreq == -1:
113 | validFreq = len(train)
114 | if saveFreq == -1:
115 | saveFreq = len(train)
116 | if sampleFreq == -1:
117 | sampleFreq = len(train)
118 |
119 | for x, y, key in train:
120 | SimTree_model.train()
121 | ud_start = time.time()
122 | n_samples += len(x)
123 | uidx += 1
124 | x, x_mask, C_y, y_mask, P_y, P_re, C_re, lp, rp = \
125 | prepare_data(params, x, y, key, symbol2id, relation2id, shuffle=False)
126 |
127 | length = C_y.shape[0]
128 | x = torch.from_numpy(x).cuda()
129 | x_mask = torch.from_numpy(x_mask).cuda()
130 | C_y = torch.from_numpy(C_y).to(torch.long).cuda()
131 | y_mask = torch.from_numpy(y_mask).cuda()
132 | P_y = torch.from_numpy(P_y).to(torch.long).cuda()
133 | P_re = torch.from_numpy(P_re).to(torch.long).cuda()
134 | P_position = torch.from_numpy(rp).to(torch.long).cuda()
135 | C_re = torch.from_numpy(C_re).cuda()
136 |
137 | loss, object_loss, relation_loss = SimTree_model(params, x, x_mask,
138 | C_y, P_y, C_re, P_re, P_position, y_mask, y_mask, length)
139 |
140 | object_loss_s += object_loss.item()
141 | relation_loss_s += relation_loss.item()
142 | loss_s = loss.item()
143 |
144 | # backward
145 | optimizer.zero_grad()
146 | loss.backward()
147 | if clip_c > 0.:
148 | torch.nn.utils.clip_grad_norm_(SimTree_model.parameters(), clip_c)
149 |
150 | # updata
151 | optimizer.step()
152 |
153 | # display
154 | ud = time.time() - ud_start
155 | ud_s += ud
156 |
157 | if np.mod(uidx, dispFreq) == 0:
158 | ud_s /= 60.
159 | loss_s /= dispFreq
160 | object_loss_s /= dispFreq
161 | relation_loss_s /= dispFreq
162 | print(f'Epoch {eidx}, Update {uidx} Cost_object {object_loss_s:.7}', end='')
163 | print(f'Cost_relation {relation_loss_s}, UD {ud_s:.3} lrate {lrate} eps {my_eps} bad_counter {bad_counter}')
164 | ud_s = 0
165 | loss_s = 0.
166 | object_loss_s = 0.
167 | relation_loss_s = 0.
168 |
169 | if np.mod(uidx, saveFreq) == 0:
170 | print('Saving latest model params ... ')
171 | torch.save(SimTree_model.state_dict(), work_path+model_params[1])
172 |
173 | # validation
174 | if np.mod(uidx, sampleFreq) == 0 and (eidx % 2) == 0:
175 | number_right = 0
176 | total_distance = 0
177 | total_length = 0
178 | latex_right = 0
179 | total_latex_distance = 0
180 | total_latex_length = 0
181 | total_number = 0
182 |
183 | print('begin sampling')
184 | ud_epoch_train = (time.time() - ud_epoch) / 60.
185 | print('epoch training cost time ...', ud_epoch_train)
186 | SimTree_model.eval()
187 |
188 | fp_results = open(work_path+valid_outputs[0], 'w')
189 | with torch.no_grad():
190 | valid_count_idx = 0
191 | for x, y, valid_key in valid:
192 | x, x_mask, C_y, y_mask, P_y, P_re, C_re, lp, rp = \
193 | prepare_data(params, x, y, valid_key, symbol2id, relation2id)
194 |
195 | L, B = C_y.shape[:2]
196 | x = torch.from_numpy(x).cuda()
197 | x_mask = torch.from_numpy(x_mask).cuda()
198 | lengths_gt = (y_mask > 0.5).sum(0)
199 | y_mask = torch.from_numpy(y_mask).cuda()
200 | P_y = torch.from_numpy(P_y).to(torch.long).cuda()
201 | P_re = torch.from_numpy(P_re).to(torch.long).cuda()
202 |
203 | object_predicts, P_masks, relation_table_static, _ \
204 | = SimTree_model.greedy_inference(x, x_mask, L+1, P_y[0], P_re[0], y_mask[0])
205 | object_predicts, P_masks = object_predicts.cpu().numpy(), P_masks.cpu().numpy()
206 | relation_table_static = relation_table_static.numpy()
207 | for bi in range(B):
208 | length_predict = min((P_masks[bi, :] > 0.5).sum(), P_masks.shape[1])
209 | object_predict = object_predicts[:int(length_predict), bi]
210 | relation_predict = relation_table_static[bi, :int(length_predict), :]
211 | gtd = relation2gtd(object_predict, relation_predict, id2symbol, id2relation)
212 | latex = tree2latex(list2node(gtd))
213 |
214 | uid = valid_key[bi]
215 | groud_truth_gtd = valid_gtds[uid]
216 | groud_truth_latex = tree2latex(list2node(groud_truth_gtd))
217 |
218 | child = [symbol2id[g[0]] for g in groud_truth_gtd if g[0] != '']
219 | distance, length = cmp_result(object_predict, child)
220 | total_number += 1
221 |
222 | if distance == 0:
223 | number_right += 1
224 | fp_results.write(uid + '\tObject True\t')
225 | else:
226 | fp_results.write(uid + '\tObject False\t')
227 |
228 | latex_distance, latex_length = cmp_result(groud_truth_latex, latex)
229 | if latex_distance == 0:
230 | latex_right += 1
231 | fp_results.write('Latex True\n')
232 | else:
233 | fp_results.write('Latex False\n')
234 |
235 | total_distance += distance
236 | total_length += length
237 | total_latex_distance += latex_distance
238 | total_latex_length += latex_length
239 |
240 | fp_results.write(' '.join(groud_truth_latex) + '\n')
241 | fp_results.write(' '.join(latex) + '\n')
242 |
243 | for c in child:
244 | fp_results.write(id2symbol[c] + ' ')
245 | fp_results.write('\n')
246 |
247 | for ob_p in object_predict:
248 | fp_results.write(id2symbol[ob_p] + ' ')
249 | fp_results.write('\n')
250 |
251 | wer = total_distance / total_length * 100
252 | sacc = number_right / total_number * 100
253 | latex_wer = total_latex_distance / total_latex_length * 100
254 | latex_acc = latex_right / total_number * 100
255 | fp_results.close()
256 |
257 | ud_epoch = (time.time() - ud_epoch) / 60.
258 | print(f'valid set decode done, epoch cost time: {ud_epoch} min')
259 | print(f'WER {wer} SACC {sacc} Latex WER {latex_wer} Latex SACC {latex_acc}')
260 |
261 | if latex_wer <= WER:
262 | WER = latex_wer
263 | bad_counter = 0
264 | print('Saving best model params ... ')
265 | torch.save(SimTree_model.state_dict(), work_path+model_params[0])
266 | else:
267 | bad_counter += 1
268 | if bad_counter > patience:
269 | if halfLrFlag == 2:
270 | print('Early Stop!')
271 | estop = True
272 | break
273 | else:
274 | print('Lr decay and retrain!')
275 | bad_counter = 0
276 | lrate = lrate / 10.
277 | for param_group in optimizer.param_groups:
278 | param_group['lr'] = lrate
279 | halfLrFlag += 1
280 | if uidx >= finish_after:
281 | print(f'Finishing after {uidx} iterations!')
282 | estop = True
283 | break
284 |
285 | print(f'Seen {n_samples} samples')
286 |
287 | if estop:
288 | break
--------------------------------------------------------------------------------
/utils/latex2gtd_v2_2.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 |
4 | class Node():
5 | def __init__(self, x=0):
6 | self.x = x
7 | self.childs = []
8 | self.relations = []
9 |
10 | def findnextbracket(latex:list, leftbracket='{'):
11 | if leftbracket == '{':
12 | rightbracket = '}'
13 | elif leftbracket == '[':
14 | rightbracket = ']'
15 | else:
16 | raise AssertionError('Unkown Bracket!')
17 |
18 | num = 0
19 | for li, l in enumerate(latex):
20 | if l == leftbracket:
21 | num += 1
22 | if l == rightbracket:
23 | num -= 1
24 | if num == 0:
25 | return li
26 | return -1
27 |
28 | def findendmatrix(latex):
29 | num = 1
30 | for li, l in enumerate(latex):
31 | if l == '\\begin{matrix}':
32 | num += 1
33 | if l == '\\end{matrix}':
34 | num -= 1
35 | if num == 0:
36 | return li
37 | return -1
38 |
39 | def latex2Tree(latex:list):
40 | '''
41 | input: latex --> list
42 | output: Node
43 | '''
44 | if len(latex) == 0:
45 | return Node('')
46 |
47 | cur_node = Node(latex[0])
48 | symbol = latex.pop(0)
49 |
50 | if symbol == '':
51 | if len(latex) > 0 and latex[0] == '_':
52 | latex.pop(0)
53 | assert latex[0] == '{', '_ not with {'
54 | li = findnextbracket(latex, leftbracket='{')
55 | sub_latex = latex[1:li]
56 | node = latex2Tree(sub_latex)
57 | cur_node.childs.append(node)
58 | cur_node.relations.append('sub')
59 | for _ in range(li+1):
60 | latex.pop(0)
61 | if len(latex) > 0 and latex[0] == '^':
62 | latex.pop(0)
63 | assert latex[0] == '{', '^ not with {'
64 | li = findnextbracket(latex, leftbracket='{')
65 | sub_latex = latex[1:li]
66 | node = latex2Tree(sub_latex)
67 | cur_node.childs.append(node)
68 | cur_node.relations.append('sup')
69 | for _ in range(li+1):
70 | latex.pop(0)
71 | li = findnextbracket(latex, leftbracket='{')
72 |
73 | elif symbol == '\\begin{matrix}':
74 | li = findendmatrix(latex)
75 | sub_latex = latex[:li]
76 | node = latex2Tree(sub_latex)
77 | cur_node.childs.append(node)
78 | cur_node.relations.append('Mstart')
79 | for _ in range(li+1):
80 | latex.pop(0)
81 |
82 | elif symbol in ['\\iint', '\\bigcup', '\\sum', '\\lim', '\\coprod']:
83 | if len(latex) > 0 and latex[0] == '_':
84 | latex.pop(0)
85 | assert latex[0] == '{', '_ not with {'
86 | li = findnextbracket(latex, leftbracket='{')
87 | sub_latex = latex[1:li]
88 | node = latex2Tree(sub_latex)
89 | cur_node.childs.append(node)
90 | cur_node.relations.append('below')
91 | for _ in range(li+1):
92 | latex.pop(0)
93 | if len(latex) > 0 and latex[0] == '^':
94 | latex.pop(0)
95 | assert latex[0] == '{', '^ not with {'
96 | li = findnextbracket(latex, leftbracket='{')
97 | sub_latex = latex[1:li]
98 | node = latex2Tree(sub_latex)
99 | cur_node.childs.append(node)
100 | cur_node.relations.append('above')
101 | for _ in range(li+1):
102 | latex.pop(0)
103 |
104 | elif symbol in ['\\dot', '\\ddot', '\\hat', '\\check', '\\grave', '\\acute',
105 | '\\tilde', '\\breve', '\\bar', '\\vec', '\\widehat',
106 | '\\overbrace', '\\widetilde','\\overleftarrow',
107 | '\\overrightarrow','\\overline']:
108 | assert latex[0] == '{', 'CASE 3 above not with {'
109 | li = findnextbracket(latex, leftbracket='{')
110 | sub_latex = latex[1:li]
111 | node = latex2Tree(sub_latex)
112 | cur_node.childs.append(node)
113 | cur_node.relations.append('below')
114 | for _ in range(li+1):
115 | latex.pop(0)
116 |
117 | elif symbol in ['\\underline', '\\underbrace']:
118 | assert latex[0] == '{', 'CASE 3 above not with {'
119 | li = findnextbracket(latex, leftbracket='{')
120 | sub_latex = latex[1:li]
121 | node = latex2Tree(sub_latex)
122 | cur_node.childs.append(node)
123 | cur_node.relations.append('above')
124 | for _ in range(li+1):
125 | latex.pop(0)
126 |
127 | elif symbol in ['\\xrightarrow', '\\xleftarrow']:
128 | if latex[0] == '[':
129 | li = findnextbracket(latex, leftbracket='[')
130 | sub_latex = latex[1:li]
131 | node = latex2Tree(sub_latex)
132 | cur_node.childs.append(node)
133 | cur_node.relations.append('below')
134 | for _ in range(li+1):
135 | latex.pop(0)
136 | if latex[0] == '{':
137 | li = findnextbracket(latex, leftbracket='{')
138 | sub_latex = latex[1:li]
139 | node = latex2Tree(sub_latex)
140 | cur_node.childs.append(node)
141 | cur_node.relations.append('above')
142 | for _ in range(li+1):
143 | latex.pop(0)
144 |
145 | elif symbol == '\\frac':
146 | assert latex[0] == '{', '\\frac above not with {'
147 | li = findnextbracket(latex, leftbracket='{')
148 | sub_latex = latex[1:li]
149 | node = latex2Tree(sub_latex)
150 | cur_node.childs.append(node)
151 | cur_node.relations.append('above')
152 | for _ in range(li+1):
153 | latex.pop(0)
154 | assert latex[0] == '{', '\\frac below not with {'
155 | li = findnextbracket(latex, leftbracket='{')
156 | sub_latex = latex[1:li]
157 | node = latex2Tree(sub_latex)
158 | cur_node.childs.insert(-1, node)
159 | cur_node.relations.insert(-1, 'below')
160 | for _ in range(li+1):
161 | latex.pop(0)
162 |
163 | elif symbol == '\\sqrt':
164 | if latex[0] == '[':
165 | li = findnextbracket(latex, leftbracket='[')
166 | sub_latex = latex[1:li]
167 | node = latex2Tree(sub_latex)
168 | cur_node.childs.append(node)
169 | cur_node.relations.append('leftup')
170 | for _ in range(li+1):
171 | latex.pop(0)
172 | assert latex[0] == '{', '\\sqrt inside not with {'
173 | li = findnextbracket(latex, leftbracket='{')
174 | sub_latex = latex[1:li]
175 | node = latex2Tree(sub_latex)
176 | cur_node.childs.append(node)
177 | cur_node.relations.append('inside')
178 | for _ in range(li+1):
179 | latex.pop(0)
180 |
181 | else:
182 | if len(latex) > 0 and latex[0] == '_':
183 | latex.pop(0)
184 | assert latex[0] == '{', '_ not with {'
185 | li = findnextbracket(latex, leftbracket='{')
186 | sub_latex = latex[1:li]
187 | node = latex2Tree(sub_latex)
188 | cur_node.childs.append(node)
189 | cur_node.relations.append('sub')
190 | for _ in range(li+1):
191 | latex.pop(0)
192 | if len(latex) > 0 and latex[0] == '^':
193 | latex.pop(0)
194 | assert latex[0] == '{', '^ not with {'
195 | li = findnextbracket(latex, leftbracket='{')
196 | sub_latex = latex[1:li]
197 | node = latex2Tree(sub_latex)
198 | cur_node.childs.append(node)
199 | cur_node.relations.append('sup')
200 | for _ in range(li+1):
201 | latex.pop(0)
202 |
203 | if len(latex) > 0 and latex[0] == '\\\\':
204 | latex.pop(0)
205 | relation = 'nextline'
206 | elif len(latex) > 0:
207 | relation = 'right'
208 | else:
209 | relation = 'end'
210 | node = latex2Tree(latex)
211 | cur_node.childs.append(node)
212 | cur_node.relations.append(relation)
213 |
214 | return cur_node
215 |
216 | index = 0
217 | def node2list(parent, parent_index, relation, current, gtd, initial=None):
218 | if current is None or current.x == '':
219 | return
220 | global index
221 | if initial is not None:
222 | index = 1
223 | else:
224 | index = index + 1
225 | gtd.append([current.x, index, parent, parent_index, relation])
226 | parent_index = index
227 | for child, relation in zip(current.childs, current.relations):
228 | node2list(current.x, parent_index, relation, child, gtd)
229 |
230 | def node2list_shuffle(parent, parent_index, relation, current, gtd, initial=None):
231 | if current is None:
232 | return
233 | global index
234 | if initial is not None:
235 | index = 1
236 | else:
237 | index = index + 1
238 | gtd.append([current.x, index, parent, parent_index, relation])
239 | parent_index = index
240 | zip_childs = list(zip(current.childs, current.relations))
241 | random.shuffle(zip_childs)
242 | for child, relation in zip_childs:
243 | node2list_shuffle(current.x, parent_index, relation, child, gtd)
244 |
245 | def list2node(gtd):
246 | node_list = []
247 | root = Node('root')
248 | node_list.append(root)
249 | for g in gtd:
250 | child_node = Node(g[0])
251 | node_list.append(child_node)
252 | parent_node = node_list[g[3]]
253 | parent_node.childs.append(child_node)
254 | parent_node.relations.append(g[4])
255 | return node_list[1]
256 |
257 |
258 |
259 |
260 | def tree2latex(root):
261 | symbol = root.x
262 | latex = [symbol]
263 | if symbol == '':
264 | return []
265 | elif symbol == '\\begin{matrix}':
266 | for child, relation in zip(root.childs, root.relations):
267 | if relation == 'Mstart':
268 | latex += tree2latex(child)
269 | latex.append('\\end{matrix}')
270 | elif relation == 'nextline':
271 | latex.append('\\\\')
272 | latex += tree2latex(child)
273 | else:
274 | latex += tree2latex(child)
275 |
276 | elif symbol == '\\frac':
277 | below_latex = []
278 | for child, relation in zip(root.childs, root.relations):
279 | if relation == 'below':
280 | below_latex.append('{')
281 | below_latex += tree2latex(child)
282 | below_latex.append('}')
283 | elif relation == 'above':
284 | latex.append('{')
285 | latex += tree2latex(child)
286 | latex.append('}')
287 | latex += below_latex
288 | elif relation == 'nextline':
289 | latex.append('\\\\')
290 | latex += tree2latex(child)
291 | else:
292 | latex += tree2latex(child)
293 | elif symbol in ['\\frac', '\\underline', '\\underbrace', '\\dot',
294 | '\\ddot', '\\hat', '\\check', '\\grave', '\\acute',
295 | '\\tilde', '\\breve', '\\bar', '\\vec', '\\widehat',
296 | '\\overbrace', '\\widetilde','\\overleftarrow',
297 | '\\overrightarrow','\\overline'] :
298 | for child, relation in zip(root.childs, root.relations):
299 | if relation in ['above', 'below']:
300 | latex.append('{')
301 | latex += tree2latex(child)
302 | latex.append('}')
303 | elif relation == 'nextline':
304 | latex.append('\\\\')
305 | latex += tree2latex(child)
306 | else:
307 | latex += tree2latex(child)
308 | elif symbol == '\\sqrt':
309 | for child, relation in zip(root.childs, root.relations):
310 | if relation == 'leftup':
311 | latex.append('[')
312 | latex += tree2latex(child)
313 | latex.append(']')
314 | elif relation == 'inside':
315 | latex.append('{')
316 | latex += tree2latex(child)
317 | latex.append('}')
318 | elif relation == 'nextline':
319 | latex.append('\\\\')
320 | latex += tree2latex(child)
321 | else:
322 | latex += tree2latex(child)
323 | elif symbol in ['\\xrightarrow', '\\xleftarrow']:
324 | for child, relation in zip(root.childs, root.relations):
325 | if relation == 'below':
326 | latex.append('[')
327 | latex += tree2latex(child)
328 | latex.append(']')
329 | elif relation == 'above':
330 | latex.append('{')
331 | latex += tree2latex(child)
332 | latex.append('}')
333 | elif relation == 'nextline':
334 | latex.append('\\\\')
335 | latex += tree2latex(child)
336 | else:
337 | latex += tree2latex(child)
338 | elif symbol in ['\\iint', '\\bigcup', '\\sum', '\\lim', '\\coprod']:
339 | for child, relation in zip(root.childs, root.relations):
340 | if relation == 'below':
341 | latex.append('_')
342 | latex.append('{')
343 | latex += tree2latex(child)
344 | latex.append('}')
345 | elif relation == 'above':
346 | latex.append('^')
347 | latex.append('{')
348 | latex += tree2latex(child)
349 | latex.append('}')
350 | elif relation == 'nextline':
351 | latex.append('\\\\')
352 | latex += tree2latex(child)
353 | else:
354 | latex += tree2latex(child)
355 | else:
356 | for child, relation in zip(root.childs, root.relations):
357 | if relation == 'sub':
358 | latex.append('_')
359 | latex.append('{')
360 | latex += tree2latex(child)
361 | latex.append('}')
362 | elif relation == 'sup':
363 | latex.append('^')
364 | latex.append('{')
365 | latex += tree2latex(child)
366 | latex.append('}')
367 | elif relation == 'nextline':
368 | latex.append('\\\\')
369 | latex += tree2latex(child)
370 | else:
371 | latex += tree2latex(child)
372 | return latex
373 |
374 | def relation2gtd(objects, relations, id2object, id2relation):
375 | gtd = [[] for o in objects]
376 | num_relation = len(relations[0])
377 |
378 | start_relation = np.array([0 for _ in range(num_relation)])
379 | start_relation[0] = 1
380 | relation_stack = [start_relation]
381 | parent_stack = [(len(id2object)-1, 0)]
382 | p_re = len(id2relation) - 1
383 | p_y = len(id2object)-1
384 | p_id = 0
385 |
386 | for ci, c in enumerate(objects):
387 | gtd[ci].append(id2object[c])
388 | gtd[ci].append(ci+1)
389 |
390 | find_flag = False
391 | while relation_stack != []:
392 | if relation_stack[-1][:num_relation-1].sum() > 0:
393 | for index_relation in range(num_relation):
394 | if relation_stack[-1][index_relation] != 0:
395 | p_re = index_relation
396 | p_y, p_id = parent_stack[-1]
397 | relation_stack[-1][index_relation] = 0
398 | if relation_stack[-1][:num_relation-1].sum() == 0:
399 | relation_stack.pop()
400 | parent_stack.pop()
401 | find_flag = 1
402 | break
403 | else:
404 | relation_stack.pop()
405 | parent_stack.pop()
406 |
407 | if find_flag:
408 | break
409 |
410 | if not find_flag:
411 | p_y = objects[ci-1]
412 | p_id = ci
413 | p_re = num_relation - 1
414 | gtd[ci].append(id2object[p_y])
415 | gtd[ci].append(p_id)
416 | gtd[ci].append(id2relation[p_re])
417 |
418 | relation_stack.append(relations[ci])
419 | parent_stack.append((c, ci+1))
420 |
421 | return gtd
422 |
423 |
424 | if __name__ == '__main__':
425 | latex = 'a = \\frac { x } { y } + \sqrt [ c ] { b }'
426 | tree = latex2Tree(latex.split())
427 | gtd = []
428 | #global index
429 | index = 0
430 | node2list('root', 0, 'start', tree, gtd, initial=True)
431 | print('original gtd:')
432 | for g in gtd:
433 | g = [str(item) for item in g]
434 | print('\t\t'.join(g))
435 | predict_tree = list2node(gtd)
436 | predict_latex = tree2latex(predict_tree)
437 | predict_latex = ' '.join(predict_latex)
438 |
439 | print('shuffled gtd:')
440 | gtd = []
441 | node2list_shuffle('root', 0, 'start', predict_tree, gtd, initial=True)
442 | for g in gtd:
443 | g = [str(item) for item in g]
444 | print('\t\t'.join(g))
445 | shuffle_tree = list2node(gtd)
446 | shuffle_latex = tree2latex(shuffle_tree)
447 | shuffle_latex = ' '.join(shuffle_latex)
448 | print(latex == predict_latex)
449 | print(latex)
450 | print(predict_latex)
451 | print(shuffle_latex)
452 |
453 |
454 |
455 |
456 |
--------------------------------------------------------------------------------