├── .DS_Store
├── README.md
├── codes
├── data_iterator.py
├── decoder.py
├── encoder.py
├── encoder_decoder.py
├── gtd2latex.py
├── latex2gtd.py
├── prepare_label.py
├── train.sh
├── train_wap.py
├── translate.py
└── utils.py
└── paper
└── TD_camera_v1.pdf
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JianshuZhang/TreeDecoder/e73da41ba234d01467d23b9bf0f36e1079e96c64/.DS_Store
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TreeDecoder
2 |
3 | The source codes has been released, will make it clear for those who are not familiar with deep learning and encoder-decoder models:
4 |
5 | The data will be released after it is prepared:
6 |
7 | * **Tree Decoder**: A Tree-Structured Decoder for Image-to-Markup Generation
8 |
9 | ## Citation
10 | If you find Tree Decoder useful in your research, please consider citing:
11 |
12 | @inproceedings{zhang2020treedecoder,
13 | title={A Tree-Structured Decoder for Image-to-Markup Generation},
14 | author={Zhang, Jianshu and Du, Jun and Yang, Yongxin and Song, Yi-Zhe and Wei, Si and Dai, Lirong},
15 | booktitle={ICML},
16 | pages={In Press},
17 | year={2020}
18 | }
19 |
--------------------------------------------------------------------------------
/codes/data_iterator.py:
--------------------------------------------------------------------------------
1 | import numpy
2 |
3 | import pickle as pkl
4 | import gzip
5 |
6 |
7 | def fopen(filename, mode='r'):
8 | if filename.endswith('.gz'):
9 | return gzip.open(filename, mode)
10 | return open(filename, mode)
11 |
12 | def dataIterator(feature_file,label_file,align_file,dictionary,redictionary,batch_size,batch_Imagesize,maxlen,maxImagesize):
13 |
14 | fp_feature=open(feature_file,'rb')
15 | features=pkl.load(fp_feature)
16 | fp_feature.close()
17 |
18 | fp_label=open(label_file,'rb')
19 | labels=pkl.load(fp_label)
20 | fp_label.close()
21 |
22 | fp_align=open(align_file,'rb')
23 | aligns=pkl.load(fp_align)
24 | fp_align.close()
25 |
26 | ltargets = {}
27 | rtargets = {}
28 | relations = {}
29 | lpositions = {}
30 | rpositions = {}
31 |
32 | # map word to int with dictionary
33 | for uid, label_lines in labels.items():
34 | lchar_list = []
35 | rchar_list = []
36 | relation_list = []
37 | lpos_list = []
38 | rpos_list = []
39 | for line_idx, line in enumerate(label_lines):
40 | parts = line.strip().split('\t')
41 | lchar = parts[0]
42 | lpos = parts[1]
43 | rchar = parts[2]
44 | rpos = parts[3]
45 | relation = parts[4]
46 | if dictionary.__contains__(lchar):
47 | lchar_list.append(dictionary[lchar])
48 | else:
49 | print ('a symbol not in the dictionary !! formula',uid ,'symbol', lchar)
50 | sys.exit()
51 | if dictionary.__contains__(rchar):
52 | rchar_list.append(dictionary[rchar])
53 | else:
54 | print ('a symbol not in the dictionary !! formula',uid ,'symbol', rchar)
55 | sys.exit()
56 |
57 | lpos_list.append(int(lpos))
58 | rpos_list.append(int(rpos))
59 |
60 | if line_idx != len(label_lines)-1:
61 | if redictionary.__contains__(relation):
62 | relation_list.append(redictionary[relation])
63 | else:
64 | print ('a relation not in the redictionary !! formula',uid ,'relation', relation)
65 | sys.exit()
66 | else:
67 | relation_list.append(0) # whatever which one to replace End relation
68 | ltargets[uid]=lchar_list
69 | rtargets[uid]=rchar_list
70 | relations[uid]=relation_list
71 | lpositions[uid] = lpos_list
72 | rpositions[uid] = rpos_list
73 |
74 | imageSize={}
75 | for uid,fea in features.items():
76 | if uid in ltargets:
77 | imageSize[uid]=fea.shape[1]*fea.shape[2]
78 | else:
79 | continue
80 |
81 | imageSize= sorted(imageSize.items(), key=lambda d:d[1]) # sorted by sentence length, return a list with each triple element
82 |
83 | feature_batch=[]
84 | llabel_batch=[]
85 | rlabel_batch=[]
86 | relabel_batch=[]
87 | align_batch=[]
88 | lpos_batch=[]
89 | rpos_batch=[]
90 |
91 | feature_total=[]
92 | llabel_total=[]
93 | rlabel_total=[]
94 | relabel_total=[]
95 | align_total=[]
96 | lpos_total=[]
97 | rpos_total=[]
98 |
99 | uidList=[]
100 |
101 | batch_image_size=0
102 | biggest_image_size=0
103 | i=0
104 | for uid,size in imageSize:
105 | if uid not in ltargets:
106 | continue
107 | if size>biggest_image_size:
108 | biggest_image_size=size
109 | fea=features[uid]
110 | llab=ltargets[uid]
111 | rlab=rtargets[uid]
112 | relab=relations[uid]
113 | ali=aligns[uid]
114 | lp=lpositions[uid]
115 | rp=rpositions[uid]
116 | batch_image_size=biggest_image_size*(i+1)
117 | if len(llab)>maxlen:
118 | print ('this sentence length bigger than', maxlen, 'ignore')
119 | elif size>maxImagesize:
120 | print ('this image size bigger than', maxImagesize, 'ignore')
121 | else:
122 | uidList.append(uid)
123 | if batch_image_size>batch_Imagesize or i==batch_size: # a batch is full
124 | feature_total.append(feature_batch)
125 | llabel_total.append(llabel_batch)
126 | rlabel_total.append(rlabel_batch)
127 | relabel_total.append(relabel_batch)
128 | align_total.append(align_batch)
129 | lpos_total.append(lpos_batch)
130 | rpos_total.append(rpos_batch)
131 |
132 | i=0
133 | biggest_image_size=size
134 | feature_batch=[]
135 | llabel_batch=[]
136 | rlabel_batch=[]
137 | relabel_batch=[]
138 | align_batch=[]
139 | lpos_batch=[]
140 | rpos_batch=[]
141 | feature_batch.append(fea)
142 | llabel_batch.append(llab)
143 | rlabel_batch.append(rlab)
144 | relabel_batch.append(relab)
145 | align_batch.append(ali)
146 | lpos_batch.append(lp)
147 | rpos_batch.append(rp)
148 | batch_image_size=biggest_image_size*(i+1)
149 | i+=1
150 | else:
151 | feature_batch.append(fea)
152 | llabel_batch.append(llab)
153 | rlabel_batch.append(rlab)
154 | relabel_batch.append(relab)
155 | align_batch.append(ali)
156 | lpos_batch.append(lp)
157 | rpos_batch.append(rp)
158 | i+=1
159 |
160 | # last batch
161 | feature_total.append(feature_batch)
162 | llabel_total.append(llabel_batch)
163 | rlabel_total.append(rlabel_batch)
164 | relabel_total.append(relabel_batch)
165 | align_total.append(align_batch)
166 | lpos_total.append(lpos_batch)
167 | rpos_total.append(rpos_batch)
168 |
169 | print ('total ',len(feature_total), 'batch data loaded')
170 |
171 | return list(zip(feature_total,llabel_total,rlabel_total,relabel_total,align_total,lpos_total,rpos_total)),uidList
172 |
173 |
174 | def dataIterator_test(feature_file,dictionary,redictionary,batch_size,batch_Imagesize,maxImagesize):
175 |
176 | fp_feature=open(feature_file,'rb')
177 | features=pkl.load(fp_feature)
178 | fp_feature.close()
179 |
180 | imageSize={}
181 | for uid,fea in features.items():
182 | imageSize[uid]=fea.shape[1]*fea.shape[2]
183 |
184 | imageSize= sorted(imageSize.items(), key=lambda d:d[1]) # sorted by sentence length, return a list with each triple element
185 |
186 | feature_batch=[]
187 |
188 | feature_total=[]
189 |
190 | uidList=[]
191 |
192 | batch_image_size=0
193 | biggest_image_size=0
194 | i=0
195 | for uid,size in imageSize:
196 | if size>biggest_image_size:
197 | biggest_image_size=size
198 | fea=features[uid]
199 | batch_image_size=biggest_image_size*(i+1)
200 | if size>maxImagesize:
201 | print ('this image size bigger than', maxImagesize, 'ignore')
202 | elif uid == '34_em_225':
203 | print ('this image ignore', uid)
204 | else:
205 | uidList.append(uid)
206 | if batch_image_size>batch_Imagesize or i==batch_size: # a batch is full
207 | feature_total.append(feature_batch)
208 |
209 | i=0
210 | biggest_image_size=size
211 | feature_batch=[]
212 | feature_batch.append(fea)
213 | batch_image_size=biggest_image_size*(i+1)
214 | i+=1
215 | else:
216 | feature_batch.append(fea)
217 | i+=1
218 |
219 | # last batch
220 | feature_total.append(feature_batch)
221 |
222 | print ('total ',len(feature_total), 'batch data loaded')
223 |
224 | return feature_total, uidList
225 |
--------------------------------------------------------------------------------
/codes/decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | # two layers of GRU
6 | class Gru_cond_layer(nn.Module):
7 | def __init__(self, params):
8 | super(Gru_cond_layer, self).__init__()
9 | self.fc_Wyz0 = nn.Linear(params['m'], params['n'])
10 | self.fc_Wyr0 = nn.Linear(params['m'], params['n'])
11 | self.fc_Wyh0 = nn.Linear(params['m'], params['n'])
12 | self.fc_Uhz0 = nn.Linear(params['n'], params['n'], bias=False)
13 | self.fc_Uhr0 = nn.Linear(params['n'], params['n'], bias=False)
14 | self.fc_Uhh0 = nn.Linear(params['n'], params['n'], bias=False)
15 |
16 | # attention for parent symbol
17 | self.conv_UaP = nn.Conv2d(params['D'], params['dim_attention'], kernel_size=1)
18 | self.fc_WaP = nn.Linear(params['n'], params['dim_attention'], bias=False)
19 | self.conv_QP = nn.Conv2d(1, 512, kernel_size=11, bias=False, padding=5)
20 | self.fc_UfP = nn.Linear(512, params['dim_attention'])
21 | self.fc_vaP = nn.Linear(params['dim_attention'], 1)
22 |
23 | # attention for memory
24 | # self.fc_Uamem = nn.Linear(params['m'], params['dim_attention'])
25 | # self.fc_Wamem = nn.Linear(params['D'], params['dim_attention'], bias=False)
26 | # # self.conv_Qmem = nn.Conv2d(1, 512, kernel_size=(3,1), bias=False, padding=(1,0))
27 | # # self.fc_Ufmem = nn.Linear(512, params['dim_attention'])
28 | # self.fc_vamem = nn.Linear(params['dim_attention'], 1)
29 |
30 | self.fc_Wyz1 = nn.Linear(params['D'], params['n'])
31 | self.fc_Wyr1 = nn.Linear(params['D'], params['n'])
32 | self.fc_Wyh1 = nn.Linear(params['D'], params['n'])
33 | self.fc_Uhz1 = nn.Linear(params['n'], params['n'], bias=False)
34 | self.fc_Uhr1 = nn.Linear(params['n'], params['n'], bias=False)
35 | self.fc_Uhh1 = nn.Linear(params['n'], params['n'], bias=False)
36 |
37 | # the first GRU layer
38 | self.fc_Wyz = nn.Linear(params['m'], params['n'])
39 | self.fc_Wyr = nn.Linear(params['m'], params['n'])
40 | self.fc_Wyh = nn.Linear(params['m'], params['n'])
41 |
42 | self.fc_Uhz = nn.Linear(params['n'], params['n'], bias=False)
43 | self.fc_Uhr = nn.Linear(params['n'], params['n'], bias=False)
44 | self.fc_Uhh = nn.Linear(params['n'], params['n'], bias=False)
45 |
46 | # attention for child symbol
47 | self.conv_UaC = nn.Conv2d(params['D'], params['dim_attention'], kernel_size=1)
48 | self.fc_WaC = nn.Linear(params['n'], params['dim_attention'], bias=False)
49 | self.conv_QC = nn.Conv2d(1, 512, kernel_size=11, bias=False, padding=5)
50 | self.fc_UfC = nn.Linear(512, params['dim_attention'])
51 | self.fc_vaC = nn.Linear(params['dim_attention'], 1)
52 |
53 | # the second GRU layer
54 | self.fc_Wcz = nn.Linear(params['D'], params['n'], bias=False)
55 | self.fc_Wcr = nn.Linear(params['D'], params['n'], bias=False)
56 | self.fc_Wch = nn.Linear(params['D'], params['n'], bias=False)
57 |
58 | self.fc_Uhz2 = nn.Linear(params['n'], params['n'])
59 | self.fc_Uhr2 = nn.Linear(params['n'], params['n'])
60 | self.fc_Uhh2 = nn.Linear(params['n'], params['n'])
61 |
62 | def forward(self, params, lembedding, rembedding, ly_mask=None,
63 | context=None, context_mask=None, init_state=None):
64 |
65 | n_steps = lembedding.shape[0] # seqs_y
66 | n_samples = lembedding.shape[1] # batch
67 |
68 | pctx_ = self.conv_UaC(context) # (batch,n',H,W)
69 | pctx_ = pctx_.permute(2, 3, 0, 1) # (H,W,batch,n')
70 | repctx_ = self.conv_UaP(context) # (batch,n',H,W)
71 | repctx_ = repctx_.permute(2, 3, 0, 1) # (H,W,batch,n')
72 | state_below_lz = self.fc_Wyz0(lembedding)
73 | state_below_lr = self.fc_Wyr0(lembedding)
74 | state_below_lh = self.fc_Wyh0(lembedding)
75 | state_below_z = self.fc_Wyz(rembedding)
76 | state_below_r = self.fc_Wyr(rembedding)
77 | state_below_h = self.fc_Wyh(rembedding)
78 |
79 | calpha_past = torch.zeros(n_samples, context.shape[2], context.shape[3]).cuda() # (batch,H,W)
80 | palpha_past = torch.zeros(n_samples, context.shape[2], context.shape[3]).cuda()
81 | h2t = init_state
82 | h2ts = torch.zeros(n_steps, n_samples, params['n']).cuda()
83 | h1ts = torch.zeros(n_steps, n_samples, params['n']).cuda()
84 | h01ts = torch.zeros(n_steps, n_samples, params['n']).cuda()
85 | ctCs = torch.zeros(n_steps, n_samples, params['D']).cuda()
86 | ctPs = torch.zeros(n_steps, n_samples, params['D']).cuda()
87 | cts = torch.zeros(n_steps, n_samples, 2*params['D']).cuda()
88 | calphas = (torch.zeros(n_steps, n_samples, context.shape[2], context.shape[3])).cuda()
89 | calpha_pasts = torch.zeros(n_steps, n_samples, context.shape[2], context.shape[3]).cuda()
90 | palphas = (torch.zeros(n_steps, n_samples, context.shape[2], context.shape[3])).cuda()
91 | palpha_pasts = torch.zeros(n_steps, n_samples, context.shape[2], context.shape[3]).cuda()
92 |
93 | for i in range(n_steps):
94 | h2t, h1t, h01t, ctC, ctP, ct, calpha, calpha_past, palpha, palpha_past = self._step_slice(ly_mask[i],
95 | context_mask, h2t, palpha_past, calpha_past,
96 | pctx_, repctx_, context, state_below_lz[i],
97 | state_below_lr[i], state_below_lh[i], state_below_z[i],
98 | state_below_r[i], state_below_h[i])
99 |
100 | h2ts[i] = h2t # (seqs_y,batch,n)
101 | h1ts[i] = h1t
102 | h01ts[i] = h01t
103 | ctCs[i] = ctC
104 | ctPs[i] = ctP
105 | cts[i] = ct # (seqs_y,batch,D)
106 | calphas[i] = calpha # (seqs_y,batch,H,W)
107 | calpha_pasts[i] = calpha_past # (seqs_y,batch,H,W)
108 | palphas[i] = palpha
109 | palpha_pasts[i] = palpha_past
110 | return h2ts, h1ts, h01ts, ctCs, ctPs, cts, calphas, calpha_pasts, palphas, palpha_pasts
111 |
112 | def parent_forward(self, params, lembedding, ly_mask=None, context=None, context_mask=None, init_state=None, palpha_past=None):
113 | state_below_lz = self.fc_Wyz0(lembedding)
114 | state_below_lr = self.fc_Wyr0(lembedding)
115 | state_below_lh = self.fc_Wyh0(lembedding)
116 | repctx_ = self.conv_UaP(context) # (batch,n',H,W)
117 | repctx_ = repctx_.permute(2, 3, 0, 1) # (H,W,batch,n')
118 | # if ly_mask is None:
119 | # ly_mask = torch.ones(embedding.shape[0]).cuda()
120 | if ly_mask is None:
121 | ly_mask = torch.ones(lembedding.shape[0]).cuda()
122 | h0s, ctPs, palphas, palpha_pasts = self._step_slice_parent(ly_mask, context_mask, init_state, palpha_past,
123 | repctx_, context, state_below_lz, state_below_lr, state_below_lh)
124 | return h0s, ctPs, palphas, palpha_pasts
125 |
126 | def child_forward(self, params, rembedding, ctxP, ly_mask=None,
127 | context=None, context_mask=None, init_state=None, calpha_past=None):
128 |
129 | pctx_ = self.conv_UaC(context) # (batch,n',H,W)
130 | pctx_ = pctx_.permute(2, 3, 0, 1) # (H,W,batch,n')
131 |
132 | state_below_z = self.fc_Wyz(rembedding)
133 | state_below_r = self.fc_Wyr(rembedding)
134 | state_below_h = self.fc_Wyh(rembedding)
135 |
136 | if ly_mask is None:
137 | ly_mask = torch.ones(rembedding.shape[0]).cuda()
138 | h2ts, h1ts, ctCs, ctPs, cts, calphas, calpha_pasts = \
139 | self._step_slice_child(ly_mask, context_mask, init_state, calpha_past,
140 | pctx_, context, state_below_z, state_below_r, state_below_h, ctxP)
141 | return h2ts, h1ts, ctCs, ctPs, cts, calphas, calpha_pasts
142 |
143 | # one step of two GRU layers
144 | def _step_slice(self, ly_mask, ctx_mask, h_, palpha_past_, calpha_past_,
145 | pctx_, repctx_, cc_, state_below_lz, state_below_lr, state_below_lh, state_below_z, state_below_r, state_below_h):
146 |
147 | z0 = torch.sigmoid(self.fc_Uhz0(h_) + state_below_lz) # (batch,n)
148 | r0 = torch.sigmoid(self.fc_Uhr0(h_) + state_below_lr) # (batch,n)
149 | h0_p = torch.tanh(self.fc_Uhh0(h_) * r0 + state_below_lh) # (batch,n)
150 | h0 = z0 * h_ + (1. - z0) * h0_p # (batch,n)
151 | h0 = ly_mask[:, None] * h0 + (1. - ly_mask)[:, None] * h_
152 |
153 | # attention for parent symbol
154 | query_parent = self.fc_WaP(h0)
155 | palpha_past__ = palpha_past_[:, None, :, :]
156 | cover_FP = self.conv_QP(palpha_past__).permute(2, 3, 0, 1) # (H,W,batch,n')
157 | pcover_vector = self.fc_UfP(cover_FP)
158 | pattention_score = torch.tanh(repctx_ + query_parent[None, None, :, :] + pcover_vector)
159 | palpha = self.fc_vaP(pattention_score)
160 | palpha = palpha - palpha.max()
161 | palpha = palpha.view(palpha.shape[0], palpha.shape[1], palpha.shape[2])
162 | palpha = torch.exp(palpha)
163 | if (ctx_mask is not None):
164 | palpha = palpha * ctx_mask.permute(1, 2, 0)
165 | palpha = palpha / (palpha.sum(1).sum(0)[None, None, :] + 1e-10) # (H,W,batch)
166 | palpha_past = palpha_past_ + palpha.permute(2, 0, 1) # (batch,H,W)
167 | ctP = (cc_ * palpha.permute(2, 0, 1)[:, None, :, :]).sum(3).sum(2)
168 |
169 | z01 = torch.sigmoid(self.fc_Uhz1(h0) + self.fc_Wyz1(ctP)) # zt (batch,n)
170 | r01 = torch.sigmoid(self.fc_Uhr1(h0) + self.fc_Wyr1(ctP)) # rt (batch,n)
171 | h01_p = torch.tanh(self.fc_Uhh1(h0) * r01 + self.fc_Wyh1(ctP)) # (batch,n)
172 | h01 = z01 * h0 + (1. - z01) * h01_p # (batch,n)
173 | h01 = ly_mask[:, None] * h01 + (1. - ly_mask)[:, None] * h0
174 | # the first GRU layer
175 | z1 = torch.sigmoid(self.fc_Uhz(h01) + state_below_z) # (batch,n)
176 | r1 = torch.sigmoid(self.fc_Uhr(h01) + state_below_r) # (batch,n)
177 | h1_p = torch.tanh(self.fc_Uhh(h01) * r1 + state_below_h) # (batch,n)
178 | h1 = z1 * h01 + (1. - z1) * h1_p # (batch,n)
179 | h1 = ly_mask[:, None] * h1 + (1. - ly_mask)[:, None] * h01
180 |
181 | # attention for child symbol
182 | query_child = self.fc_WaC(h1)
183 | calpha_past__ = calpha_past_[:, None, :, :] # (batch,1,H,W)
184 | cover_FC = self.conv_QC(calpha_past__).permute(2, 3, 0, 1) # (H,W,batch,n')
185 | ccover_vector = self.fc_UfC(cover_FC) # (H,W,batch,n')
186 | cattention_score = torch.tanh(pctx_ + query_child[None, None, :, :] + ccover_vector) # (H,W,batch,n')
187 | calpha = self.fc_vaC(cattention_score) # (H,W,batch,1)
188 | calpha = calpha - calpha.max()
189 | calpha = calpha.view(calpha.shape[0], calpha.shape[1], calpha.shape[2]) # (H,W,batch)
190 | calpha = torch.exp(calpha) # exp
191 | if (ctx_mask is not None):
192 | calpha = calpha * ctx_mask.permute(1, 2, 0)
193 | calpha = (calpha / calpha.sum(1).sum(0)[None, None, :] + 1e-10) # (H,W,batch)
194 | calpha_past = calpha_past_ + calpha.permute(2, 0, 1) # (batch,H,W)
195 | ctC = (cc_ * calpha.permute(2, 0, 1)[:, None, :, :]).sum(3).sum(2) # current context, (batch,D)
196 |
197 | # the second GRU layer
198 | ct = torch.cat((ctC, ctP), 1)
199 | z2 = torch.sigmoid(self.fc_Uhz2(h1) + self.fc_Wcz(ctC)) # zt (batch,n)
200 | r2 = torch.sigmoid(self.fc_Uhr2(h1) + self.fc_Wcr(ctC)) # rt (batch,n)
201 | h2_p = torch.tanh(self.fc_Uhh2(h1) * r2 + self.fc_Wch(ctC)) # (batch,n)
202 | h2 = z2 * h1 + (1. - z2) * h2_p # (batch,n)
203 | h2 = ly_mask[:, None] * h2 + (1. - ly_mask)[:, None] * h1
204 |
205 | return h2, h1, h01, ctC, ctP, ct, calpha.permute(2, 0, 1), calpha_past, palpha.permute(2, 0, 1), palpha_past
206 |
207 | def _step_slice_parent(self, ly_mask, ctx_mask, h_, palpha_past_, repctx_, cc_, state_below_lz, state_below_lr, state_below_lh):
208 | z0 = torch.sigmoid(self.fc_Uhz0(h_) + state_below_lz) # (batch,n)
209 | r0 = torch.sigmoid(self.fc_Uhr0(h_) + state_below_lr) # (batch,n)
210 | h0_p = torch.tanh(self.fc_Uhh0(h_) * r0 + state_below_lh) # (batch,n)
211 | h0 = z0 * h_ + (1. - z0) * h0_p # (batch,n)
212 | h0 = ly_mask[:, None] * h0 + (1. - ly_mask)[:, None] * h_
213 | # attention for parent symbol
214 | query_parent = self.fc_WaP(h0)
215 | palpha_past__ = palpha_past_[:, None, :, :]
216 | cover_FP = self.conv_QP(palpha_past__).permute(2, 3, 0, 1) # (H,W,batch,n')
217 | pcover_vector = self.fc_UfP(cover_FP)
218 | pattention_score = torch.tanh(repctx_ + query_parent[None, None, :, :] + pcover_vector)
219 | palpha = self.fc_vaP(pattention_score)
220 | palpha = palpha - palpha.max()
221 | palpha = palpha.view(palpha.shape[0], palpha.shape[1], palpha.shape[2])
222 | palpha = torch.exp(palpha)
223 | if (ctx_mask is not None):
224 | palpha = palpha * ctx_mask.permute(1, 2, 0)
225 | palpha = palpha / (palpha.sum(1).sum(0)[None, None, :] + 1e-10) # (H,W,batch)
226 | palpha_past = palpha_past_ + palpha.permute(2, 0, 1) # (batch,H,W)
227 | ctP = (cc_ * palpha.permute(2, 0, 1)[:, None, :, :]).sum(3).sum(2)
228 |
229 | z01 = torch.sigmoid(self.fc_Uhz1(h0) + self.fc_Wyz1(ctP)) # zt (batch,n)
230 | r01 = torch.sigmoid(self.fc_Uhr1(h0) + self.fc_Wyr1(ctP)) # rt (batch,n)
231 | h01_p = torch.tanh(self.fc_Uhh1(h0) * r01 + self.fc_Wyh1(ctP)) # (batch,n)
232 | h01 = z01 * h0 + (1. - z01) * h01_p # (batch,n)
233 | h01 = ly_mask[:, None] * h01 + (1. - ly_mask)[:, None] * h0
234 |
235 | return h01, ctP, palpha.permute(2, 0, 1), palpha_past
236 |
237 | def _step_slice_child(self, ly_mask, ctx_mask, h_, calpha_past_,
238 | pctx_, cc_, state_below_z, state_below_r, state_below_h, ctP):
239 |
240 | # the first GRU layer
241 | z1 = torch.sigmoid(self.fc_Uhz(h_) + state_below_z) # (batch,n)
242 | r1 = torch.sigmoid(self.fc_Uhr(h_) + state_below_r) # (batch,n)
243 | h1_p = torch.tanh(self.fc_Uhh(h_) * r1 + state_below_h) # (batch,n)
244 | h1 = z1 * h_ + (1. - z1) * h1_p # (batch,n)
245 | h1 = ly_mask[:, None] * h1 + (1. - ly_mask)[:, None] * h_
246 |
247 | # attention for child symbol
248 | query_child = self.fc_WaC(h1)
249 | calpha_past__ = calpha_past_[:, None, :, :] # (batch,1,H,W)
250 | cover_FC = self.conv_QC(calpha_past__).permute(2, 3, 0, 1) # (H,W,batch,n')
251 | ccover_vector = self.fc_UfC(cover_FC) # (H,W,batch,n')
252 | cattention_score = torch.tanh(pctx_ + query_child[None, None, :, :] + ccover_vector) # (H,W,batch,n')
253 | calpha = self.fc_vaC(cattention_score) # (H,W,batch,1)
254 | calpha = calpha - calpha.max()
255 | calpha = calpha.view(calpha.shape[0], calpha.shape[1], calpha.shape[2]) # (H,W,batch)
256 | calpha = torch.exp(calpha) # exp
257 | if (ctx_mask is not None):
258 | calpha = calpha * ctx_mask.permute(1, 2, 0)
259 | calpha = (calpha / calpha.sum(1).sum(0)[None, None, :] + 1e-10) # (H,W,batch)
260 | calpha_past = calpha_past_ + calpha.permute(2, 0, 1) # (batch,H,W)
261 | ctC = (cc_ * calpha.permute(2, 0, 1)[:, None, :, :]).sum(3).sum(2) # current context, (batch,D)
262 |
263 | # the second GRU layer
264 | ct = torch.cat((ctC, ctP), 1)
265 | z2 = torch.sigmoid(self.fc_Uhz2(h1) + self.fc_Wcz(ctC)) # zt (batch,n)
266 | r2 = torch.sigmoid(self.fc_Uhr2(h1) + self.fc_Wcr(ctC)) # rt (batch,n)
267 | h2_p = torch.tanh(self.fc_Uhh2(h1) * r2 + self.fc_Wch(ctC)) # (batch,n)
268 | h2 = z2 * h1 + (1. - z2) * h2_p # (batch,n)
269 | h2 = ly_mask[:, None] * h2 + (1. - ly_mask)[:, None] * h1
270 |
271 | return h2, h1, ctC, ctP, ct, calpha.permute(2, 0, 1), calpha_past
272 |
273 | # calculate probabilities
274 | class Gru_prob(nn.Module):
275 | def __init__(self, params):
276 | super(Gru_prob, self).__init__()
277 | self.fc_WctC = nn.Linear(params['D'], params['m'])
278 | self.fc_WhtC = nn.Linear(params['n'], params['m'])
279 | self.fc_WytC = nn.Linear(params['m'], params['m'])
280 | self.dropout = nn.Dropout(p=0.2)
281 | self.fc_W0C = nn.Linear(int(params['m'] / 2), params['K'])
282 | # self.fc_WctP = nn.Linear(params['D'], params['m'])
283 | self.fc_W0P = nn.Linear(int(params['m'] / 2), params['K'])
284 | self.fc_WctRe = nn.Linear(2*params['D'], params['mre'])
285 | self.fc_W0Re = nn.Linear(int(params['mre']), params['Kre'])
286 |
287 | def forward(self, ctCs, ctPs, cts, htCs, prevC, use_dropout):
288 | clogit = self.fc_WctC(ctCs) + self.fc_WhtC(htCs) + self.fc_WytC(prevC) # (seqs_y,batch,m)
289 | # maxout
290 | cshape = clogit.shape # (seqs_y,batch,m)
291 | cshape2 = int(cshape[2] / 2) # m/2
292 | cshape3 = 2
293 | clogit = clogit.view(cshape[0], cshape[1], cshape2, cshape3) # (seqs_y,batch,m) -> (seqs_y,batch,m/2,2)
294 | clogit = clogit.max(3)[0] # (seqs_y,batch,m/2)
295 | if use_dropout:
296 | clogit = self.dropout(clogit)
297 | cprob = self.fc_W0C(clogit) # (seqs_y,batch,K)
298 |
299 | plogit = self.fc_WctC(ctPs)
300 | # maxout
301 | pshape = plogit.shape # (seqs_y,batch,m)
302 | pshape2 = int(pshape[2] / 2) # m/2
303 | pshape3 = 2
304 | plogit = plogit.view(pshape[0], pshape[1], pshape2, pshape3) # (seqs_y,batch,m) -> (seqs_y,batch,m/2,2)
305 | plogit = plogit.max(3)[0] # (seqs_y,batch,m/2)
306 | if use_dropout:
307 | plogit = self.dropout(plogit)
308 | pprob = self.fc_W0P(plogit) # (seqs_y,batch,K)
309 |
310 | relogit = self.fc_WctRe(cts)
311 | if use_dropout:
312 | relogit = self.dropout(relogit)
313 | reprob = self.fc_W0Re(relogit)
314 |
315 | return cprob, pprob, reprob
316 |
--------------------------------------------------------------------------------
/codes/encoder.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | # DenseNet-B
8 | class Bottleneck(nn.Module):
9 | def __init__(self, nChannels, growthRate, use_dropout):
10 | super(Bottleneck, self).__init__()
11 | interChannels = 4 * growthRate
12 | self.bn1 = nn.BatchNorm2d(interChannels)
13 | self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False)
14 | self.bn2 = nn.BatchNorm2d(growthRate)
15 | self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, padding=1, bias=False)
16 | self.use_dropout = use_dropout
17 | self.dropout = nn.Dropout(p=0.2)
18 |
19 | def forward(self, x):
20 | out = F.relu(self.bn1(self.conv1(x)), inplace=True)
21 | if self.use_dropout:
22 | out = self.dropout(out)
23 | out = F.relu(self.bn2(self.conv2(out)), inplace=True)
24 | if self.use_dropout:
25 | out = self.dropout(out)
26 | out = torch.cat((x, out), 1)
27 | return out
28 |
29 |
30 | # single layer
31 | class SingleLayer(nn.Module):
32 | def __init__(self, nChannels, growthRate, use_dropout):
33 | super(SingleLayer, self).__init__()
34 | self.bn1 = nn.BatchNorm2d(nChannels)
35 | self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, padding=1, bias=False)
36 | self.use_dropout = use_dropout
37 | self.dropout = nn.Dropout(p=0.2)
38 |
39 | def forward(self, x):
40 | out = self.conv1(F.relu(x, inplace=True))
41 | if self.use_dropout:
42 | out = self.dropout(out)
43 | out = torch.cat((x, out), 1)
44 | return out
45 |
46 |
47 | # transition layer
48 | class Transition(nn.Module):
49 | def __init__(self, nChannels, nOutChannels, use_dropout):
50 | super(Transition, self).__init__()
51 | self.bn1 = nn.BatchNorm2d(nOutChannels)
52 | self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False)
53 | self.use_dropout = use_dropout
54 | self.dropout = nn.Dropout(p=0.2)
55 |
56 | def forward(self, x):
57 | out = F.relu(self.bn1(self.conv1(x)), inplace=True)
58 | if self.use_dropout:
59 | out = self.dropout(out)
60 | out = F.avg_pool2d(out, 2, ceil_mode=True)
61 | return out
62 |
63 |
64 | class DenseNet(nn.Module):
65 | def __init__(self, growthRate, reduction, bottleneck, use_dropout):
66 | super(DenseNet, self).__init__()
67 | nDenseBlocks = 16
68 | nChannels = 2 * growthRate
69 | self.conv1 = nn.Conv2d(1, nChannels, kernel_size=7, padding=3, stride=2, bias=False)
70 | self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck, use_dropout)
71 | nChannels += nDenseBlocks * growthRate
72 | nOutChannels = int(math.floor(nChannels * reduction))
73 | self.trans1 = Transition(nChannels, nOutChannels, use_dropout)
74 |
75 | nChannels = nOutChannels
76 | self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck, use_dropout)
77 | nChannels += nDenseBlocks * growthRate
78 | nOutChannels = int(math.floor(nChannels * reduction))
79 | self.trans2 = Transition(nChannels, nOutChannels, use_dropout)
80 |
81 | nChannels = nOutChannels
82 | self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck, use_dropout)
83 |
84 | def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck, use_dropout):
85 | layers = []
86 | for i in range(int(nDenseBlocks)):
87 | if bottleneck:
88 | layers.append(Bottleneck(nChannels, growthRate, use_dropout))
89 | else:
90 | layers.append(SingleLayer(nChannels, growthRate, use_dropout))
91 | nChannels += growthRate
92 | return nn.Sequential(*layers)
93 |
94 | def forward(self, x, x_mask):
95 | out = self.conv1(x)
96 | out_mask = x_mask[:, 0::2, 0::2]
97 | out = F.relu(out, inplace=True)
98 | out = F.max_pool2d(out, 2, ceil_mode=True)
99 | out_mask = out_mask[:, 0::2, 0::2]
100 | out = self.dense1(out)
101 | out = self.trans1(out)
102 | out_mask = out_mask[:, 0::2, 0::2]
103 | out = self.dense2(out)
104 | out = self.trans2(out)
105 | out_mask = out_mask[:, 0::2, 0::2]
106 | out = self.dense3(out)
107 | return out, out_mask
108 |
--------------------------------------------------------------------------------
/codes/encoder_decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from encoder import DenseNet
5 | from decoder import Gru_cond_layer, Gru_prob
6 | import math
7 |
8 | # create gru init state
9 | class FcLayer(nn.Module):
10 | def __init__(self, nin, nout):
11 | super(FcLayer, self).__init__()
12 | self.fc = nn.Linear(nin, nout)
13 |
14 | def forward(self, x):
15 | out = torch.tanh(self.fc(x))
16 | return out
17 |
18 |
19 | # Embedding
20 | class My_Embedding(nn.Module):
21 | def __init__(self, params):
22 | super(My_Embedding, self).__init__()
23 | self.embedding = nn.Embedding(params['K'], params['m'])
24 | self.pos_embedding = torch.zeros(params['maxlen'], params['m']).cuda()
25 | nin = params['maxlen']
26 | nout = params['m']
27 | d_model = nout
28 | for pos in range(nin):
29 | for i in range(nout//2):
30 | self.pos_embedding[pos, 2*i] = math.sin(1.*pos/(10000**(2.*i/d_model)))
31 | self.pos_embedding[pos, 2*i+1] = math.cos(1.*pos/(10000**(2.*i/d_model)))
32 | def forward(self, params, ly, lp, ry):
33 | if ly.sum() < 0.: #
34 | lemb = torch.zeros(1, params['m']).cuda() # (1,m)
35 | else:
36 | lemb = self.embedding(ly) # (seqs_y,batch,m) | (batch,m)
37 | if len(lemb.shape) == 3: # only for training stage
38 | lemb_shifted = torch.zeros([lemb.shape[0], lemb.shape[1], params['m']], dtype=torch.float32).cuda()
39 | lemb_shifted[1:] = lemb[:-1]
40 | lemb = lemb_shifted
41 |
42 | if lp.sum() < 1.: # pos=0
43 | Pemb = torch.zeros(1, params['m']).cuda() # (1,m)
44 | else:
45 | Pemb = self.pos_embedding[lp] # (seqs_y,batch,m) | (batch,m)
46 | if len(Pemb.shape) == 3: # only for training stage
47 | Pemb_shifted = torch.zeros([Pemb.shape[0], Pemb.shape[1], params['m']], dtype=torch.float32).cuda()
48 | Pemb_shifted[1:] = Pemb[:-1]
49 | Pemb = Pemb_shifted
50 |
51 | if ry.sum() < 0.: #
52 | remb = torch.zeros(1, params['m']).cuda() # (1,m)
53 | else:
54 | remb = self.embedding(ry) # (seqs_y,batch,m) | (batch,m)
55 | if len(remb.shape) == 3: # only for training stage
56 | remb_shifted = torch.zeros([remb.shape[0], remb.shape[1], params['m']], dtype=torch.float32).cuda()
57 | remb_shifted[1:] = remb[1:]
58 | remb = remb_shifted
59 | return lemb, Pemb, remb
60 | def word_emb(self, params, y):
61 | if y.sum() < 0.: #
62 | emb = torch.zeros(1, params['m']).cuda() # (1,m)
63 | else:
64 | emb = self.embedding(y) # (seqs_y,batch,m) | (batch,m)
65 | return emb
66 | def pos_emb(self, params, p):
67 | if p.sum() < 1.: #
68 | Pemb = torch.zeros(1, params['m']).cuda() # (1,m)
69 | else:
70 | Pemb = self.pos_embedding[p] # (seqs_y,batch,m) | (batch,m)
71 | return Pemb
72 |
73 | class Encoder_Decoder(nn.Module):
74 | def __init__(self, params):
75 | super(Encoder_Decoder, self).__init__()
76 | self.encoder = DenseNet(growthRate=params['growthRate'], reduction=params['reduction'],
77 | bottleneck=params['bottleneck'], use_dropout=params['use_dropout'])
78 | self.init_GRU_model = FcLayer(params['D'], params['n'])
79 | self.emb_model = My_Embedding(params)
80 | self.gru_model = Gru_cond_layer(params)
81 | self.gru_prob_model = Gru_prob(params)
82 | self.fc_Uamem = nn.Linear(params['n'], params['dim_attention'])
83 | self.fc_Wamem = nn.Linear(params['n'], params['dim_attention'], bias=False)
84 | # self.conv_Qmem = nn.Conv2d(1, 512, kernel_size=(3,1), bias=False, padding=(1,0))
85 | # self.fc_Ufmem = nn.Linear(512, params['dim_attention'])
86 | self.fc_vamem = nn.Linear(params['dim_attention'], 1)
87 | self.criterion = torch.nn.CrossEntropyLoss(reduce=False)
88 |
89 | def forward(self, params, x, x_mask, ly, ly_mask, ry, ry_mask, re, re_mask, ma, ma_mask, lp, rp, one_step=False):
90 | # recover permute
91 | # ly = ly.permute(1, 0)
92 | # ly_mask = ly_mask.permute(1, 0)
93 | # ly = ly.permute(1, 0)
94 | # ly_mask = ly_mask.permute(1, 0)
95 |
96 | ma = ma.permute(2, 1, 0) # SeqY * Matt * batch
97 | ma_mask = ma_mask.permute(2, 1, 0)
98 |
99 | ctx, ctx_mask = self.encoder(x, x_mask)
100 |
101 | # init state
102 | ctx_mean = (ctx * ctx_mask[:, None, :, :]).sum(3).sum(2) / ctx_mask.sum(2).sum(1)[:, None] # (batch,D)
103 | init_state = self.init_GRU_model(ctx_mean) # (batch,n)
104 |
105 | # two GRU layers
106 | lemb, Pemb, remb = self.emb_model(params, ly, lp, ry) # (seqs_y,batch,m)
107 | # h2ts: (seqs_y,batch,n), cts: (seqs_y,batch,D), alphas: (seqs_y,batch,H,W)
108 | h2ts, h1ts, h01ts, ctCs, ctPs, cts, calphas, calpha_pasts, palphas, palpha_pasts= \
109 | self.gru_model(params, lemb, remb, ly_mask, ctx, ctx_mask, init_state=init_state)
110 |
111 | word_pos_memory = torch.cat((init_state[None, :, :], h2ts[:-1]), 0)
112 | # word_pos_memory = lemb + Pemb
113 | mempctx_ = self.fc_Uamem(word_pos_memory)
114 | memquery = self.fc_Wamem(h01ts)
115 | memattention_score = torch.tanh(mempctx_[None, :, :, :] + memquery[:, None, :, :])
116 | memalpha = self.fc_vamem(memattention_score)
117 | memalpha = memalpha - memalpha.max()
118 | memalpha = memalpha.view(memalpha.shape[0], memalpha.shape[1], memalpha.shape[2]) # SeqY * Matt * batch
119 | memalpha = torch.exp(memalpha)
120 | memalpha = memalpha * ma_mask # SeqY * Matt * batch
121 | memalpha = memalpha / (memalpha.sum(1)[:, None, :] + 1e-10)
122 | memalphas = memalpha + 1e-10
123 | cost_memalphas = - ma.float() * torch.log(memalphas) * ma_mask
124 | loss_memalphas = cost_memalphas.sum((0,1))
125 |
126 | # compute KL alpha
127 | calpha_sort_ = torch.cat((torch.zeros(1, calphas.shape[1], calphas.shape[2], calphas.shape[3]).cuda(), calphas), 0)
128 | n_gaps = calpha_sort_.shape[0]
129 | n_batch = calpha_sort_.shape[1]
130 | n_H = calpha_sort_.shape[2]
131 | n_W = calpha_sort_.shape[3]
132 | rp = rp.permute(1,0) # batch * SeqY
133 | rp_shape = rp.shape
134 | rp = rp + n_gaps * torch.arange(n_batch)[:, None].cuda()
135 |
136 | calpha_sort = calpha_sort_.permute(1,0,2,3)
137 | calpha_sort = torch.reshape(calpha_sort, (calpha_sort.shape[0]*calpha_sort.shape[1],calpha_sort.shape[2],calpha_sort.shape[3]))
138 | calpha_sort = calpha_sort[rp.flatten()]
139 | calpha_sort = torch.reshape(calpha_sort, (rp_shape[0], rp_shape[1], n_H, n_W))
140 | calpha_sort = calpha_sort.permute(1,0,2,3)
141 |
142 | calpha_sort = calpha_sort + 1e-10
143 | palphas = palphas + 1e-10
144 | cost_KL_alpha = calpha_sort * (torch.log(calpha_sort)-torch.log(palphas)) * ctx_mask[None, :, :, :]
145 | loss_KL = cost_KL_alpha.sum((0,2,3))
146 |
147 | cscores, pscores, rescores = self.gru_prob_model(ctCs, ctPs, cts, h2ts, remb, use_dropout=params['use_dropout']) # (seqs_y x batch,K)
148 |
149 | cscores = cscores.contiguous()
150 | cscores = cscores.view(-1, cscores.shape[2])
151 | ly = ly.contiguous()
152 | lpred_loss = self.criterion(cscores, ly.view(-1)) # (seqs_y x batch,)
153 | lpred_loss = lpred_loss.view(ly.shape[0], ly.shape[1]) # (seqs_y,batch)
154 | lpred_loss = (lpred_loss * ly_mask).sum(0) / (ly_mask.sum(0)+1e-10)
155 | lpred_loss = lpred_loss.mean()
156 |
157 | pscores = pscores.contiguous()
158 | pscores = pscores.view(-1, pscores.shape[2])
159 | ry = ry.contiguous()
160 | rpred_loss = self.criterion(pscores, ry.view(-1))
161 | rpred_loss = rpred_loss.view(ry.shape[0], ry.shape[1])
162 | rpred_loss = (rpred_loss * ry_mask).sum(0) / (ry_mask.sum(0)+1e-10)
163 | rpred_loss = rpred_loss.mean()
164 |
165 | rescores = rescores.contiguous()
166 | rescores = rescores.view(-1, rescores.shape[2])
167 | re = re.contiguous()
168 | repred_loss = self.criterion(rescores, re.view(-1))
169 | repred_loss = repred_loss.view(re.shape[0], re.shape[1])
170 | repred_loss = (repred_loss * re_mask).sum(0) / (re_mask.sum(0)+1e-10)
171 | repred_loss = repred_loss.mean()
172 |
173 | mem_loss = loss_memalphas / (ly_mask.sum(0)+1e-10)
174 | mem_loss = mem_loss.mean()
175 |
176 | KL_loss = loss_KL / (ly_mask.sum(0)+1e-10)
177 | KL_loss = KL_loss.mean()
178 |
179 | loss = params['ly_lambda']*lpred_loss + params['ry_lambda']*rpred_loss + \
180 | params['re_lambda']*repred_loss + params['rpos_lambda']*mem_loss + params['KL_lambda']*KL_loss
181 |
182 | return loss, lpred_loss, rpred_loss, repred_loss, mem_loss, KL_loss
183 |
184 | # decoding: encoder part
185 | def f_init(self, x, x_mask=None):
186 | if x_mask is None: # x_mask is actually no use here
187 | shape = x.shape
188 | x_mask = torch.ones(shape).cuda()
189 | ctx, _ctx_mask = self.encoder(x, x_mask)
190 | ctx_mean = ctx.mean(dim=3).mean(dim=2)
191 | init_state = self.init_GRU_model(ctx_mean) # (1,n)
192 | return init_state, ctx
193 |
194 | def f_next_parent(self, params, ly, lp, ctx, init_state, h1t, palpha_past, nextemb_memory, nextePmb_memory, initIdx):
195 | emb = self.emb_model.word_emb(params, ly)
196 | # Pemb = self.emb_model.pos_emb(params, lp)
197 | nextemb_memory[initIdx, :, :] = emb
198 | # ePmb_memory_ = emb + Pemb
199 | nextePmb_memory[initIdx, :, :] = init_state
200 |
201 | h01, ctP, palpha, next_palpha_past = self.gru_model.parent_forward(params, emb, context=ctx, init_state=init_state, palpha_past=palpha_past)
202 |
203 | mempctx_ = self.fc_Uamem(nextePmb_memory)
204 | memquery = self.fc_Wamem(h01)
205 | memattention_score = torch.tanh(mempctx_ + memquery[None, :, :])
206 | memalpha = self.fc_vamem(memattention_score)
207 | memalpha = memalpha - memalpha.max()
208 | memalpha = memalpha.view(memalpha.shape[0], memalpha.shape[1]) # Matt * batch
209 | memalpha = torch.exp(memalpha)
210 | mem_mask = torch.zeros(nextePmb_memory.shape[0], nextePmb_memory.shape[1]).cuda()
211 | mem_mask[:(initIdx+1), :] = 1
212 | memalpha = memalpha * mem_mask # Matt * batch
213 | memalpha = memalpha / (memalpha.sum(0) + 1e-10)
214 |
215 | Pmemalpha = memalpha.view(-1, memalpha.shape[1])
216 | Pmemalpha = Pmemalpha.permute(1, 0) # batch * Matt
217 | return h01, Pmemalpha, ctP, palpha, next_palpha_past, nextemb_memory, nextePmb_memory
218 |
219 | # decoding: decoder part
220 | def f_next_child(self, params, remb, ctP, ctx, init_state, calpha_past):
221 |
222 | next_state, h1t, ctC, ctP, ct, calpha, next_calpha_past = \
223 | self.gru_model.child_forward(params, remb, ctP, context=ctx, init_state=init_state, calpha_past=calpha_past)
224 |
225 | # reshape to suit GRU step code
226 | h2te = next_state.view(1, next_state.shape[0], next_state.shape[1])
227 | ctC = ctC.view(1, ctC.shape[0], ctC.shape[1])
228 | ctP = ctP.view(1, ctP.shape[0], ctP.shape[1])
229 | ct = ct.view(1, ct.shape[0], ct.shape[1])
230 |
231 | # calculate probabilities
232 | cscores, pscores, rescores = self.gru_prob_model(ctC, ctP, ct, h2te, remb, use_dropout=params['use_dropout'])
233 | cscores = cscores.view(-1, cscores.shape[2])
234 | next_lprobs = F.softmax(cscores, dim=1)
235 | rescores = rescores.view(-1, rescores.shape[2])
236 | next_reprobs = F.softmax(rescores, dim=1)
237 | next_re = torch.argmax(next_reprobs, dim=1)
238 |
239 | return next_lprobs, next_reprobs, next_state, h1t, calpha, next_calpha_past, next_re
240 |
--------------------------------------------------------------------------------
/codes/gtd2latex.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import os
4 | import sys
5 | import pickle as pkl
6 | import numpy
7 |
8 |
9 | def convert(nodeid, gtd_list):
10 | isparent = False
11 | child_list = []
12 | for i in range(len(gtd_list)):
13 | if gtd_list[i][2] == nodeid:
14 | isparent = True
15 | child_list.append([gtd_list[i][0],gtd_list[i][1],gtd_list[i][3]])
16 | if not isparent:
17 | return [gtd_list[nodeid][0]]
18 | else:
19 | if gtd_list[nodeid][0] == '\\frac':
20 | return_string = [gtd_list[nodeid][0]]
21 | for i in range(len(child_list)):
22 | if child_list[i][2] == 'Above':
23 | return_string += ['{'] + convert(child_list[i][1], gtd_list) + ['}']
24 | for i in range(len(child_list)):
25 | if child_list[i][2] == 'Below':
26 | return_string += ['{'] + convert(child_list[i][1], gtd_list) + ['}']
27 | for i in range(len(child_list)):
28 | if child_list[i][2] == 'Right':
29 | return_string += convert(child_list[i][1], gtd_list)
30 | for i in range(len(child_list)):
31 | if child_list[i][2] not in ['Right','Above','Below']:
32 | return_string += ['illegal']
33 | else:
34 | return_string = [gtd_list[nodeid][0]]
35 | for i in range(len(child_list)):
36 | if child_list[i][2] == 'Inside':
37 | return_string += ['{'] + convert(child_list[i][1], gtd_list) + ['}']
38 | for i in range(len(child_list)):
39 | if child_list[i][2] in ['Sub','Below']:
40 | return_string += ['_','{'] + convert(child_list[i][1], gtd_list) + ['}']
41 | for i in range(len(child_list)):
42 | if child_list[i][2] in ['Sup','Above']:
43 | return_string += ['^','{'] + convert(child_list[i][1], gtd_list) + ['}']
44 | for i in range(len(child_list)):
45 | if child_list[i][2] in ['Right']:
46 | return_string += convert(child_list[i][1], gtd_list)
47 | return return_string
48 |
49 | latex_root_path = '***'
50 | gtd_root_path = '***'
51 | gtd_paths = ['test_caption_14','test_caption_16','test_caption_19']
52 |
53 | for gtd_path in gtd_paths:
54 | gtd_files = os.listdir(gtd_root_path + gtd_path + '/')
55 | f_out = open(latex_root_path + gtd_path + '.txt', 'w')
56 | for process_num, gtd_file in enumerate(gtd_files):
57 | # gtd_file = '510_em_101.gtd'
58 | key = gtd_file[:-4] # remove .gtd
59 | f_out.write(key + '\t')
60 | gtd_list = []
61 | gtd_list.append(['',0,-1,'root'])
62 | with open(gtd_root_path + gtd_path + '/' + gtd_file) as f:
63 | lines = f.readlines()
64 | for line in lines[:-1]:
65 | parts = line.split()
66 | sym = parts[0]
67 | childid = int(parts[1])
68 | parentid = int(parts[3])
69 | relation = parts[4]
70 | gtd_list.append([sym,childid,parentid,relation])
71 | latex_list = convert(1, gtd_list)
72 | if 'illegal' in latex_list:
73 | print (key + ' has error')
74 | latex_string = ' '
75 | else:
76 | latex_string = ' '.join(latex_list)
77 | f_out.write(latex_string + '\n')
78 | # sys.exit()
79 |
80 | if (process_num+1) // 2000 == (process_num+1) * 1.0 / 2000:
81 | print ('process files', process_num)
82 |
--------------------------------------------------------------------------------
/codes/latex2gtd.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import os
4 | import sys
5 | import pickle as pkl
6 | import numpy
7 |
8 | latex_root_path = ''
9 | gtd_root_path = ''
10 | latex_files = ['test_caption_16.txt','test_caption_19.txt','valid_data_v1.txt','train_data_v1.txt','test_caption_14.txt']
11 | for latexF in latex_files:
12 | latex_file = latex_root_path + latexF
13 | gtd_path = gtd_root_path + latexF[:-4] + '/'
14 | if not os.path.exists(gtd_path):
15 | os.mkdir(gtd_path)
16 |
17 | with open(latex_file) as f:
18 | lines = f.readlines()
19 | for process_num, line in enumerate(lines):
20 | if '\sqrt [' in line:
21 | continue
22 | parts = line.split()
23 | if len(parts) < 2:
24 | print ('error: invalid latex caption ...', line)
25 | continue
26 | key = parts[0]
27 | f_out = open(gtd_path + key + '.gtd', 'w')
28 | raw_cap = parts[1:]
29 | cap = []
30 | for w in raw_cap:
31 | if w not in ['\limits']:
32 | cap.append(w)
33 | gtd_stack = []
34 | idx = 0
35 | outidx = 1
36 | error_flag = False
37 | while idx < len(cap):
38 | if idx == 0:
39 | if cap[0] in ['{','}']:
40 | print ('error: {} should NOT appears at START')
41 | print (line.strip())
42 | sys.exit()
43 | string = cap[0] + '\t' + str(outidx) + '\t\t0\tStart'
44 | f_out.write(string + '\n')
45 | idx += 1
46 | outidx += 1
47 | else:
48 | if cap[idx] == '{':
49 | if cap[idx-1] == '{':
50 | print ('error: double { appears')
51 | print (line.strip())
52 | sys.exit()
53 | elif cap[idx-1] == '}':
54 | if gtd_stack[-1][0] != '\\frac':
55 | print ('error: } { not follows frac ...', key)
56 | f_out.close()
57 | os.system('rm ' + gtd_path + key + '.gtd')
58 | error_flag = True
59 | break
60 | else:
61 | gtd_stack[-1][2] = 'Below'
62 | idx += 1
63 | else:
64 | if cap[idx-1] == '\\frac':
65 | gtd_stack.append([cap[idx-1],str(outidx-1),'Above'])
66 | idx += 1
67 | elif cap[idx-1] == '\\sqrt':
68 | gtd_stack.append([cap[idx-1],str(outidx-1),'Inside'])
69 | idx += 1
70 | elif cap[idx-1] == '_':
71 | if cap[idx-2] in ['_','^','\\frac','\\sqrt']:
72 | print ('error: ^ _ follows wrong math symbols')
73 | print (line.strip())
74 | sys.exit()
75 | elif cap[idx-2] in ['\\sum','\\int','\\lim']:
76 | gtd_stack.append([cap[idx-2],str(outidx-1),'Below'])
77 | idx += 1
78 | elif cap[idx-2] == '}':
79 | if gtd_stack[-1][0] in ['\\sum','\\int','\\lim']:
80 | gtd_stack[-1][2] = 'Below'
81 | else:
82 | gtd_stack[-1][2] = 'Sub'
83 | idx += 1
84 | else:
85 | gtd_stack.append([cap[idx-2],str(outidx-1),'Sub'])
86 | idx += 1
87 | elif cap[idx-1] == '^':
88 | if cap[idx-2] in ['_','^','\\frac','\\sqrt']:
89 | print ('error: ^ _ follows wrong math symbols')
90 | print (line.strip())
91 | sys.exit()
92 | elif cap[idx-2] in ['\\sum','\\int','\\lim']:
93 | gtd_stack.append([cap[idx-2],str(outidx-1),'Above'])
94 | idx += 1
95 | elif cap[idx-2] == '}':
96 | if gtd_stack[-1][0] in ['\\sum','\\int','\\lim']:
97 | gtd_stack[-1][2] = 'Above'
98 | else:
99 | gtd_stack[-1][2] = 'Sup'
100 | idx += 1
101 | else:
102 | gtd_stack.append([cap[idx-2],str(outidx-1),'Sup'])
103 | idx += 1
104 | else:
105 | print ('error: { follows unknown math symbols ...', key)
106 | f_out.close()
107 | os.system('rm ' + gtd_path + key + '.gtd')
108 | error_flag = True
109 | break
110 | elif cap[idx] == '}':
111 | if cap[idx-1] == '}':
112 | del(gtd_stack[-1])
113 | idx += 1
114 | elif cap[idx] in ['_','^']:
115 | if idx == len(cap)-1:
116 | print ('error: ^ _ appers at end ...', key)
117 | f_out.close()
118 | os.system('rm ' + gtd_path + key + '.gtd')
119 | error_flag = True
120 | break
121 | if cap[idx+1] != '{':
122 | print ('error: ^ _ not follows { ...', key)
123 | f_out.close()
124 | os.system('rm ' + gtd_path + key + '.gtd')
125 | error_flag = True
126 | break
127 | else:
128 | idx += 1
129 | elif cap[idx] in ['\limits']:
130 | print ('error: \limits happens')
131 | print (line.strip())
132 | sys.exit()
133 | else:
134 | if cap[idx-1] == '{':
135 | string = cap[idx] + '\t' + str(outidx) + '\t' + gtd_stack[-1][0] + '\t' + gtd_stack[-1][1] + '\t' + gtd_stack[-1][2]
136 | f_out.write(string + '\n')
137 | outidx += 1
138 | idx += 1
139 | elif cap[idx-1] == '}':
140 | string = cap[idx] + '\t' + str(outidx) + '\t' + gtd_stack[-1][0] + '\t' + gtd_stack[-1][1] + '\tRight'
141 | f_out.write(string + '\n')
142 | outidx += 1
143 | idx += 1
144 | del(gtd_stack[-1])
145 | else:
146 | parts = string.split('\t')
147 | string = cap[idx] + '\t' + str(outidx) + '\t' + parts[0] + '\t' + parts[1] + '\tRight'
148 | f_out.write(string + '\n')
149 | outidx += 1
150 | idx += 1
151 | if not error_flag:
152 | parts = string.split('\t')
153 | string = '\t' + str(outidx) + '\t' + parts[0] + '\t' + parts[1] + '\tEnd'
154 | f_out.write(string + '\n')
155 | f_out.close()
156 |
157 | if (process_num+1) // 1000 == (process_num+1) * 1.0 / 1000:
158 | print ('process files', process_num)
159 |
--------------------------------------------------------------------------------
/codes/prepare_label.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import os
4 | import sys
5 | import pickle as pkl
6 | import numpy
7 | from scipy.misc import imread, imresize, imsave
8 |
9 |
10 | def gen_gtd_label():
11 |
12 | bfs1_path = ''
13 | gtd_root_path = ''
14 | gtd_paths = ['']
15 |
16 | for gtd_path in gtd_paths:
17 | outpkl_label_file = bfs1_path + gtd_path + '_label_gtd.pkl'
18 | out_label_fp = open(outpkl_label_file, 'wb')
19 | label_lines = {}
20 | process_num = 0
21 |
22 | file_list = os.listdir(gtd_root_path + gtd_path)
23 | for file_name in file_list:
24 | key = file_name[:-4] # remove suffix .gtd
25 | if key in ['fa66375ede8be1c192a1acc2bc62b575.jpg']:
26 | continue
27 | with open(gtd_root_path + gtd_path + '/' + file_name) as f:
28 | lines = f.readlines()
29 | label_strs = []
30 | for line in lines:
31 | parts = line.strip().split('\t')
32 | if len(parts) == 5:
33 | sym = parts[0]
34 | align = parts[1]
35 | related_sym = parts[2]
36 | realign = parts[3]
37 | relation = parts[4]
38 | string = sym + '\t' + align + '\t' + related_sym + '\t' + realign + '\t' + relation
39 | label_strs.append(string)
40 | else:
41 | print ('illegal line', key)
42 | sys.exit()
43 | label_lines[key] = label_strs
44 |
45 | process_num = process_num + 1
46 | if process_num // 2000 == process_num * 1.0 / 2000:
47 | print ('process files', process_num)
48 |
49 | print ('process files number ', process_num)
50 |
51 | pkl.dump(label_lines, out_label_fp)
52 | print ('save file done')
53 | out_label_fp.close()
54 |
55 |
56 | def gen_gtd_align():
57 |
58 | bfs1_path = ''
59 | gtd_root_path = ''
60 | gtd_paths = ['']
61 |
62 | for gtd_path in gtd_paths:
63 | outpkl_label_file = bfs1_path + gtd_path + '_label_align_gtd.pkl'
64 | out_label_fp = open(outpkl_label_file, 'wb')
65 | label_aligns = {}
66 | process_num = 0
67 |
68 | file_list = os.listdir(gtd_root_path + gtd_path)
69 | for file_name in file_list:
70 | key = file_name[:-4] # remove suffix .gtd
71 | if key in ['fa66375ede8be1c192a1acc2bc62b575.jpg']:
72 | continue
73 | with open(gtd_root_path + gtd_path + '/' + file_name) as f:
74 | lines = f.readlines()
75 | wordNum = len(lines)
76 | align = numpy.zeros([wordNum, wordNum], dtype='int8')
77 | wordindex = -1
78 |
79 | for line in lines:
80 | wordindex += 1
81 | parts = line.strip().split('\t')
82 | if len(parts) == 5:
83 | realign = parts[3]
84 | realign_index = int(realign)
85 | align[realign_index,wordindex] = 1
86 | else:
87 | print ('illegal line', key)
88 | sys.exit()
89 | label_aligns[key] = align
90 |
91 | process_num = process_num + 1
92 | if process_num // 2000 == process_num * 1.0 / 2000:
93 | print ('process files', process_num)
94 |
95 | print ('process files number ', process_num)
96 |
97 | pkl.dump(label_aligns, out_label_fp)
98 | print ('save file done')
99 | out_label_fp.close()
100 |
101 | if __name__ == '__main__':
102 | gen_gtd_label()
103 | gen_gtd_align()
--------------------------------------------------------------------------------
/codes/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # use CUDA
4 | export CUDA_VISIBLE_DEVICES=0
5 | # source ~/.bashrc
6 | python -u train_wap.py
7 |
--------------------------------------------------------------------------------
/codes/train_wap.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | import re
4 | import numpy as np
5 | import random
6 | import torch
7 | from torch import optim, nn
8 | from utils import BatchBucket, load_dict, prepare_data, gen_sample, weight_init, compute_wer, compute_sacc
9 | from encoder_decoder import Encoder_Decoder
10 | from data_iterator import dataIterator
11 |
12 | # whether use multi-GPUs
13 | multi_gpu_flag = False
14 | # whether init params
15 | init_param_flag = True
16 | # whether reload params
17 | reload_flag = False
18 |
19 | # load configurations
20 | # root_paths
21 | bfs2_path = ''
22 | work_path = ''
23 |
24 | model_idx = 7
25 | # paths
26 | dictionaries = [bfs2_path + '106_dictionary.txt', bfs2_path + '7relation_dictionary.txt']
27 | datasets = [bfs2_path + 'jiaming-train-py3.pkl', bfs2_path + 'train_data_v1_label_gtd.pkl', bfs2_path + 'train_data_v1_label_align_gtd.pkl']
28 | valid_datasets = [bfs2_path + 'jiaming-valid-py3.pkl', bfs2_path + 'valid_data_v1_label_gtd.pkl', bfs2_path + 'valid_data_v1_label_align_gtd.pkl']
29 | valid_output = [work_path+'results'+str(model_idx)+'/symbol_relation/', work_path+'results'+str(model_idx)+'/memory_alpha/']
30 | valid_result = [work_path+'results'+str(model_idx)+'/valid.cer', work_path+'results'+str(model_idx)+'/valid.exprate']
31 | saveto = work_path+'models'+str(model_idx)+'/WAP_params.pkl'
32 | last_saveto = work_path+'models'+str(model_idx)+'/WAP_params_last.pkl'
33 |
34 | # training settings
35 | if multi_gpu_flag:
36 | batch_Imagesize = 500000
37 | valid_batch_Imagesize = 500000
38 | batch_size = 24
39 | valid_batch_size = 24
40 | else:
41 | batch_Imagesize = 500000
42 | valid_batch_Imagesize = 500000
43 | batch_size = 8
44 | valid_batch_size = 8
45 | maxImagesize = 500000
46 | maxlen = 200
47 | max_epochs = 5000
48 | lrate = 1.0
49 | my_eps = 1e-6
50 | decay_c = 1e-4
51 | clip_c = 100.
52 |
53 | # early stop
54 | estop = False
55 | halfLrFlag = 0
56 | bad_counter = 0
57 | patience = 15
58 | validStart = 10
59 | finish_after = 10000000
60 |
61 | # model architecture
62 | params = {}
63 | params['n'] = 256
64 | params['m'] = 256
65 | params['dim_attention'] = 512
66 | params['D'] = 684
67 | params['K'] = 106
68 |
69 | params['Kre'] = 7
70 | params['mre'] = 256
71 | params['maxlen'] = maxlen
72 |
73 | params['growthRate'] = 24
74 | params['reduction'] = 0.5
75 | params['bottleneck'] = True
76 | params['use_dropout'] = True
77 | params['input_channels'] = 1
78 |
79 | params['ly_lambda'] = 1.
80 | params['ry_lambda'] = 0.1
81 | params['re_lambda'] = 1.
82 | params['rpos_lambda'] = 1.
83 | params['KL_lambda'] = 0.1
84 |
85 | # load dictionary
86 | worddicts = load_dict(dictionaries[0])
87 | print ('total chars',len(worddicts))
88 | worddicts_r = [None] * len(worddicts)
89 | for kk, vv in worddicts.items():
90 | worddicts_r[vv] = kk
91 |
92 | reworddicts = load_dict(dictionaries[1])
93 | print ('total relations',len(reworddicts))
94 | reworddicts_r = [None] * len(reworddicts)
95 | for kk, vv in reworddicts.items():
96 | reworddicts_r[vv] = kk
97 |
98 | train,train_uid_list = dataIterator(datasets[0], datasets[1], datasets[2], worddicts, reworddicts,
99 | batch_size=batch_size, batch_Imagesize=batch_Imagesize,maxlen=maxlen,maxImagesize=maxImagesize)
100 | valid,valid_uid_list = dataIterator(valid_datasets[0], valid_datasets[1], valid_datasets[2], worddicts, reworddicts,
101 | batch_size=valid_batch_size, batch_Imagesize=valid_batch_Imagesize,maxlen=maxlen,maxImagesize=maxImagesize)
102 | # display
103 | uidx = 0 # count batch
104 | lpred_loss_s = 0. # count loss
105 | rpred_loss_s = 0.
106 | repred_loss_s = 0.
107 | mem_loss_s = 0.
108 | KL_loss_s = 0.
109 | loss_s = 0.
110 | ud_s = 0 # time for training an epoch
111 | validFreq = -1
112 | saveFreq = -1
113 | sampleFreq = -1
114 | dispFreq = 100
115 | if validFreq == -1:
116 | validFreq = len(train)
117 | if saveFreq == -1:
118 | saveFreq = len(train)
119 | if sampleFreq == -1:
120 | sampleFreq = len(train)
121 |
122 | # initialize model
123 | WAP_model = Encoder_Decoder(params)
124 | if init_param_flag:
125 | WAP_model.apply(weight_init)
126 | if multi_gpu_flag:
127 | WAP_model = nn.DataParallel(WAP_model, device_ids=[0, 1, 2, 3])
128 | if reload_flag:
129 | WAP_model.load_state_dict(torch.load(saveto,map_location=lambda storage,loc:storage))
130 | WAP_model.cuda()
131 |
132 | # print model's parameters
133 | model_params = WAP_model.named_parameters()
134 | for k, v in model_params:
135 | print(k)
136 |
137 | # loss function
138 | # criterion = torch.nn.CrossEntropyLoss(reduce=False)
139 | # optimizer
140 | optimizer = optim.Adadelta(WAP_model.parameters(), lr=lrate, eps=my_eps, weight_decay=decay_c)
141 |
142 | print('Optimization')
143 |
144 | # statistics
145 | history_errs = []
146 |
147 | for eidx in range(max_epochs):
148 | n_samples = 0
149 | ud_epoch = time.time()
150 | random.shuffle(train)
151 | for x, ly, ry, re, ma, lp, rp in train:
152 | WAP_model.train()
153 | ud_start = time.time()
154 | n_samples += len(x)
155 | uidx += 1
156 | x, x_mask, ly, ly_mask, ry, ry_mask, re, re_mask, ma, ma_mask, lp, rp = \
157 | prepare_data(params, x, ly, ry, re, ma, lp, rp)
158 |
159 | x = torch.from_numpy(x).cuda() # (batch,1,H,W)
160 | x_mask = torch.from_numpy(x_mask).cuda() # (batch,H,W)
161 | ly = torch.from_numpy(ly).cuda() # (seqs_y,batch)
162 | ly_mask = torch.from_numpy(ly_mask).cuda() # (seqs_y,batch)
163 | ry = torch.from_numpy(ry).cuda() # (seqs_y,batch)
164 | ry_mask = torch.from_numpy(ry_mask).cuda() # (seqs_y,batch)
165 | re = torch.from_numpy(re).cuda() # (seqs_y,batch)
166 | re_mask = torch.from_numpy(re_mask).cuda() # (seqs_y,batch)
167 | ma = torch.from_numpy(ma).cuda() # (batch,seqs_y,seqs_y)
168 | ma_mask = torch.from_numpy(ma_mask).cuda() # (batch,seqs_y,seqs_y)
169 | lp = torch.from_numpy(lp).cuda() # (seqs_y,batch)
170 | rp = torch.from_numpy(rp).cuda() # (seqs_y,batch)
171 |
172 | # permute for multi-GPU training
173 | # ly = ly.permute(1, 0)
174 | # ly_mask = ly_mask.permute(1, 0)
175 | # ry = ry.permute(1, 0)
176 | # ry_mask = ry_mask.permute(1, 0)
177 | # lp = lp.permute(1, 0)
178 | # rp = rp.permute(1, 0)
179 |
180 | # forward
181 | loss, lpred_loss, rpred_loss, repred_loss, mem_loss, KL_loss = \
182 | WAP_model(params, x, x_mask, ly, ly_mask, ry, ry_mask, re, re_mask, ma, ma_mask, lp, rp)
183 |
184 | # recover from permute
185 | lpred_loss_s += lpred_loss.item()
186 | rpred_loss_s += rpred_loss.item()
187 | repred_loss_s += repred_loss.item()
188 | mem_loss_s += mem_loss.item()
189 | KL_loss_s += KL_loss.item()
190 | loss_s += loss.item()
191 |
192 | # backward
193 | optimizer.zero_grad()
194 | loss.backward()
195 | if clip_c > 0.:
196 | torch.nn.utils.clip_grad_norm_(WAP_model.parameters(), clip_c)
197 |
198 | # update
199 | optimizer.step()
200 |
201 | ud = time.time() - ud_start
202 | ud_s += ud
203 |
204 | # display
205 | if np.mod(uidx, dispFreq) == 0:
206 | ud_s /= 60.
207 | loss_s /= dispFreq
208 | lpred_loss_s /= dispFreq
209 | rpred_loss_s /= dispFreq
210 | repred_loss_s /= dispFreq
211 | mem_loss_s /= dispFreq
212 | KL_loss_s /= dispFreq
213 | print ('Epoch', eidx, ' Update', uidx, ' Cost_lpred %.7f, Cost_rpred %.7f, Cost_re %.7f, Cost_matt %.7f, Cost_kl %.7f' % \
214 | (np.float(lpred_loss_s),np.float(rpred_loss_s),np.float(repred_loss_s),np.float(mem_loss_s),np.float(KL_loss_s)), \
215 | ' UD %.3f' % ud_s, ' lrate', lrate, ' eps', my_eps, ' bad_counter', bad_counter)
216 | ud_s = 0
217 | loss_s = 0.
218 | lpred_loss_s = 0.
219 | rpred_loss_s = 0.
220 | repred_loss_s = 0.
221 | mem_loss_s = 0.
222 | KL_loss_s = 0.
223 |
224 | # validation
225 | if np.mod(uidx, sampleFreq) == 0 and eidx >= validStart:
226 | print('begin sampling')
227 | ud_epoch_train = (time.time() - ud_epoch) / 60.
228 | print('epoch training cost time ... ', ud_epoch_train)
229 | WAP_model.eval()
230 | valid_out_path = valid_output[0]
231 | valid_malpha_path = valid_output[1]
232 | if not os.path.exists(valid_out_path):
233 | os.mkdir(valid_out_path)
234 | if not os.path.exists(valid_malpha_path):
235 | os.mkdir(valid_malpha_path)
236 | rec_mat = {}
237 | label_mat = {}
238 | rec_re_mat = {}
239 | label_re_mat = {}
240 | rec_ridx_mat = {}
241 | label_ridx_mat = {}
242 | with torch.no_grad():
243 | valid_count_idx = 0
244 | for x, ly, ry, re, ma, lp, rp in valid:
245 | for xx, lyy, ree, rpp in zip(x, ly, re, rp):
246 | xx_pad = xx.astype(np.float32) / 255.
247 | xx_pad = torch.from_numpy(xx_pad[None, :, :, :]).cuda() # (1,1,H,W)
248 | score, sample, malpha_list, relation_sample = \
249 | gen_sample(WAP_model, xx_pad, params, multi_gpu_flag, k=3, maxlen=maxlen, rpos_beam=3)
250 |
251 | key = valid_uid_list[valid_count_idx]
252 | rec_mat[key] = []
253 | label_mat[key] = lyy
254 | rec_re_mat[key] = []
255 | label_re_mat[key] = ree
256 | rec_ridx_mat[key] = []
257 | label_ridx_mat[key] = rpp
258 | if len(score) == 0:
259 | rec_mat[key].append(0)
260 | rec_re_mat[key].append(0) # End
261 | rec_ridx_mat[key].append(0)
262 | else:
263 | score = score / np.array([len(s) for s in sample])
264 | min_score_index = score.argmin()
265 | ss = sample[min_score_index]
266 | rs = relation_sample[min_score_index]
267 | mali = malpha_list[min_score_index]
268 | for i, [vv, rv] in enumerate(zip(ss, rs)):
269 | if vv == 0:
270 | rec_mat[key].append(vv)
271 | rec_re_mat[key].append(0) # End
272 | break
273 | else:
274 | if i == 0:
275 | rec_mat[key].append(vv)
276 | rec_re_mat[key].append(6) # Start
277 | else:
278 | rec_mat[key].append(vv)
279 | rec_re_mat[key].append(rv)
280 | ma_idx_list = np.array(mali).astype(np.int64)
281 | ma_idx_list[-1] = int(len(ma_idx_list)-1)
282 | rec_ridx_mat[key] = ma_idx_list
283 | valid_count_idx=valid_count_idx+1
284 |
285 | print('valid set decode done')
286 | ud_epoch = (time.time() - ud_epoch) / 60.
287 | print('epoch cost time ... ', ud_epoch)
288 |
289 | if np.mod(uidx, saveFreq) == 0:
290 | print('Saving latest model params ... ')
291 | torch.save(WAP_model.state_dict(), last_saveto)
292 |
293 | # calculate wer and expRate
294 | if np.mod(uidx, validFreq) == 0 and eidx >= validStart:
295 | valid_cer_out = compute_wer(rec_mat, label_mat)
296 | valid_cer = 100. * valid_cer_out[0]
297 | valid_recer_out = compute_wer(rec_re_mat, label_re_mat)
298 | valid_recer = 100. * valid_recer_out[0]
299 | valid_ridxcer_out = compute_wer(rec_ridx_mat, label_ridx_mat)
300 | valid_ridxcer = 100. * valid_ridxcer_out[0]
301 | valid_exprate = compute_sacc(rec_mat, label_mat, rec_ridx_mat, label_ridx_mat, rec_re_mat, label_re_mat, worddicts_r, reworddicts_r)
302 | valid_exprate = 100. * valid_exprate
303 | valid_err=valid_cer+valid_ridxcer
304 | history_errs.append(valid_err)
305 |
306 | # the first time validation or better model
307 | if uidx // validFreq == 0 or valid_err <= np.array(history_errs).min():
308 | bad_counter = 0
309 | print('Saving best model params ... ')
310 | if multi_gpu_flag:
311 | torch.save(WAP_model.module.state_dict(), saveto)
312 | else:
313 | torch.save(WAP_model.state_dict(), saveto)
314 |
315 | # worse model
316 | if uidx / validFreq != 0 and valid_err > np.array(history_errs).min():
317 | bad_counter += 1
318 | if bad_counter > patience:
319 | if halfLrFlag == 2:
320 | print('Early Stop!')
321 | estop = True
322 | break
323 | else:
324 | print('Lr decay and retrain!')
325 | bad_counter = 0
326 | lrate = lrate / 10.
327 | params['KL_lambda'] = params['KL_lambda'] * 0.5
328 | for param_group in optimizer.param_groups:
329 | param_group['lr'] = lrate
330 | halfLrFlag += 1
331 | print ('Valid CER: %.2f%%, relation_CER: %.2f%%, rpos_CER: %.2f%%, ExpRate: %.2f%%' % (valid_cer,valid_recer,valid_ridxcer,valid_exprate))
332 | # finish after these many updates
333 | if uidx >= finish_after:
334 | print('Finishing after %d iterations!' % uidx)
335 | estop = True
336 | break
337 |
338 | print('Seen %d samples' % n_samples)
339 |
340 | # early stop
341 | if estop:
342 | break
343 |
--------------------------------------------------------------------------------
/codes/translate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import copy
3 | import numpy as np
4 | import os
5 | import re
6 | import time
7 | import sys
8 |
9 | import torch
10 |
11 | from data_iterator import dataIterator, dataIterator_test
12 | from encoder_decoder import Encoder_Decoder
13 |
14 |
15 | # Note:
16 | # here model means Encoder_Decoder --> WAP_model
17 | # x means a sample not a batch(or batch_size = 1),and x's shape should be (1,1,H,W),type must be Variable
18 | # live_k is just equal to k -dead_k(except the begin of sentence:live_k = 1,dead_k = 0,so use k-dead_k to represent the number of alive paths in beam search)
19 |
20 | def gen_sample(model, x, params, gpu_flag, k=1, maxlen=30, rpos_beam=3):
21 |
22 | sample = []
23 | sample_score = []
24 | rpos_sample = []
25 | # rpos_sample_score = []
26 | relation_sample = []
27 |
28 | live_k = 1
29 | dead_k = 0 # except init, live_k = k - dead_k
30 |
31 | # current living paths and corresponding scores(-log)
32 | hyp_samples = [[]] * live_k
33 | hyp_scores = np.zeros(live_k).astype(np.float32)
34 | hyp_rpos_samples = [[]] * live_k
35 | hyp_relation_samples = [[]] * live_k
36 | # get init state, (1,n) and encoder output, (1,D,H,W)
37 | next_state, ctx0 = model.f_init(x)
38 | next_h1t = next_state
39 | # -1 -> My_embedding -> 0 tensor(1,m)
40 | next_lw = -1 * torch.ones(1, dtype=torch.int64).cuda()
41 | next_calpha_past = torch.zeros(1, ctx0.shape[2], ctx0.shape[3]).cuda() # (live_k,H,W)
42 | next_palpha_past = torch.zeros(1, ctx0.shape[2], ctx0.shape[3]).cuda()
43 | nextemb_memory = torch.zeros(params['maxlen'], live_k, params['m']).cuda()
44 | nextePmb_memory = torch.zeros(params['maxlen'], live_k, params['m']).cuda()
45 |
46 | for ii in range(maxlen):
47 | ctxP = ctx0.repeat(live_k, 1, 1, 1) # (live_k,D,H,W)
48 | next_lpos = ii * torch.ones(live_k, dtype=torch.int64).cuda()
49 | next_h01, next_ma, next_ctP, next_pa, next_palpha_past, nextemb_memory, nextePmb_memory = \
50 | model.f_next_parent(params, next_lw, next_lpos, ctxP, next_state, next_h1t, next_palpha_past, nextemb_memory, nextePmb_memory, ii)
51 | next_ma = next_ma.cpu().numpy()
52 | # next_ctP = next_ctP.cpu().numpy()
53 | next_palpha_past = next_palpha_past.cpu().numpy()
54 | nextemb_memory = nextemb_memory.cpu().numpy()
55 | nextePmb_memory = nextePmb_memory.cpu().numpy()
56 |
57 | nextemb_memory = np.transpose(nextemb_memory, (1, 0, 2)) # batch * Matt * dim
58 | nextePmb_memory = np.transpose(nextePmb_memory, (1, 0, 2))
59 |
60 | next_rpos = next_ma.argsort(axis=1)[:,-rpos_beam:] # topK parent index; batch * topK
61 | n_gaps = nextemb_memory.shape[1]
62 | n_batch = nextemb_memory.shape[0]
63 | next_rpos_gap = next_rpos + n_gaps * np.arange(n_batch)[:, None]
64 | next_remb_memory = nextemb_memory.reshape([n_batch*n_gaps, nextemb_memory.shape[-1]])
65 | next_remb = next_remb_memory[next_rpos_gap.flatten()] # [batch*rpos_beam, emb_dim]
66 | rpos_scores = next_ma.flatten()[next_rpos_gap.flatten()] # [batch*rpos_beam,]
67 |
68 | # next_ctPC = next_ctP.repeat(1, 1, rpos_beam)
69 | # next_ctPC = torch.reshape(next_ctPC, (-1, next_ctP.shape[1]))
70 | ctxC = ctx0.repeat(live_k*rpos_beam, 1, 1, 1)
71 | next_ctPC = torch.zeros(next_ctP.shape[0]*rpos_beam, next_ctP.shape[1]).cuda()
72 | next_h01C = torch.zeros(next_h01.shape[0]*rpos_beam, next_h01.shape[1]).cuda()
73 | next_calpha_pastC = torch.zeros(next_calpha_past.shape[0]*rpos_beam, next_calpha_past.shape[1], next_calpha_past.shape[2]).cuda()
74 | for bidx in range(next_calpha_past.shape[0]):
75 | for ridx in range(rpos_beam):
76 | next_ctPC[bidx*rpos_beam+ridx] = next_ctP[bidx]
77 | next_h01C[bidx*rpos_beam+ridx] = next_h01[bidx]
78 | next_calpha_pastC[bidx*rpos_beam+ridx] = next_calpha_past[bidx]
79 | next_remb = torch.from_numpy(next_remb).cuda()
80 |
81 | next_lp, next_rep, next_state, next_h1t, next_ca, next_calpha_past, next_re = \
82 | model.f_next_child(params, next_remb, next_ctPC, ctxC, next_h01C, next_calpha_pastC)
83 |
84 | next_lp = next_lp.cpu().numpy()
85 | next_state = next_state.cpu().numpy()
86 | next_h1t = next_h1t.cpu().numpy()
87 | next_calpha_past = next_calpha_past.cpu().numpy()
88 | next_re = next_re.cpu().numpy()
89 |
90 | hyp_scores = np.tile(hyp_scores[:, None], [1, rpos_beam]).flatten()
91 | cand_scores = hyp_scores[:, None] - np.log(next_lp+1e-10)- np.log(rpos_scores+1e-10)[:,None]
92 | cand_flat = cand_scores.flatten()
93 | ranks_flat = cand_flat.argsort()[:(k-dead_k)]
94 | voc_size = next_lp.shape[1]
95 | trans_indices = ranks_flat // voc_size
96 | trans_indicesP = ranks_flat // (voc_size*rpos_beam)
97 | word_indices = ranks_flat % voc_size
98 | costs = cand_flat[ranks_flat]
99 |
100 | # update paths
101 | new_hyp_samples = []
102 | new_hyp_scores = np.zeros(k-dead_k).astype('float32')
103 | new_hyp_rpos_samples = []
104 | new_hyp_relation_samples = []
105 | new_hyp_states = []
106 | new_hyp_h1ts = []
107 | new_hyp_calpha_past = []
108 | new_hyp_palpha_past = []
109 | new_hyp_emb_memory = []
110 | new_hyp_ePmb_memory = []
111 |
112 | for idx, [ti, wi, tPi] in enumerate(zip(trans_indices, word_indices, trans_indicesP)):
113 | new_hyp_samples.append(hyp_samples[tPi]+[wi])
114 | new_hyp_scores[idx] = copy.copy(costs[idx])
115 | new_hyp_rpos_samples.append(hyp_rpos_samples[tPi]+[next_rpos.flatten()[ti]])
116 | new_hyp_relation_samples.append(hyp_relation_samples[tPi]+[next_re[ti]])
117 | new_hyp_states.append(copy.copy(next_state[ti]))
118 | new_hyp_h1ts.append(copy.copy(next_h1t[ti]))
119 | new_hyp_calpha_past.append(copy.copy(next_calpha_past[ti]))
120 | new_hyp_palpha_past.append(copy.copy(next_palpha_past[tPi]))
121 | new_hyp_emb_memory.append(copy.copy(nextemb_memory[tPi]))
122 | new_hyp_ePmb_memory.append(copy.copy(nextePmb_memory[tPi]))
123 |
124 | # check the finished samples
125 | new_live_k = 0
126 | hyp_samples = []
127 | hyp_scores = []
128 | hyp_rpos_samples = []
129 | hyp_relation_samples = []
130 | hyp_states = []
131 | hyp_h1ts = []
132 | hyp_calpha_past = []
133 | hyp_palpha_past = []
134 | hyp_emb_memory = []
135 | hyp_ePmb_memory = []
136 |
137 | for idx in range(len(new_hyp_samples)):
138 | if new_hyp_samples[idx][-1] == 0: #
139 | sample_score.append(new_hyp_scores[idx])
140 | sample.append(new_hyp_samples[idx])
141 | rpos_sample.append(new_hyp_rpos_samples[idx])
142 | relation_sample.append(new_hyp_relation_samples[idx])
143 | dead_k += 1
144 | else:
145 | new_live_k += 1
146 | hyp_scores.append(new_hyp_scores[idx])
147 | hyp_samples.append(new_hyp_samples[idx])
148 | hyp_rpos_samples.append(new_hyp_rpos_samples[idx])
149 | hyp_relation_samples.append(new_hyp_relation_samples[idx])
150 | hyp_states.append(new_hyp_states[idx])
151 | hyp_h1ts.append(new_hyp_h1ts[idx])
152 | hyp_calpha_past.append(new_hyp_calpha_past[idx])
153 | hyp_palpha_past.append(new_hyp_palpha_past[idx])
154 | hyp_emb_memory.append(new_hyp_emb_memory[idx])
155 | hyp_ePmb_memory.append(new_hyp_ePmb_memory[idx])
156 |
157 | hyp_scores = np.array(hyp_scores)
158 | live_k = new_live_k
159 |
160 | # whether finish beam search
161 | if new_live_k < 1:
162 | break
163 | if dead_k >= k:
164 | break
165 |
166 | next_lw = np.array([w[-1] for w in hyp_samples]) # each path's final symbol, (live_k,)
167 | next_state = np.array(hyp_states) # h2t, (live_k,n)
168 | next_h1t = np.array(hyp_h1ts)
169 | next_calpha_past = np.array(hyp_calpha_past) # (live_k,H,W)
170 | next_palpha_past = np.array(hyp_palpha_past)
171 | nextemb_memory = np.array(hyp_emb_memory)
172 | nextemb_memory = np.transpose(nextemb_memory, (1, 0, 2))
173 | nextePmb_memory = np.array(hyp_ePmb_memory)
174 | nextePmb_memory = np.transpose(nextePmb_memory, (1, 0, 2))
175 | next_lw = torch.from_numpy(next_lw).cuda()
176 | next_state = torch.from_numpy(next_state).cuda()
177 | next_h1t = torch.from_numpy(next_h1t).cuda()
178 | next_calpha_past = torch.from_numpy(next_calpha_past).cuda()
179 | next_palpha_past = torch.from_numpy(next_palpha_past).cuda()
180 | nextemb_memory = torch.from_numpy(nextemb_memory).cuda()
181 | nextePmb_memory = torch.from_numpy(nextePmb_memory).cuda()
182 |
183 | return sample_score, sample, rpos_sample, relation_sample
184 |
185 |
186 | def load_dict(dictFile):
187 | fp = open(dictFile)
188 | stuff = fp.readlines()
189 | fp.close()
190 | lexicon = {}
191 | for l in stuff:
192 | w = l.strip().split()
193 | lexicon[w[0]] = int(w[1])
194 |
195 | print('total words/phones', len(lexicon))
196 | return lexicon
197 |
198 |
199 | def main(model_path, dictionary_target, dictionary_retarget, fea, output_path, k=5):
200 | # set parameters
201 | params = {}
202 | params['n'] = 256
203 | params['m'] = 256
204 | params['dim_attention'] = 512
205 | params['D'] = 684
206 | params['K'] = 106
207 | params['growthRate'] = 24
208 | params['reduction'] = 0.5
209 | params['bottleneck'] = True
210 | params['use_dropout'] = True
211 | params['input_channels'] = 1
212 | params['Kre'] = 7
213 | params['mre'] = 256
214 |
215 | maxlen = 300
216 | params['maxlen'] = maxlen
217 |
218 | # load model
219 | model = Encoder_Decoder(params)
220 | model.load_state_dict(torch.load(model_path,map_location=lambda storage,loc:storage))
221 | # enable CUDA
222 | model.cuda()
223 |
224 | # load source dictionary and invert
225 | worddicts = load_dict(dictionary_target)
226 | print ('total chars',len(worddicts))
227 | worddicts_r = [None] * len(worddicts)
228 | for kk, vv in worddicts.items():
229 | worddicts_r[vv] = kk
230 |
231 | reworddicts = load_dict(dictionary_retarget)
232 | print ('total relations',len(reworddicts))
233 | reworddicts_r = [None] * len(reworddicts)
234 | for kk, vv in reworddicts.items():
235 | reworddicts_r[vv] = kk
236 |
237 | valid,valid_uid_list = dataIterator_test(fea, worddicts, reworddicts,
238 | batch_size=8, batch_Imagesize=800000,maxImagesize=800000)
239 |
240 | # change model's mode to eval
241 | model.eval()
242 |
243 | valid_out_path = output_path + 'symbol_relation/'
244 | valid_malpha_path = output_path + 'memory_alpha/'
245 | if not os.path.exists(valid_out_path):
246 | os.mkdir(valid_out_path)
247 | if not os.path.exists(valid_malpha_path):
248 | os.mkdir(valid_malpha_path)
249 | valid_count_idx = 0
250 | print('Decoding ... ')
251 | ud_epoch = time.time()
252 | model.eval()
253 | with torch.no_grad():
254 | for x in valid:
255 | for xx in x: # xx:当前batch中的一个数据,numpy
256 | print('%d : %s' % (valid_count_idx + 1, valid_uid_list[valid_count_idx]))
257 | xx_pad = np.zeros((xx.shape[0], xx.shape[1], xx.shape[2]), dtype='float32') # (1,height,width)
258 | xx_pad[:, :, :] = xx / 255.
259 | xx_pad = torch.from_numpy(xx_pad[None, :, :, :]).cuda()
260 | score, sample, malpha_list, relation_sample = \
261 | gen_sample(model, xx_pad, params, gpu_flag=False, k=k, maxlen=maxlen)
262 | # sys.exit()
263 | if len(score) != 0:
264 | score = score / np.array([len(s) for s in sample])
265 | # relation_score = relation_score / np.array([len(r) for r in relation_sample])
266 | min_score_index = score.argmin()
267 | ss = sample[min_score_index]
268 | rs = relation_sample[min_score_index]
269 | mali = malpha_list[min_score_index]
270 | fpp_sample = open(valid_out_path+valid_uid_list[valid_count_idx]+'.txt','w')
271 | file_malpha_sample = valid_malpha_path+valid_uid_list[valid_count_idx]+'_malpha.txt'
272 | for i, [vv, rv] in enumerate(zip(ss, rs)):
273 | if vv == 0:
274 | string = worddicts_r[vv] + '\tEnd\n'
275 | fpp_sample.write(string)
276 | break
277 | else:
278 | if i == 0:
279 | string = worddicts_r[vv] + '\tStart\n'
280 | else:
281 | string = worddicts_r[vv] + '\t' + reworddicts_r[rv] + '\n'
282 | fpp_sample.write(string)
283 | np.savetxt(file_malpha_sample, np.array(mali))
284 | fpp_sample.close()
285 | valid_count_idx=valid_count_idx+1
286 | print('test set decode done')
287 | ud_epoch = (time.time() - ud_epoch) / 60.
288 | print('epoch cost time ... ', ud_epoch)
289 |
290 | # valid_result = [result_wer, result_exprate]
291 | # os.system('python compute_sym_re_ridx_cer.py ' + valid_out_path + ' ' + valid_malpha_path + ' ' + label_path + ' ' + valid_result[0])
292 | # fpp=open(valid_result[0])
293 | # lines = fpp.readlines()
294 | # fpp.close()
295 | # part1 = lines[-3].split()
296 | # if part1[0] == 'CER':
297 | # valid_cer=100. * float(part1[1])
298 | # else:
299 | # print ('no CER result')
300 | # part2 = lines[-2].split()
301 | # if part2[0] == 'reCER':
302 | # valid_recer=100. * float(part2[1])
303 | # else:
304 | # print ('no reCER result')
305 | # part3 = lines[-1].split()
306 | # if part3[0] == 'ridxCER':
307 | # valid_ridxcer=100. * float(part3[1])
308 | # else:
309 | # print ('no ridxCER result')
310 | # os.system('python evaluate_ExpRate2.py ' + valid_out_path + ' ' + valid_malpha_path + ' ' + label_path + ' ' + valid_result[1])
311 | # fpp=open(valid_result[1])
312 | # exp_lines = fpp.readlines()
313 | # fpp.close()
314 | # parts = exp_lines[0].split()
315 | # if parts[0] == 'ExpRate':
316 | # valid_exprate = float(parts[1])
317 | # else:
318 | # print ('no ExpRate result')
319 | # print ('ExpRate: %.2f%%' % (valid_exprate))
320 | # print ('Valid CER: %.2f%%, relation_CER: %.2f%%, rpos_CER: %.2f%%, ExpRate: %.2f%%' % (valid_cer,valid_recer,valid_ridxcer,valid_exprate))
321 |
322 |
323 | if __name__ == "__main__":
324 | parser = argparse.ArgumentParser()
325 | parser.add_argument('-k', type=int, default=10)
326 | parser.add_argument('model_path', type=str)
327 | parser.add_argument('dictionary_target', type=str)
328 | parser.add_argument('dictionary_retarget', type=str)
329 | parser.add_argument('fea', type=str)
330 | parser.add_argument('output_path', type=str)
331 |
332 | args = parser.parse_args()
333 |
334 | main(args.model_path, args.dictionary_target, args.dictionary_retarget, args.fea,
335 | args.output_path, k=args.k)
336 |
--------------------------------------------------------------------------------
/codes/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import copy
4 | import sys
5 | import pickle as pkl
6 | import torch
7 | from torch import nn
8 |
9 |
10 | class BatchBucket():
11 | def __init__(self, max_h, max_w, max_l, max_img_size, max_batch_size, feature_file, label_file, dictionary,
12 | use_all=True):
13 | self._max_img_size = max_img_size
14 | self._max_batch_size = max_batch_size
15 | self._fea_file = feature_file
16 | self._label_file = label_file
17 | self._dictionary_file = dictionary
18 | self._use_all = use_all
19 | self._dict_load()
20 | self._data_load()
21 | self.keys = self._calc_keys(max_h, max_w, max_l)
22 | self._make_plan()
23 | self._reset()
24 |
25 | def _dict_load(self):
26 | fp = open(self._dictionary_file)
27 | stuff = fp.readlines()
28 | fp.close()
29 | self._lexicon = {}
30 | for l in stuff:
31 | w = l.strip().split()
32 | self._lexicon[w[0]] = int(w[1])
33 |
34 | def _data_load(self):
35 | fp_fea = open(self._fea_file, 'rb')
36 | self._features = pkl.load(fp_fea)
37 | fp_fea.close()
38 | fp_label = open(self._label_file, 'r')
39 | labels = fp_label.readlines()
40 | fp_label.close()
41 | self._targets = {}
42 | for l in labels:
43 | tmp = l.strip().split()
44 | uid = tmp[0]
45 | w_list = []
46 | for w in tmp[1:]:
47 | if self._lexicon.__contains__(w):
48 | w_list.append(self._lexicon[w])
49 | else:
50 | print('a word not in the dictionary !! sentence ', uid, 'word ', w)
51 | sys.exit()
52 | self._targets[uid] = w_list
53 | # (uid, h, w, tgt_len)
54 | self._data_parser = [(uid, fea.shape[1], fea.shape[2], len(label)) for (uid, fea), (_, label) in
55 | zip(self._features.items(), self._targets.items())]
56 |
57 | def _calc_keys(self, max_h, max_w, max_l):
58 | mh = mw = ml = 0
59 | for _, h, w, l in self._data_parser:
60 | if h > mh:
61 | mh = h
62 | if w > mw:
63 | mw = w
64 | if l > ml:
65 | ml = l
66 | max_h = min(max_h, mh)
67 | max_w = min(max_w, mw)
68 | max_l = min(max_l, ml)
69 | keys = []
70 | init_h = 64 if 64 < max_h else max_h
71 | init_w = 64 if 64 < max_w else max_w
72 | init_l = 20 if 20 < max_l else max_l
73 | h_step = 64
74 | w_step = 64
75 | l_step = 30
76 | h = init_h
77 | while h <= max_h:
78 | w = init_w
79 | while w <= max_w:
80 | l = init_l
81 | while l <= max_l:
82 | keys.append([h, w, l, h * w * l, 0])
83 | if l < max_l and l + l_step > max_l:
84 | l = max_l
85 | continue
86 | l += l_step
87 | if w < max_w and w + w_step > max_w:
88 | w = max_w
89 | continue
90 | w += w_step
91 | if h < max_h and h + h_step > max_h:
92 | h = max_h
93 | continue
94 | h += h_step
95 | keys = sorted(keys, key=lambda area: area[3])
96 | for _, h, w, l in self._data_parser:
97 | for i in range(len(keys)):
98 | hh, ww, ll, _, _ = keys[i]
99 | if h <= hh and w <= ww and l <= ll:
100 | keys[i][-1] += 1
101 | break
102 | new_keys = []
103 | n_samples = len(self._data_parser)
104 | th = n_samples * 0.01
105 | if self._use_all:
106 | th = 1
107 | num = 0
108 | for key in keys:
109 | hh, ww, ll, _, n = key
110 | num += n
111 | if num >= th:
112 | new_keys.append((hh, ww, ll))
113 | num = 0
114 | return new_keys
115 |
116 | def _make_plan(self):
117 | self._bucket_keys = []
118 | for h, w, l in self.keys:
119 | batch_size = int(self._max_img_size / (h * w))
120 | if batch_size > self._max_batch_size:
121 | batch_size = self._max_batch_size
122 | if batch_size == 0:
123 | continue
124 | self._bucket_keys.append((batch_size, h, w, l))
125 | self._data_buckets = [[] for key in self._bucket_keys]
126 | unuse_num = 0
127 | for item in self._data_parser:
128 | flag = 0
129 | for key, bucket in zip(self._bucket_keys, self._data_buckets):
130 | _, h, w, l = key
131 | if item[1] <= h and item[2] <= w and item[3] <= l:
132 | bucket.append(item)
133 | flag = 1
134 | break
135 | if flag == 0:
136 | unuse_num += 1
137 | print('The number of unused samples: ', unuse_num)
138 | all_sample_num = 0
139 | for key, bucket in zip(self._bucket_keys, self._data_buckets):
140 | sample_num = len(bucket)
141 | all_sample_num += sample_num
142 | print('bucket {}, sample number={}'.format(key, len(bucket)))
143 | print('All samples number={}, raw samples number={}'.format(all_sample_num, len(self._data_parser)))
144 |
145 | def _reset(self):
146 | # shuffle data in each bucket
147 | for bucket in self._data_buckets:
148 | random.shuffle(bucket)
149 | self._batches = []
150 | for id, (key, bucket) in enumerate(zip(self._bucket_keys, self._data_buckets)):
151 | batch_size, _, _, _ = key
152 | bucket_len = len(bucket)
153 | batch_num = (bucket_len + batch_size - 1) // batch_size
154 | for i in range(batch_num):
155 | start = i * batch_size
156 | end = start + batch_size if start + batch_size < bucket_len else bucket_len
157 | if start != end: # remove empty batch
158 | self._batches.append(bucket[start:end])
159 |
160 | def get_batches(self):
161 | batches = []
162 | uid_batches = []
163 | for batch_info in self._batches:
164 | fea_batch = []
165 | label_batch = []
166 | for uid, _, _, _ in batch_info:
167 | feature = self._features[uid]
168 | label = self._targets[uid]
169 | fea_batch.append(feature)
170 | label_batch.append(label)
171 | uid_batches.append(uid)
172 | batches.append((fea_batch, label_batch))
173 | return batches, uid_batches
174 |
175 |
176 | # load dictionary
177 | def load_dict(dictFile):
178 | fp = open(dictFile)
179 | stuff = fp.readlines()
180 | fp.close()
181 | lexicon = {}
182 | for l in stuff:
183 | w = l.strip().split()
184 | lexicon[w[0]] = int(w[1])
185 | print('total words/phones', len(lexicon))
186 | return lexicon
187 |
188 |
189 | # create batch
190 | def prepare_data(params, images_x, seqs_ly, seqs_ry, seqs_re, seqs_ma, seqs_lp, seqs_rp):
191 | heights_x = [s.shape[1] for s in images_x]
192 | widths_x = [s.shape[2] for s in images_x]
193 | lengths_ly = [len(s) for s in seqs_ly]
194 | lengths_ry = [len(s) for s in seqs_ry]
195 |
196 | n_samples = len(heights_x)
197 | max_height_x = np.max(heights_x)
198 | max_width_x = np.max(widths_x)
199 | maxlen_ly = np.max(lengths_ly)
200 | maxlen_ry = np.max(lengths_ry)
201 |
202 | x = np.zeros((n_samples, params['input_channels'], max_height_x, max_width_x)).astype(np.float32)
203 | ly = np.zeros((maxlen_ly, n_samples)).astype(np.int64) # must be 0 in the dict
204 | ry = np.zeros((maxlen_ry, n_samples)).astype(np.int64)
205 | re = np.zeros((maxlen_ly, n_samples)).astype(np.int64)
206 | ma = np.zeros((n_samples, maxlen_ly, maxlen_ly)).astype(np.int64)
207 | lp = np.zeros((maxlen_ly, n_samples)).astype(np.int64)
208 | rp = np.zeros((maxlen_ry, n_samples)).astype(np.int64)
209 |
210 | x_mask = np.zeros((n_samples, max_height_x, max_width_x)).astype(np.float32)
211 | ly_mask = np.zeros((maxlen_ly, n_samples)).astype(np.float32)
212 | ry_mask = np.zeros((maxlen_ry, n_samples)).astype(np.float32)
213 | re_mask = np.zeros((maxlen_ly, n_samples)).astype(np.float32)
214 | ma_mask = np.zeros((n_samples, maxlen_ly, maxlen_ly)).astype(np.float32)
215 |
216 | for idx, [s_x, s_ly, s_ry, s_re, s_ma, s_lp, s_rp] in enumerate(zip(images_x, seqs_ly, seqs_ry, seqs_re, seqs_ma, seqs_lp, seqs_rp)):
217 | x[idx, :, :heights_x[idx], :widths_x[idx]] = s_x / 255.
218 | x_mask[idx, :heights_x[idx], :widths_x[idx]] = 1.
219 | ly[:lengths_ly[idx], idx] = s_ly
220 | ly_mask[:lengths_ly[idx], idx] = 1.
221 | ry[:lengths_ry[idx], idx] = s_ry
222 | ry_mask[:lengths_ry[idx], idx] = 1.
223 | ry_mask[0, idx] = 0. # remove the
224 | re[:lengths_ly[idx], idx] = s_re
225 | re_mask[:lengths_ly[idx], idx] = 1.
226 | re_mask[0, idx] = 0. # remove the Start relation
227 | re_mask[lengths_ly[idx]-1, idx] = 0. # remove the End relation
228 | ma[idx, :lengths_ly[idx], :lengths_ly[idx]] = s_ma
229 | for ma_idx in range(lengths_ly[idx]):
230 | ma_mask[idx, :(ma_idx+1), ma_idx] = 1.
231 | lp[:lengths_ly[idx], idx] = s_lp
232 | # lp_mask[:lengths_ly[idx], idx] = 1
233 | rp[:lengths_ry[idx], idx] = s_rp
234 |
235 | return x, x_mask, ly, ly_mask, ry, ry_mask, re, re_mask, ma, ma_mask, lp, rp
236 |
237 | def gen_sample(model, x, params, gpu_flag, k=1, maxlen=30, rpos_beam=3):
238 |
239 | sample = []
240 | sample_score = []
241 | rpos_sample = []
242 | # rpos_sample_score = []
243 | relation_sample = []
244 |
245 | live_k = 1
246 | dead_k = 0 # except init, live_k = k - dead_k
247 |
248 | # current living paths and corresponding scores(-log)
249 | hyp_samples = [[]] * live_k
250 | hyp_scores = np.zeros(live_k).astype(np.float32)
251 | hyp_rpos_samples = [[]] * live_k
252 | hyp_relation_samples = [[]] * live_k
253 | # get init state, (1,n) and encoder output, (1,D,H,W)
254 | next_state, ctx0 = model.f_init(x)
255 | next_h1t = next_state
256 | # -1 -> My_embedding -> 0 tensor(1,m)
257 | next_lw = -1 * torch.ones(1, dtype=torch.int64).cuda()
258 | next_calpha_past = torch.zeros(1, ctx0.shape[2], ctx0.shape[3]).cuda() # (live_k,H,W)
259 | next_palpha_past = torch.zeros(1, ctx0.shape[2], ctx0.shape[3]).cuda()
260 | nextemb_memory = torch.zeros(params['maxlen'], live_k, params['m']).cuda()
261 | nextePmb_memory = torch.zeros(params['maxlen'], live_k, params['m']).cuda()
262 |
263 | for ii in range(maxlen):
264 | ctxP = ctx0.repeat(live_k, 1, 1, 1) # (live_k,D,H,W)
265 | next_lpos = ii * torch.ones(live_k, dtype=torch.int64).cuda()
266 | next_h01, next_ma, next_ctP, next_pa, next_palpha_past, nextemb_memory, nextePmb_memory = \
267 | model.f_next_parent(params, next_lw, next_lpos, ctxP, next_state, next_h1t, next_palpha_past, nextemb_memory, nextePmb_memory, ii)
268 | next_ma = next_ma.cpu().numpy()
269 | # next_ctP = next_ctP.cpu().numpy()
270 | next_palpha_past = next_palpha_past.cpu().numpy()
271 | nextemb_memory = nextemb_memory.cpu().numpy()
272 | nextePmb_memory = nextePmb_memory.cpu().numpy()
273 |
274 | nextemb_memory = np.transpose(nextemb_memory, (1, 0, 2)) # batch * Matt * dim
275 | nextePmb_memory = np.transpose(nextePmb_memory, (1, 0, 2))
276 |
277 | next_rpos = next_ma.argsort(axis=1)[:,-rpos_beam:] # topK parent index; batch * topK
278 | n_gaps = nextemb_memory.shape[1]
279 | n_batch = nextemb_memory.shape[0]
280 | next_rpos_gap = next_rpos + n_gaps * np.arange(n_batch)[:, None]
281 | next_remb_memory = nextemb_memory.reshape([n_batch*n_gaps, nextemb_memory.shape[-1]])
282 | next_remb = next_remb_memory[next_rpos_gap.flatten()] # [batch*rpos_beam, emb_dim]
283 | rpos_scores = next_ma.flatten()[next_rpos_gap.flatten()] # [batch*rpos_beam,]
284 |
285 | # next_ctPC = next_ctP.repeat(1, 1, rpos_beam)
286 | # next_ctPC = torch.reshape(next_ctPC, (-1, next_ctP.shape[1]))
287 | ctxC = ctx0.repeat(live_k*rpos_beam, 1, 1, 1)
288 | next_ctPC = torch.zeros(next_ctP.shape[0]*rpos_beam, next_ctP.shape[1]).cuda()
289 | next_h01C = torch.zeros(next_h01.shape[0]*rpos_beam, next_h01.shape[1]).cuda()
290 | next_calpha_pastC = torch.zeros(next_calpha_past.shape[0]*rpos_beam, next_calpha_past.shape[1], next_calpha_past.shape[2]).cuda()
291 | for bidx in range(next_calpha_past.shape[0]):
292 | for ridx in range(rpos_beam):
293 | next_ctPC[bidx*rpos_beam+ridx] = next_ctP[bidx]
294 | next_h01C[bidx*rpos_beam+ridx] = next_h01[bidx]
295 | next_calpha_pastC[bidx*rpos_beam+ridx] = next_calpha_past[bidx]
296 | next_remb = torch.from_numpy(next_remb).cuda()
297 |
298 | next_lp, next_rep, next_state, next_h1t, next_ca, next_calpha_past, next_re = \
299 | model.f_next_child(params, next_remb, next_ctPC, ctxC, next_h01C, next_calpha_pastC)
300 |
301 | next_lp = next_lp.cpu().numpy()
302 | next_state = next_state.cpu().numpy()
303 | next_h1t = next_h1t.cpu().numpy()
304 | next_calpha_past = next_calpha_past.cpu().numpy()
305 | next_re = next_re.cpu().numpy()
306 |
307 | hyp_scores = np.tile(hyp_scores[:, None], [1, rpos_beam]).flatten()
308 | cand_scores = hyp_scores[:, None] - np.log(next_lp+1e-10)- np.log(rpos_scores+1e-10)[:,None]
309 | cand_flat = cand_scores.flatten()
310 | ranks_flat = cand_flat.argsort()[:(k-dead_k)]
311 | voc_size = next_lp.shape[1]
312 | trans_indices = ranks_flat // voc_size
313 | trans_indicesP = ranks_flat // (voc_size*rpos_beam)
314 | word_indices = ranks_flat % voc_size
315 | costs = cand_flat[ranks_flat]
316 |
317 | # update paths
318 | new_hyp_samples = []
319 | new_hyp_scores = np.zeros(k-dead_k).astype('float32')
320 | new_hyp_rpos_samples = []
321 | new_hyp_relation_samples = []
322 | new_hyp_states = []
323 | new_hyp_h1ts = []
324 | new_hyp_calpha_past = []
325 | new_hyp_palpha_past = []
326 | new_hyp_emb_memory = []
327 | new_hyp_ePmb_memory = []
328 |
329 | for idx, [ti, wi, tPi] in enumerate(zip(trans_indices, word_indices, trans_indicesP)):
330 | new_hyp_samples.append(hyp_samples[tPi]+[wi])
331 | new_hyp_scores[idx] = copy.copy(costs[idx])
332 | new_hyp_rpos_samples.append(hyp_rpos_samples[tPi]+[next_rpos.flatten()[ti]])
333 | new_hyp_relation_samples.append(hyp_relation_samples[tPi]+[next_re[ti]])
334 | new_hyp_states.append(copy.copy(next_state[ti]))
335 | new_hyp_h1ts.append(copy.copy(next_h1t[ti]))
336 | new_hyp_calpha_past.append(copy.copy(next_calpha_past[ti]))
337 | new_hyp_palpha_past.append(copy.copy(next_palpha_past[tPi]))
338 | new_hyp_emb_memory.append(copy.copy(nextemb_memory[tPi]))
339 | new_hyp_ePmb_memory.append(copy.copy(nextePmb_memory[tPi]))
340 |
341 | # check the finished samples
342 | new_live_k = 0
343 | hyp_samples = []
344 | hyp_scores = []
345 | hyp_rpos_samples = []
346 | hyp_relation_samples = []
347 | hyp_states = []
348 | hyp_h1ts = []
349 | hyp_calpha_past = []
350 | hyp_palpha_past = []
351 | hyp_emb_memory = []
352 | hyp_ePmb_memory = []
353 |
354 | for idx in range(len(new_hyp_samples)):
355 | if new_hyp_samples[idx][-1] == 0: #
356 | sample_score.append(new_hyp_scores[idx])
357 | sample.append(new_hyp_samples[idx])
358 | rpos_sample.append(new_hyp_rpos_samples[idx])
359 | relation_sample.append(new_hyp_relation_samples[idx])
360 | dead_k += 1
361 | else:
362 | new_live_k += 1
363 | hyp_scores.append(new_hyp_scores[idx])
364 | hyp_samples.append(new_hyp_samples[idx])
365 | hyp_rpos_samples.append(new_hyp_rpos_samples[idx])
366 | hyp_relation_samples.append(new_hyp_relation_samples[idx])
367 | hyp_states.append(new_hyp_states[idx])
368 | hyp_h1ts.append(new_hyp_h1ts[idx])
369 | hyp_calpha_past.append(new_hyp_calpha_past[idx])
370 | hyp_palpha_past.append(new_hyp_palpha_past[idx])
371 | hyp_emb_memory.append(new_hyp_emb_memory[idx])
372 | hyp_ePmb_memory.append(new_hyp_ePmb_memory[idx])
373 |
374 | hyp_scores = np.array(hyp_scores)
375 | live_k = new_live_k
376 |
377 | # whether finish beam search
378 | if new_live_k < 1:
379 | break
380 | if dead_k >= k:
381 | break
382 |
383 | next_lw = np.array([w[-1] for w in hyp_samples]) # each path's final symbol, (live_k,)
384 | next_state = np.array(hyp_states) # h2t, (live_k,n)
385 | next_h1t = np.array(hyp_h1ts)
386 | next_calpha_past = np.array(hyp_calpha_past) # (live_k,H,W)
387 | next_palpha_past = np.array(hyp_palpha_past)
388 | nextemb_memory = np.array(hyp_emb_memory)
389 | nextemb_memory = np.transpose(nextemb_memory, (1, 0, 2))
390 | nextePmb_memory = np.array(hyp_ePmb_memory)
391 | nextePmb_memory = np.transpose(nextePmb_memory, (1, 0, 2))
392 | next_lw = torch.from_numpy(next_lw).cuda()
393 | next_state = torch.from_numpy(next_state).cuda()
394 | next_h1t = torch.from_numpy(next_h1t).cuda()
395 | next_calpha_past = torch.from_numpy(next_calpha_past).cuda()
396 | next_palpha_past = torch.from_numpy(next_palpha_past).cuda()
397 | nextemb_memory = torch.from_numpy(nextemb_memory).cuda()
398 | nextePmb_memory = torch.from_numpy(nextePmb_memory).cuda()
399 |
400 | return sample_score, sample, rpos_sample, relation_sample
401 |
402 |
403 | # init model params
404 | def weight_init(m):
405 | if isinstance(m, nn.Conv2d):
406 | nn.init.xavier_uniform_(m.weight.data)
407 | try:
408 | nn.init.constant_(m.bias.data, 0.)
409 | except:
410 | pass
411 |
412 | if isinstance(m, nn.Linear):
413 | nn.init.xavier_uniform_(m.weight.data)
414 | try:
415 | nn.init.constant_(m.bias.data, 0.)
416 | except:
417 | pass
418 |
419 | # compute metric
420 | def cmp_result(rec,label):
421 | dist_mat = np.zeros((len(label)+1, len(rec)+1),dtype='int32')
422 | dist_mat[0,:] = range(len(rec) + 1)
423 | dist_mat[:,0] = range(len(label) + 1)
424 | for i in range(1, len(label) + 1):
425 | for j in range(1, len(rec) + 1):
426 | hit_score = dist_mat[i-1, j-1] + (label[i-1] != rec[j-1])
427 | ins_score = dist_mat[i,j-1] + 1
428 | del_score = dist_mat[i-1, j] + 1
429 | dist_mat[i,j] = min(hit_score, ins_score, del_score)
430 |
431 | dist = dist_mat[len(label), len(rec)]
432 | return dist, len(label)
433 |
434 | def compute_wer(rec_mat, label_mat):
435 | total_dist = 0
436 | total_label = 0
437 | total_line = 0
438 | total_line_rec = 0
439 | for key_rec in rec_mat:
440 | label = label_mat[key_rec]
441 | rec = rec_mat[key_rec]
442 | # label = list(map(int,label))
443 | # rec = list(map(int,rec))
444 | dist, llen = cmp_result(rec, label)
445 | total_dist += dist
446 | total_label += llen
447 | total_line += 1
448 | if dist == 0:
449 | total_line_rec += 1
450 | wer = float(total_dist)/total_label
451 | sacc = float(total_line_rec)/total_line
452 | return wer, sacc
453 |
454 | def cmp_sacc_result(rec_list,label_list,rec_ridx_list,label_ridx_list,rec_re_list,label_re_list,chdict,redict):
455 | rec = True
456 | out_sym_pdict = {}
457 | label_sym_pdict = {}
458 | out_sym_pdict['0'] = ''
459 | label_sym_pdict['0'] = ''
460 | for idx, sym in enumerate(rec_list):
461 | out_sym_pdict[str(idx+1)] = chdict[sym]
462 | for idx, sym in enumerate(label_list):
463 | label_sym_pdict[str(idx+1)] = chdict[sym]
464 |
465 | if len(rec_list) != len(label_list):
466 | rec = False
467 | else:
468 | for idx in range(len(rec_list)):
469 | out_sym = chdict[rec_list[idx]]
470 | label_sym = chdict[label_list[idx]]
471 | out_repos = int(rec_ridx_list[idx])
472 | label_repos = int(label_ridx_list[idx])
473 | out_re = redict[rec_re_list[idx]]
474 | label_re = redict[label_re_list[idx]]
475 | if out_repos in out_sym_pdict:
476 | out_resym_s = out_sym_pdict[out_repos]
477 | else:
478 | out_resym_s = 'unknown'
479 | if label_repos in label_sym_pdict:
480 | label_resym_s = label_sym_pdict[label_repos]
481 | else:
482 | label_resym_s = 'unknown'
483 |
484 | # post-processing only for math recognition
485 | if (out_resym_s == '\lim' and label_resym_s == '\lim') or \
486 | (out_resym_s == '\int' and label_resym_s == '\int') or \
487 | (out_resym_s == '\sum' and label_resym_s == '\sum'):
488 | if out_re == 'Above':
489 | out_re = 'Sup'
490 | if out_re == 'Below':
491 | out_re = 'Sub'
492 | if label_re == 'Above':
493 | label_re = 'Sup'
494 | if label_re == 'Below':
495 | label_re = 'Sub'
496 |
497 | # if out_sym != label_sym or out_pos != label_pos or out_repos != label_repos or out_re != label_re:
498 | # if out_sym != label_sym or out_repos != label_repos:
499 | if out_sym != label_sym or out_repos != label_repos or out_re != label_re:
500 | rec = False
501 | break
502 | return rec
503 |
504 | def compute_sacc(rec_mat, label_mat, rec_ridx_mat, label_ridx_mat, rec_re_mat, label_re_mat, chdict, redict):
505 | total_num = len(rec_mat)
506 | correct_num = 0
507 | for key_rec in rec_mat:
508 | rec_list = rec_mat[key_rec]
509 | label_list = label_mat[key_rec]
510 | rec_ridx_list = rec_ridx_mat[key_rec]
511 | label_ridx_list = label_ridx_mat[key_rec]
512 | rec_re_list = rec_re_mat[key_rec]
513 | label_re_list = label_re_mat[key_rec]
514 | rec_result = cmp_sacc_result(rec_list,label_list,rec_ridx_list,label_ridx_list,rec_re_list,label_re_list,chdict,redict)
515 | if rec_result:
516 | correct_num += 1
517 | correct_rate = 1. * correct_num / total_num
518 | return correct_rate
519 |
--------------------------------------------------------------------------------
/paper/TD_camera_v1.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JianshuZhang/TreeDecoder/e73da41ba234d01467d23b9bf0f36e1079e96c64/paper/TD_camera_v1.pdf
--------------------------------------------------------------------------------