├── .DS_Store ├── README.md ├── codes ├── data_iterator.py ├── decoder.py ├── encoder.py ├── encoder_decoder.py ├── gtd2latex.py ├── latex2gtd.py ├── prepare_label.py ├── train.sh ├── train_wap.py ├── translate.py └── utils.py └── paper └── TD_camera_v1.pdf /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianshuZhang/TreeDecoder/e73da41ba234d01467d23b9bf0f36e1079e96c64/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TreeDecoder 2 | 3 | The source codes has been released, will make it clear for those who are not familiar with deep learning and encoder-decoder models:
4 | 5 | The data will be released after it is prepared:
6 | 7 | * **Tree Decoder**: A Tree-Structured Decoder for Image-to-Markup Generation
8 | 9 | ## Citation 10 | If you find Tree Decoder useful in your research, please consider citing: 11 | 12 | @inproceedings{zhang2020treedecoder, 13 | title={A Tree-Structured Decoder for Image-to-Markup Generation}, 14 | author={Zhang, Jianshu and Du, Jun and Yang, Yongxin and Song, Yi-Zhe and Wei, Si and Dai, Lirong}, 15 | booktitle={ICML}, 16 | pages={In Press}, 17 | year={2020} 18 | } 19 | -------------------------------------------------------------------------------- /codes/data_iterator.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | import pickle as pkl 4 | import gzip 5 | 6 | 7 | def fopen(filename, mode='r'): 8 | if filename.endswith('.gz'): 9 | return gzip.open(filename, mode) 10 | return open(filename, mode) 11 | 12 | def dataIterator(feature_file,label_file,align_file,dictionary,redictionary,batch_size,batch_Imagesize,maxlen,maxImagesize): 13 | 14 | fp_feature=open(feature_file,'rb') 15 | features=pkl.load(fp_feature) 16 | fp_feature.close() 17 | 18 | fp_label=open(label_file,'rb') 19 | labels=pkl.load(fp_label) 20 | fp_label.close() 21 | 22 | fp_align=open(align_file,'rb') 23 | aligns=pkl.load(fp_align) 24 | fp_align.close() 25 | 26 | ltargets = {} 27 | rtargets = {} 28 | relations = {} 29 | lpositions = {} 30 | rpositions = {} 31 | 32 | # map word to int with dictionary 33 | for uid, label_lines in labels.items(): 34 | lchar_list = [] 35 | rchar_list = [] 36 | relation_list = [] 37 | lpos_list = [] 38 | rpos_list = [] 39 | for line_idx, line in enumerate(label_lines): 40 | parts = line.strip().split('\t') 41 | lchar = parts[0] 42 | lpos = parts[1] 43 | rchar = parts[2] 44 | rpos = parts[3] 45 | relation = parts[4] 46 | if dictionary.__contains__(lchar): 47 | lchar_list.append(dictionary[lchar]) 48 | else: 49 | print ('a symbol not in the dictionary !! formula',uid ,'symbol', lchar) 50 | sys.exit() 51 | if dictionary.__contains__(rchar): 52 | rchar_list.append(dictionary[rchar]) 53 | else: 54 | print ('a symbol not in the dictionary !! formula',uid ,'symbol', rchar) 55 | sys.exit() 56 | 57 | lpos_list.append(int(lpos)) 58 | rpos_list.append(int(rpos)) 59 | 60 | if line_idx != len(label_lines)-1: 61 | if redictionary.__contains__(relation): 62 | relation_list.append(redictionary[relation]) 63 | else: 64 | print ('a relation not in the redictionary !! formula',uid ,'relation', relation) 65 | sys.exit() 66 | else: 67 | relation_list.append(0) # whatever which one to replace End relation 68 | ltargets[uid]=lchar_list 69 | rtargets[uid]=rchar_list 70 | relations[uid]=relation_list 71 | lpositions[uid] = lpos_list 72 | rpositions[uid] = rpos_list 73 | 74 | imageSize={} 75 | for uid,fea in features.items(): 76 | if uid in ltargets: 77 | imageSize[uid]=fea.shape[1]*fea.shape[2] 78 | else: 79 | continue 80 | 81 | imageSize= sorted(imageSize.items(), key=lambda d:d[1]) # sorted by sentence length, return a list with each triple element 82 | 83 | feature_batch=[] 84 | llabel_batch=[] 85 | rlabel_batch=[] 86 | relabel_batch=[] 87 | align_batch=[] 88 | lpos_batch=[] 89 | rpos_batch=[] 90 | 91 | feature_total=[] 92 | llabel_total=[] 93 | rlabel_total=[] 94 | relabel_total=[] 95 | align_total=[] 96 | lpos_total=[] 97 | rpos_total=[] 98 | 99 | uidList=[] 100 | 101 | batch_image_size=0 102 | biggest_image_size=0 103 | i=0 104 | for uid,size in imageSize: 105 | if uid not in ltargets: 106 | continue 107 | if size>biggest_image_size: 108 | biggest_image_size=size 109 | fea=features[uid] 110 | llab=ltargets[uid] 111 | rlab=rtargets[uid] 112 | relab=relations[uid] 113 | ali=aligns[uid] 114 | lp=lpositions[uid] 115 | rp=rpositions[uid] 116 | batch_image_size=biggest_image_size*(i+1) 117 | if len(llab)>maxlen: 118 | print ('this sentence length bigger than', maxlen, 'ignore') 119 | elif size>maxImagesize: 120 | print ('this image size bigger than', maxImagesize, 'ignore') 121 | else: 122 | uidList.append(uid) 123 | if batch_image_size>batch_Imagesize or i==batch_size: # a batch is full 124 | feature_total.append(feature_batch) 125 | llabel_total.append(llabel_batch) 126 | rlabel_total.append(rlabel_batch) 127 | relabel_total.append(relabel_batch) 128 | align_total.append(align_batch) 129 | lpos_total.append(lpos_batch) 130 | rpos_total.append(rpos_batch) 131 | 132 | i=0 133 | biggest_image_size=size 134 | feature_batch=[] 135 | llabel_batch=[] 136 | rlabel_batch=[] 137 | relabel_batch=[] 138 | align_batch=[] 139 | lpos_batch=[] 140 | rpos_batch=[] 141 | feature_batch.append(fea) 142 | llabel_batch.append(llab) 143 | rlabel_batch.append(rlab) 144 | relabel_batch.append(relab) 145 | align_batch.append(ali) 146 | lpos_batch.append(lp) 147 | rpos_batch.append(rp) 148 | batch_image_size=biggest_image_size*(i+1) 149 | i+=1 150 | else: 151 | feature_batch.append(fea) 152 | llabel_batch.append(llab) 153 | rlabel_batch.append(rlab) 154 | relabel_batch.append(relab) 155 | align_batch.append(ali) 156 | lpos_batch.append(lp) 157 | rpos_batch.append(rp) 158 | i+=1 159 | 160 | # last batch 161 | feature_total.append(feature_batch) 162 | llabel_total.append(llabel_batch) 163 | rlabel_total.append(rlabel_batch) 164 | relabel_total.append(relabel_batch) 165 | align_total.append(align_batch) 166 | lpos_total.append(lpos_batch) 167 | rpos_total.append(rpos_batch) 168 | 169 | print ('total ',len(feature_total), 'batch data loaded') 170 | 171 | return list(zip(feature_total,llabel_total,rlabel_total,relabel_total,align_total,lpos_total,rpos_total)),uidList 172 | 173 | 174 | def dataIterator_test(feature_file,dictionary,redictionary,batch_size,batch_Imagesize,maxImagesize): 175 | 176 | fp_feature=open(feature_file,'rb') 177 | features=pkl.load(fp_feature) 178 | fp_feature.close() 179 | 180 | imageSize={} 181 | for uid,fea in features.items(): 182 | imageSize[uid]=fea.shape[1]*fea.shape[2] 183 | 184 | imageSize= sorted(imageSize.items(), key=lambda d:d[1]) # sorted by sentence length, return a list with each triple element 185 | 186 | feature_batch=[] 187 | 188 | feature_total=[] 189 | 190 | uidList=[] 191 | 192 | batch_image_size=0 193 | biggest_image_size=0 194 | i=0 195 | for uid,size in imageSize: 196 | if size>biggest_image_size: 197 | biggest_image_size=size 198 | fea=features[uid] 199 | batch_image_size=biggest_image_size*(i+1) 200 | if size>maxImagesize: 201 | print ('this image size bigger than', maxImagesize, 'ignore') 202 | elif uid == '34_em_225': 203 | print ('this image ignore', uid) 204 | else: 205 | uidList.append(uid) 206 | if batch_image_size>batch_Imagesize or i==batch_size: # a batch is full 207 | feature_total.append(feature_batch) 208 | 209 | i=0 210 | biggest_image_size=size 211 | feature_batch=[] 212 | feature_batch.append(fea) 213 | batch_image_size=biggest_image_size*(i+1) 214 | i+=1 215 | else: 216 | feature_batch.append(fea) 217 | i+=1 218 | 219 | # last batch 220 | feature_total.append(feature_batch) 221 | 222 | print ('total ',len(feature_total), 'batch data loaded') 223 | 224 | return feature_total, uidList 225 | -------------------------------------------------------------------------------- /codes/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # two layers of GRU 6 | class Gru_cond_layer(nn.Module): 7 | def __init__(self, params): 8 | super(Gru_cond_layer, self).__init__() 9 | self.fc_Wyz0 = nn.Linear(params['m'], params['n']) 10 | self.fc_Wyr0 = nn.Linear(params['m'], params['n']) 11 | self.fc_Wyh0 = nn.Linear(params['m'], params['n']) 12 | self.fc_Uhz0 = nn.Linear(params['n'], params['n'], bias=False) 13 | self.fc_Uhr0 = nn.Linear(params['n'], params['n'], bias=False) 14 | self.fc_Uhh0 = nn.Linear(params['n'], params['n'], bias=False) 15 | 16 | # attention for parent symbol 17 | self.conv_UaP = nn.Conv2d(params['D'], params['dim_attention'], kernel_size=1) 18 | self.fc_WaP = nn.Linear(params['n'], params['dim_attention'], bias=False) 19 | self.conv_QP = nn.Conv2d(1, 512, kernel_size=11, bias=False, padding=5) 20 | self.fc_UfP = nn.Linear(512, params['dim_attention']) 21 | self.fc_vaP = nn.Linear(params['dim_attention'], 1) 22 | 23 | # attention for memory 24 | # self.fc_Uamem = nn.Linear(params['m'], params['dim_attention']) 25 | # self.fc_Wamem = nn.Linear(params['D'], params['dim_attention'], bias=False) 26 | # # self.conv_Qmem = nn.Conv2d(1, 512, kernel_size=(3,1), bias=False, padding=(1,0)) 27 | # # self.fc_Ufmem = nn.Linear(512, params['dim_attention']) 28 | # self.fc_vamem = nn.Linear(params['dim_attention'], 1) 29 | 30 | self.fc_Wyz1 = nn.Linear(params['D'], params['n']) 31 | self.fc_Wyr1 = nn.Linear(params['D'], params['n']) 32 | self.fc_Wyh1 = nn.Linear(params['D'], params['n']) 33 | self.fc_Uhz1 = nn.Linear(params['n'], params['n'], bias=False) 34 | self.fc_Uhr1 = nn.Linear(params['n'], params['n'], bias=False) 35 | self.fc_Uhh1 = nn.Linear(params['n'], params['n'], bias=False) 36 | 37 | # the first GRU layer 38 | self.fc_Wyz = nn.Linear(params['m'], params['n']) 39 | self.fc_Wyr = nn.Linear(params['m'], params['n']) 40 | self.fc_Wyh = nn.Linear(params['m'], params['n']) 41 | 42 | self.fc_Uhz = nn.Linear(params['n'], params['n'], bias=False) 43 | self.fc_Uhr = nn.Linear(params['n'], params['n'], bias=False) 44 | self.fc_Uhh = nn.Linear(params['n'], params['n'], bias=False) 45 | 46 | # attention for child symbol 47 | self.conv_UaC = nn.Conv2d(params['D'], params['dim_attention'], kernel_size=1) 48 | self.fc_WaC = nn.Linear(params['n'], params['dim_attention'], bias=False) 49 | self.conv_QC = nn.Conv2d(1, 512, kernel_size=11, bias=False, padding=5) 50 | self.fc_UfC = nn.Linear(512, params['dim_attention']) 51 | self.fc_vaC = nn.Linear(params['dim_attention'], 1) 52 | 53 | # the second GRU layer 54 | self.fc_Wcz = nn.Linear(params['D'], params['n'], bias=False) 55 | self.fc_Wcr = nn.Linear(params['D'], params['n'], bias=False) 56 | self.fc_Wch = nn.Linear(params['D'], params['n'], bias=False) 57 | 58 | self.fc_Uhz2 = nn.Linear(params['n'], params['n']) 59 | self.fc_Uhr2 = nn.Linear(params['n'], params['n']) 60 | self.fc_Uhh2 = nn.Linear(params['n'], params['n']) 61 | 62 | def forward(self, params, lembedding, rembedding, ly_mask=None, 63 | context=None, context_mask=None, init_state=None): 64 | 65 | n_steps = lembedding.shape[0] # seqs_y 66 | n_samples = lembedding.shape[1] # batch 67 | 68 | pctx_ = self.conv_UaC(context) # (batch,n',H,W) 69 | pctx_ = pctx_.permute(2, 3, 0, 1) # (H,W,batch,n') 70 | repctx_ = self.conv_UaP(context) # (batch,n',H,W) 71 | repctx_ = repctx_.permute(2, 3, 0, 1) # (H,W,batch,n') 72 | state_below_lz = self.fc_Wyz0(lembedding) 73 | state_below_lr = self.fc_Wyr0(lembedding) 74 | state_below_lh = self.fc_Wyh0(lembedding) 75 | state_below_z = self.fc_Wyz(rembedding) 76 | state_below_r = self.fc_Wyr(rembedding) 77 | state_below_h = self.fc_Wyh(rembedding) 78 | 79 | calpha_past = torch.zeros(n_samples, context.shape[2], context.shape[3]).cuda() # (batch,H,W) 80 | palpha_past = torch.zeros(n_samples, context.shape[2], context.shape[3]).cuda() 81 | h2t = init_state 82 | h2ts = torch.zeros(n_steps, n_samples, params['n']).cuda() 83 | h1ts = torch.zeros(n_steps, n_samples, params['n']).cuda() 84 | h01ts = torch.zeros(n_steps, n_samples, params['n']).cuda() 85 | ctCs = torch.zeros(n_steps, n_samples, params['D']).cuda() 86 | ctPs = torch.zeros(n_steps, n_samples, params['D']).cuda() 87 | cts = torch.zeros(n_steps, n_samples, 2*params['D']).cuda() 88 | calphas = (torch.zeros(n_steps, n_samples, context.shape[2], context.shape[3])).cuda() 89 | calpha_pasts = torch.zeros(n_steps, n_samples, context.shape[2], context.shape[3]).cuda() 90 | palphas = (torch.zeros(n_steps, n_samples, context.shape[2], context.shape[3])).cuda() 91 | palpha_pasts = torch.zeros(n_steps, n_samples, context.shape[2], context.shape[3]).cuda() 92 | 93 | for i in range(n_steps): 94 | h2t, h1t, h01t, ctC, ctP, ct, calpha, calpha_past, palpha, palpha_past = self._step_slice(ly_mask[i], 95 | context_mask, h2t, palpha_past, calpha_past, 96 | pctx_, repctx_, context, state_below_lz[i], 97 | state_below_lr[i], state_below_lh[i], state_below_z[i], 98 | state_below_r[i], state_below_h[i]) 99 | 100 | h2ts[i] = h2t # (seqs_y,batch,n) 101 | h1ts[i] = h1t 102 | h01ts[i] = h01t 103 | ctCs[i] = ctC 104 | ctPs[i] = ctP 105 | cts[i] = ct # (seqs_y,batch,D) 106 | calphas[i] = calpha # (seqs_y,batch,H,W) 107 | calpha_pasts[i] = calpha_past # (seqs_y,batch,H,W) 108 | palphas[i] = palpha 109 | palpha_pasts[i] = palpha_past 110 | return h2ts, h1ts, h01ts, ctCs, ctPs, cts, calphas, calpha_pasts, palphas, palpha_pasts 111 | 112 | def parent_forward(self, params, lembedding, ly_mask=None, context=None, context_mask=None, init_state=None, palpha_past=None): 113 | state_below_lz = self.fc_Wyz0(lembedding) 114 | state_below_lr = self.fc_Wyr0(lembedding) 115 | state_below_lh = self.fc_Wyh0(lembedding) 116 | repctx_ = self.conv_UaP(context) # (batch,n',H,W) 117 | repctx_ = repctx_.permute(2, 3, 0, 1) # (H,W,batch,n') 118 | # if ly_mask is None: 119 | # ly_mask = torch.ones(embedding.shape[0]).cuda() 120 | if ly_mask is None: 121 | ly_mask = torch.ones(lembedding.shape[0]).cuda() 122 | h0s, ctPs, palphas, palpha_pasts = self._step_slice_parent(ly_mask, context_mask, init_state, palpha_past, 123 | repctx_, context, state_below_lz, state_below_lr, state_below_lh) 124 | return h0s, ctPs, palphas, palpha_pasts 125 | 126 | def child_forward(self, params, rembedding, ctxP, ly_mask=None, 127 | context=None, context_mask=None, init_state=None, calpha_past=None): 128 | 129 | pctx_ = self.conv_UaC(context) # (batch,n',H,W) 130 | pctx_ = pctx_.permute(2, 3, 0, 1) # (H,W,batch,n') 131 | 132 | state_below_z = self.fc_Wyz(rembedding) 133 | state_below_r = self.fc_Wyr(rembedding) 134 | state_below_h = self.fc_Wyh(rembedding) 135 | 136 | if ly_mask is None: 137 | ly_mask = torch.ones(rembedding.shape[0]).cuda() 138 | h2ts, h1ts, ctCs, ctPs, cts, calphas, calpha_pasts = \ 139 | self._step_slice_child(ly_mask, context_mask, init_state, calpha_past, 140 | pctx_, context, state_below_z, state_below_r, state_below_h, ctxP) 141 | return h2ts, h1ts, ctCs, ctPs, cts, calphas, calpha_pasts 142 | 143 | # one step of two GRU layers 144 | def _step_slice(self, ly_mask, ctx_mask, h_, palpha_past_, calpha_past_, 145 | pctx_, repctx_, cc_, state_below_lz, state_below_lr, state_below_lh, state_below_z, state_below_r, state_below_h): 146 | 147 | z0 = torch.sigmoid(self.fc_Uhz0(h_) + state_below_lz) # (batch,n) 148 | r0 = torch.sigmoid(self.fc_Uhr0(h_) + state_below_lr) # (batch,n) 149 | h0_p = torch.tanh(self.fc_Uhh0(h_) * r0 + state_below_lh) # (batch,n) 150 | h0 = z0 * h_ + (1. - z0) * h0_p # (batch,n) 151 | h0 = ly_mask[:, None] * h0 + (1. - ly_mask)[:, None] * h_ 152 | 153 | # attention for parent symbol 154 | query_parent = self.fc_WaP(h0) 155 | palpha_past__ = palpha_past_[:, None, :, :] 156 | cover_FP = self.conv_QP(palpha_past__).permute(2, 3, 0, 1) # (H,W,batch,n') 157 | pcover_vector = self.fc_UfP(cover_FP) 158 | pattention_score = torch.tanh(repctx_ + query_parent[None, None, :, :] + pcover_vector) 159 | palpha = self.fc_vaP(pattention_score) 160 | palpha = palpha - palpha.max() 161 | palpha = palpha.view(palpha.shape[0], palpha.shape[1], palpha.shape[2]) 162 | palpha = torch.exp(palpha) 163 | if (ctx_mask is not None): 164 | palpha = palpha * ctx_mask.permute(1, 2, 0) 165 | palpha = palpha / (palpha.sum(1).sum(0)[None, None, :] + 1e-10) # (H,W,batch) 166 | palpha_past = palpha_past_ + palpha.permute(2, 0, 1) # (batch,H,W) 167 | ctP = (cc_ * palpha.permute(2, 0, 1)[:, None, :, :]).sum(3).sum(2) 168 | 169 | z01 = torch.sigmoid(self.fc_Uhz1(h0) + self.fc_Wyz1(ctP)) # zt (batch,n) 170 | r01 = torch.sigmoid(self.fc_Uhr1(h0) + self.fc_Wyr1(ctP)) # rt (batch,n) 171 | h01_p = torch.tanh(self.fc_Uhh1(h0) * r01 + self.fc_Wyh1(ctP)) # (batch,n) 172 | h01 = z01 * h0 + (1. - z01) * h01_p # (batch,n) 173 | h01 = ly_mask[:, None] * h01 + (1. - ly_mask)[:, None] * h0 174 | # the first GRU layer 175 | z1 = torch.sigmoid(self.fc_Uhz(h01) + state_below_z) # (batch,n) 176 | r1 = torch.sigmoid(self.fc_Uhr(h01) + state_below_r) # (batch,n) 177 | h1_p = torch.tanh(self.fc_Uhh(h01) * r1 + state_below_h) # (batch,n) 178 | h1 = z1 * h01 + (1. - z1) * h1_p # (batch,n) 179 | h1 = ly_mask[:, None] * h1 + (1. - ly_mask)[:, None] * h01 180 | 181 | # attention for child symbol 182 | query_child = self.fc_WaC(h1) 183 | calpha_past__ = calpha_past_[:, None, :, :] # (batch,1,H,W) 184 | cover_FC = self.conv_QC(calpha_past__).permute(2, 3, 0, 1) # (H,W,batch,n') 185 | ccover_vector = self.fc_UfC(cover_FC) # (H,W,batch,n') 186 | cattention_score = torch.tanh(pctx_ + query_child[None, None, :, :] + ccover_vector) # (H,W,batch,n') 187 | calpha = self.fc_vaC(cattention_score) # (H,W,batch,1) 188 | calpha = calpha - calpha.max() 189 | calpha = calpha.view(calpha.shape[0], calpha.shape[1], calpha.shape[2]) # (H,W,batch) 190 | calpha = torch.exp(calpha) # exp 191 | if (ctx_mask is not None): 192 | calpha = calpha * ctx_mask.permute(1, 2, 0) 193 | calpha = (calpha / calpha.sum(1).sum(0)[None, None, :] + 1e-10) # (H,W,batch) 194 | calpha_past = calpha_past_ + calpha.permute(2, 0, 1) # (batch,H,W) 195 | ctC = (cc_ * calpha.permute(2, 0, 1)[:, None, :, :]).sum(3).sum(2) # current context, (batch,D) 196 | 197 | # the second GRU layer 198 | ct = torch.cat((ctC, ctP), 1) 199 | z2 = torch.sigmoid(self.fc_Uhz2(h1) + self.fc_Wcz(ctC)) # zt (batch,n) 200 | r2 = torch.sigmoid(self.fc_Uhr2(h1) + self.fc_Wcr(ctC)) # rt (batch,n) 201 | h2_p = torch.tanh(self.fc_Uhh2(h1) * r2 + self.fc_Wch(ctC)) # (batch,n) 202 | h2 = z2 * h1 + (1. - z2) * h2_p # (batch,n) 203 | h2 = ly_mask[:, None] * h2 + (1. - ly_mask)[:, None] * h1 204 | 205 | return h2, h1, h01, ctC, ctP, ct, calpha.permute(2, 0, 1), calpha_past, palpha.permute(2, 0, 1), palpha_past 206 | 207 | def _step_slice_parent(self, ly_mask, ctx_mask, h_, palpha_past_, repctx_, cc_, state_below_lz, state_below_lr, state_below_lh): 208 | z0 = torch.sigmoid(self.fc_Uhz0(h_) + state_below_lz) # (batch,n) 209 | r0 = torch.sigmoid(self.fc_Uhr0(h_) + state_below_lr) # (batch,n) 210 | h0_p = torch.tanh(self.fc_Uhh0(h_) * r0 + state_below_lh) # (batch,n) 211 | h0 = z0 * h_ + (1. - z0) * h0_p # (batch,n) 212 | h0 = ly_mask[:, None] * h0 + (1. - ly_mask)[:, None] * h_ 213 | # attention for parent symbol 214 | query_parent = self.fc_WaP(h0) 215 | palpha_past__ = palpha_past_[:, None, :, :] 216 | cover_FP = self.conv_QP(palpha_past__).permute(2, 3, 0, 1) # (H,W,batch,n') 217 | pcover_vector = self.fc_UfP(cover_FP) 218 | pattention_score = torch.tanh(repctx_ + query_parent[None, None, :, :] + pcover_vector) 219 | palpha = self.fc_vaP(pattention_score) 220 | palpha = palpha - palpha.max() 221 | palpha = palpha.view(palpha.shape[0], palpha.shape[1], palpha.shape[2]) 222 | palpha = torch.exp(palpha) 223 | if (ctx_mask is not None): 224 | palpha = palpha * ctx_mask.permute(1, 2, 0) 225 | palpha = palpha / (palpha.sum(1).sum(0)[None, None, :] + 1e-10) # (H,W,batch) 226 | palpha_past = palpha_past_ + palpha.permute(2, 0, 1) # (batch,H,W) 227 | ctP = (cc_ * palpha.permute(2, 0, 1)[:, None, :, :]).sum(3).sum(2) 228 | 229 | z01 = torch.sigmoid(self.fc_Uhz1(h0) + self.fc_Wyz1(ctP)) # zt (batch,n) 230 | r01 = torch.sigmoid(self.fc_Uhr1(h0) + self.fc_Wyr1(ctP)) # rt (batch,n) 231 | h01_p = torch.tanh(self.fc_Uhh1(h0) * r01 + self.fc_Wyh1(ctP)) # (batch,n) 232 | h01 = z01 * h0 + (1. - z01) * h01_p # (batch,n) 233 | h01 = ly_mask[:, None] * h01 + (1. - ly_mask)[:, None] * h0 234 | 235 | return h01, ctP, palpha.permute(2, 0, 1), palpha_past 236 | 237 | def _step_slice_child(self, ly_mask, ctx_mask, h_, calpha_past_, 238 | pctx_, cc_, state_below_z, state_below_r, state_below_h, ctP): 239 | 240 | # the first GRU layer 241 | z1 = torch.sigmoid(self.fc_Uhz(h_) + state_below_z) # (batch,n) 242 | r1 = torch.sigmoid(self.fc_Uhr(h_) + state_below_r) # (batch,n) 243 | h1_p = torch.tanh(self.fc_Uhh(h_) * r1 + state_below_h) # (batch,n) 244 | h1 = z1 * h_ + (1. - z1) * h1_p # (batch,n) 245 | h1 = ly_mask[:, None] * h1 + (1. - ly_mask)[:, None] * h_ 246 | 247 | # attention for child symbol 248 | query_child = self.fc_WaC(h1) 249 | calpha_past__ = calpha_past_[:, None, :, :] # (batch,1,H,W) 250 | cover_FC = self.conv_QC(calpha_past__).permute(2, 3, 0, 1) # (H,W,batch,n') 251 | ccover_vector = self.fc_UfC(cover_FC) # (H,W,batch,n') 252 | cattention_score = torch.tanh(pctx_ + query_child[None, None, :, :] + ccover_vector) # (H,W,batch,n') 253 | calpha = self.fc_vaC(cattention_score) # (H,W,batch,1) 254 | calpha = calpha - calpha.max() 255 | calpha = calpha.view(calpha.shape[0], calpha.shape[1], calpha.shape[2]) # (H,W,batch) 256 | calpha = torch.exp(calpha) # exp 257 | if (ctx_mask is not None): 258 | calpha = calpha * ctx_mask.permute(1, 2, 0) 259 | calpha = (calpha / calpha.sum(1).sum(0)[None, None, :] + 1e-10) # (H,W,batch) 260 | calpha_past = calpha_past_ + calpha.permute(2, 0, 1) # (batch,H,W) 261 | ctC = (cc_ * calpha.permute(2, 0, 1)[:, None, :, :]).sum(3).sum(2) # current context, (batch,D) 262 | 263 | # the second GRU layer 264 | ct = torch.cat((ctC, ctP), 1) 265 | z2 = torch.sigmoid(self.fc_Uhz2(h1) + self.fc_Wcz(ctC)) # zt (batch,n) 266 | r2 = torch.sigmoid(self.fc_Uhr2(h1) + self.fc_Wcr(ctC)) # rt (batch,n) 267 | h2_p = torch.tanh(self.fc_Uhh2(h1) * r2 + self.fc_Wch(ctC)) # (batch,n) 268 | h2 = z2 * h1 + (1. - z2) * h2_p # (batch,n) 269 | h2 = ly_mask[:, None] * h2 + (1. - ly_mask)[:, None] * h1 270 | 271 | return h2, h1, ctC, ctP, ct, calpha.permute(2, 0, 1), calpha_past 272 | 273 | # calculate probabilities 274 | class Gru_prob(nn.Module): 275 | def __init__(self, params): 276 | super(Gru_prob, self).__init__() 277 | self.fc_WctC = nn.Linear(params['D'], params['m']) 278 | self.fc_WhtC = nn.Linear(params['n'], params['m']) 279 | self.fc_WytC = nn.Linear(params['m'], params['m']) 280 | self.dropout = nn.Dropout(p=0.2) 281 | self.fc_W0C = nn.Linear(int(params['m'] / 2), params['K']) 282 | # self.fc_WctP = nn.Linear(params['D'], params['m']) 283 | self.fc_W0P = nn.Linear(int(params['m'] / 2), params['K']) 284 | self.fc_WctRe = nn.Linear(2*params['D'], params['mre']) 285 | self.fc_W0Re = nn.Linear(int(params['mre']), params['Kre']) 286 | 287 | def forward(self, ctCs, ctPs, cts, htCs, prevC, use_dropout): 288 | clogit = self.fc_WctC(ctCs) + self.fc_WhtC(htCs) + self.fc_WytC(prevC) # (seqs_y,batch,m) 289 | # maxout 290 | cshape = clogit.shape # (seqs_y,batch,m) 291 | cshape2 = int(cshape[2] / 2) # m/2 292 | cshape3 = 2 293 | clogit = clogit.view(cshape[0], cshape[1], cshape2, cshape3) # (seqs_y,batch,m) -> (seqs_y,batch,m/2,2) 294 | clogit = clogit.max(3)[0] # (seqs_y,batch,m/2) 295 | if use_dropout: 296 | clogit = self.dropout(clogit) 297 | cprob = self.fc_W0C(clogit) # (seqs_y,batch,K) 298 | 299 | plogit = self.fc_WctC(ctPs) 300 | # maxout 301 | pshape = plogit.shape # (seqs_y,batch,m) 302 | pshape2 = int(pshape[2] / 2) # m/2 303 | pshape3 = 2 304 | plogit = plogit.view(pshape[0], pshape[1], pshape2, pshape3) # (seqs_y,batch,m) -> (seqs_y,batch,m/2,2) 305 | plogit = plogit.max(3)[0] # (seqs_y,batch,m/2) 306 | if use_dropout: 307 | plogit = self.dropout(plogit) 308 | pprob = self.fc_W0P(plogit) # (seqs_y,batch,K) 309 | 310 | relogit = self.fc_WctRe(cts) 311 | if use_dropout: 312 | relogit = self.dropout(relogit) 313 | reprob = self.fc_W0Re(relogit) 314 | 315 | return cprob, pprob, reprob 316 | -------------------------------------------------------------------------------- /codes/encoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | # DenseNet-B 8 | class Bottleneck(nn.Module): 9 | def __init__(self, nChannels, growthRate, use_dropout): 10 | super(Bottleneck, self).__init__() 11 | interChannels = 4 * growthRate 12 | self.bn1 = nn.BatchNorm2d(interChannels) 13 | self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(growthRate) 15 | self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, padding=1, bias=False) 16 | self.use_dropout = use_dropout 17 | self.dropout = nn.Dropout(p=0.2) 18 | 19 | def forward(self, x): 20 | out = F.relu(self.bn1(self.conv1(x)), inplace=True) 21 | if self.use_dropout: 22 | out = self.dropout(out) 23 | out = F.relu(self.bn2(self.conv2(out)), inplace=True) 24 | if self.use_dropout: 25 | out = self.dropout(out) 26 | out = torch.cat((x, out), 1) 27 | return out 28 | 29 | 30 | # single layer 31 | class SingleLayer(nn.Module): 32 | def __init__(self, nChannels, growthRate, use_dropout): 33 | super(SingleLayer, self).__init__() 34 | self.bn1 = nn.BatchNorm2d(nChannels) 35 | self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, padding=1, bias=False) 36 | self.use_dropout = use_dropout 37 | self.dropout = nn.Dropout(p=0.2) 38 | 39 | def forward(self, x): 40 | out = self.conv1(F.relu(x, inplace=True)) 41 | if self.use_dropout: 42 | out = self.dropout(out) 43 | out = torch.cat((x, out), 1) 44 | return out 45 | 46 | 47 | # transition layer 48 | class Transition(nn.Module): 49 | def __init__(self, nChannels, nOutChannels, use_dropout): 50 | super(Transition, self).__init__() 51 | self.bn1 = nn.BatchNorm2d(nOutChannels) 52 | self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=1, bias=False) 53 | self.use_dropout = use_dropout 54 | self.dropout = nn.Dropout(p=0.2) 55 | 56 | def forward(self, x): 57 | out = F.relu(self.bn1(self.conv1(x)), inplace=True) 58 | if self.use_dropout: 59 | out = self.dropout(out) 60 | out = F.avg_pool2d(out, 2, ceil_mode=True) 61 | return out 62 | 63 | 64 | class DenseNet(nn.Module): 65 | def __init__(self, growthRate, reduction, bottleneck, use_dropout): 66 | super(DenseNet, self).__init__() 67 | nDenseBlocks = 16 68 | nChannels = 2 * growthRate 69 | self.conv1 = nn.Conv2d(1, nChannels, kernel_size=7, padding=3, stride=2, bias=False) 70 | self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck, use_dropout) 71 | nChannels += nDenseBlocks * growthRate 72 | nOutChannels = int(math.floor(nChannels * reduction)) 73 | self.trans1 = Transition(nChannels, nOutChannels, use_dropout) 74 | 75 | nChannels = nOutChannels 76 | self.dense2 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck, use_dropout) 77 | nChannels += nDenseBlocks * growthRate 78 | nOutChannels = int(math.floor(nChannels * reduction)) 79 | self.trans2 = Transition(nChannels, nOutChannels, use_dropout) 80 | 81 | nChannels = nOutChannels 82 | self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck, use_dropout) 83 | 84 | def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck, use_dropout): 85 | layers = [] 86 | for i in range(int(nDenseBlocks)): 87 | if bottleneck: 88 | layers.append(Bottleneck(nChannels, growthRate, use_dropout)) 89 | else: 90 | layers.append(SingleLayer(nChannels, growthRate, use_dropout)) 91 | nChannels += growthRate 92 | return nn.Sequential(*layers) 93 | 94 | def forward(self, x, x_mask): 95 | out = self.conv1(x) 96 | out_mask = x_mask[:, 0::2, 0::2] 97 | out = F.relu(out, inplace=True) 98 | out = F.max_pool2d(out, 2, ceil_mode=True) 99 | out_mask = out_mask[:, 0::2, 0::2] 100 | out = self.dense1(out) 101 | out = self.trans1(out) 102 | out_mask = out_mask[:, 0::2, 0::2] 103 | out = self.dense2(out) 104 | out = self.trans2(out) 105 | out_mask = out_mask[:, 0::2, 0::2] 106 | out = self.dense3(out) 107 | return out, out_mask 108 | -------------------------------------------------------------------------------- /codes/encoder_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from encoder import DenseNet 5 | from decoder import Gru_cond_layer, Gru_prob 6 | import math 7 | 8 | # create gru init state 9 | class FcLayer(nn.Module): 10 | def __init__(self, nin, nout): 11 | super(FcLayer, self).__init__() 12 | self.fc = nn.Linear(nin, nout) 13 | 14 | def forward(self, x): 15 | out = torch.tanh(self.fc(x)) 16 | return out 17 | 18 | 19 | # Embedding 20 | class My_Embedding(nn.Module): 21 | def __init__(self, params): 22 | super(My_Embedding, self).__init__() 23 | self.embedding = nn.Embedding(params['K'], params['m']) 24 | self.pos_embedding = torch.zeros(params['maxlen'], params['m']).cuda() 25 | nin = params['maxlen'] 26 | nout = params['m'] 27 | d_model = nout 28 | for pos in range(nin): 29 | for i in range(nout//2): 30 | self.pos_embedding[pos, 2*i] = math.sin(1.*pos/(10000**(2.*i/d_model))) 31 | self.pos_embedding[pos, 2*i+1] = math.cos(1.*pos/(10000**(2.*i/d_model))) 32 | def forward(self, params, ly, lp, ry): 33 | if ly.sum() < 0.: # 34 | lemb = torch.zeros(1, params['m']).cuda() # (1,m) 35 | else: 36 | lemb = self.embedding(ly) # (seqs_y,batch,m) | (batch,m) 37 | if len(lemb.shape) == 3: # only for training stage 38 | lemb_shifted = torch.zeros([lemb.shape[0], lemb.shape[1], params['m']], dtype=torch.float32).cuda() 39 | lemb_shifted[1:] = lemb[:-1] 40 | lemb = lemb_shifted 41 | 42 | if lp.sum() < 1.: # pos=0 43 | Pemb = torch.zeros(1, params['m']).cuda() # (1,m) 44 | else: 45 | Pemb = self.pos_embedding[lp] # (seqs_y,batch,m) | (batch,m) 46 | if len(Pemb.shape) == 3: # only for training stage 47 | Pemb_shifted = torch.zeros([Pemb.shape[0], Pemb.shape[1], params['m']], dtype=torch.float32).cuda() 48 | Pemb_shifted[1:] = Pemb[:-1] 49 | Pemb = Pemb_shifted 50 | 51 | if ry.sum() < 0.: # 52 | remb = torch.zeros(1, params['m']).cuda() # (1,m) 53 | else: 54 | remb = self.embedding(ry) # (seqs_y,batch,m) | (batch,m) 55 | if len(remb.shape) == 3: # only for training stage 56 | remb_shifted = torch.zeros([remb.shape[0], remb.shape[1], params['m']], dtype=torch.float32).cuda() 57 | remb_shifted[1:] = remb[1:] 58 | remb = remb_shifted 59 | return lemb, Pemb, remb 60 | def word_emb(self, params, y): 61 | if y.sum() < 0.: # 62 | emb = torch.zeros(1, params['m']).cuda() # (1,m) 63 | else: 64 | emb = self.embedding(y) # (seqs_y,batch,m) | (batch,m) 65 | return emb 66 | def pos_emb(self, params, p): 67 | if p.sum() < 1.: # 68 | Pemb = torch.zeros(1, params['m']).cuda() # (1,m) 69 | else: 70 | Pemb = self.pos_embedding[p] # (seqs_y,batch,m) | (batch,m) 71 | return Pemb 72 | 73 | class Encoder_Decoder(nn.Module): 74 | def __init__(self, params): 75 | super(Encoder_Decoder, self).__init__() 76 | self.encoder = DenseNet(growthRate=params['growthRate'], reduction=params['reduction'], 77 | bottleneck=params['bottleneck'], use_dropout=params['use_dropout']) 78 | self.init_GRU_model = FcLayer(params['D'], params['n']) 79 | self.emb_model = My_Embedding(params) 80 | self.gru_model = Gru_cond_layer(params) 81 | self.gru_prob_model = Gru_prob(params) 82 | self.fc_Uamem = nn.Linear(params['n'], params['dim_attention']) 83 | self.fc_Wamem = nn.Linear(params['n'], params['dim_attention'], bias=False) 84 | # self.conv_Qmem = nn.Conv2d(1, 512, kernel_size=(3,1), bias=False, padding=(1,0)) 85 | # self.fc_Ufmem = nn.Linear(512, params['dim_attention']) 86 | self.fc_vamem = nn.Linear(params['dim_attention'], 1) 87 | self.criterion = torch.nn.CrossEntropyLoss(reduce=False) 88 | 89 | def forward(self, params, x, x_mask, ly, ly_mask, ry, ry_mask, re, re_mask, ma, ma_mask, lp, rp, one_step=False): 90 | # recover permute 91 | # ly = ly.permute(1, 0) 92 | # ly_mask = ly_mask.permute(1, 0) 93 | # ly = ly.permute(1, 0) 94 | # ly_mask = ly_mask.permute(1, 0) 95 | 96 | ma = ma.permute(2, 1, 0) # SeqY * Matt * batch 97 | ma_mask = ma_mask.permute(2, 1, 0) 98 | 99 | ctx, ctx_mask = self.encoder(x, x_mask) 100 | 101 | # init state 102 | ctx_mean = (ctx * ctx_mask[:, None, :, :]).sum(3).sum(2) / ctx_mask.sum(2).sum(1)[:, None] # (batch,D) 103 | init_state = self.init_GRU_model(ctx_mean) # (batch,n) 104 | 105 | # two GRU layers 106 | lemb, Pemb, remb = self.emb_model(params, ly, lp, ry) # (seqs_y,batch,m) 107 | # h2ts: (seqs_y,batch,n), cts: (seqs_y,batch,D), alphas: (seqs_y,batch,H,W) 108 | h2ts, h1ts, h01ts, ctCs, ctPs, cts, calphas, calpha_pasts, palphas, palpha_pasts= \ 109 | self.gru_model(params, lemb, remb, ly_mask, ctx, ctx_mask, init_state=init_state) 110 | 111 | word_pos_memory = torch.cat((init_state[None, :, :], h2ts[:-1]), 0) 112 | # word_pos_memory = lemb + Pemb 113 | mempctx_ = self.fc_Uamem(word_pos_memory) 114 | memquery = self.fc_Wamem(h01ts) 115 | memattention_score = torch.tanh(mempctx_[None, :, :, :] + memquery[:, None, :, :]) 116 | memalpha = self.fc_vamem(memattention_score) 117 | memalpha = memalpha - memalpha.max() 118 | memalpha = memalpha.view(memalpha.shape[0], memalpha.shape[1], memalpha.shape[2]) # SeqY * Matt * batch 119 | memalpha = torch.exp(memalpha) 120 | memalpha = memalpha * ma_mask # SeqY * Matt * batch 121 | memalpha = memalpha / (memalpha.sum(1)[:, None, :] + 1e-10) 122 | memalphas = memalpha + 1e-10 123 | cost_memalphas = - ma.float() * torch.log(memalphas) * ma_mask 124 | loss_memalphas = cost_memalphas.sum((0,1)) 125 | 126 | # compute KL alpha 127 | calpha_sort_ = torch.cat((torch.zeros(1, calphas.shape[1], calphas.shape[2], calphas.shape[3]).cuda(), calphas), 0) 128 | n_gaps = calpha_sort_.shape[0] 129 | n_batch = calpha_sort_.shape[1] 130 | n_H = calpha_sort_.shape[2] 131 | n_W = calpha_sort_.shape[3] 132 | rp = rp.permute(1,0) # batch * SeqY 133 | rp_shape = rp.shape 134 | rp = rp + n_gaps * torch.arange(n_batch)[:, None].cuda() 135 | 136 | calpha_sort = calpha_sort_.permute(1,0,2,3) 137 | calpha_sort = torch.reshape(calpha_sort, (calpha_sort.shape[0]*calpha_sort.shape[1],calpha_sort.shape[2],calpha_sort.shape[3])) 138 | calpha_sort = calpha_sort[rp.flatten()] 139 | calpha_sort = torch.reshape(calpha_sort, (rp_shape[0], rp_shape[1], n_H, n_W)) 140 | calpha_sort = calpha_sort.permute(1,0,2,3) 141 | 142 | calpha_sort = calpha_sort + 1e-10 143 | palphas = palphas + 1e-10 144 | cost_KL_alpha = calpha_sort * (torch.log(calpha_sort)-torch.log(palphas)) * ctx_mask[None, :, :, :] 145 | loss_KL = cost_KL_alpha.sum((0,2,3)) 146 | 147 | cscores, pscores, rescores = self.gru_prob_model(ctCs, ctPs, cts, h2ts, remb, use_dropout=params['use_dropout']) # (seqs_y x batch,K) 148 | 149 | cscores = cscores.contiguous() 150 | cscores = cscores.view(-1, cscores.shape[2]) 151 | ly = ly.contiguous() 152 | lpred_loss = self.criterion(cscores, ly.view(-1)) # (seqs_y x batch,) 153 | lpred_loss = lpred_loss.view(ly.shape[0], ly.shape[1]) # (seqs_y,batch) 154 | lpred_loss = (lpred_loss * ly_mask).sum(0) / (ly_mask.sum(0)+1e-10) 155 | lpred_loss = lpred_loss.mean() 156 | 157 | pscores = pscores.contiguous() 158 | pscores = pscores.view(-1, pscores.shape[2]) 159 | ry = ry.contiguous() 160 | rpred_loss = self.criterion(pscores, ry.view(-1)) 161 | rpred_loss = rpred_loss.view(ry.shape[0], ry.shape[1]) 162 | rpred_loss = (rpred_loss * ry_mask).sum(0) / (ry_mask.sum(0)+1e-10) 163 | rpred_loss = rpred_loss.mean() 164 | 165 | rescores = rescores.contiguous() 166 | rescores = rescores.view(-1, rescores.shape[2]) 167 | re = re.contiguous() 168 | repred_loss = self.criterion(rescores, re.view(-1)) 169 | repred_loss = repred_loss.view(re.shape[0], re.shape[1]) 170 | repred_loss = (repred_loss * re_mask).sum(0) / (re_mask.sum(0)+1e-10) 171 | repred_loss = repred_loss.mean() 172 | 173 | mem_loss = loss_memalphas / (ly_mask.sum(0)+1e-10) 174 | mem_loss = mem_loss.mean() 175 | 176 | KL_loss = loss_KL / (ly_mask.sum(0)+1e-10) 177 | KL_loss = KL_loss.mean() 178 | 179 | loss = params['ly_lambda']*lpred_loss + params['ry_lambda']*rpred_loss + \ 180 | params['re_lambda']*repred_loss + params['rpos_lambda']*mem_loss + params['KL_lambda']*KL_loss 181 | 182 | return loss, lpred_loss, rpred_loss, repred_loss, mem_loss, KL_loss 183 | 184 | # decoding: encoder part 185 | def f_init(self, x, x_mask=None): 186 | if x_mask is None: # x_mask is actually no use here 187 | shape = x.shape 188 | x_mask = torch.ones(shape).cuda() 189 | ctx, _ctx_mask = self.encoder(x, x_mask) 190 | ctx_mean = ctx.mean(dim=3).mean(dim=2) 191 | init_state = self.init_GRU_model(ctx_mean) # (1,n) 192 | return init_state, ctx 193 | 194 | def f_next_parent(self, params, ly, lp, ctx, init_state, h1t, palpha_past, nextemb_memory, nextePmb_memory, initIdx): 195 | emb = self.emb_model.word_emb(params, ly) 196 | # Pemb = self.emb_model.pos_emb(params, lp) 197 | nextemb_memory[initIdx, :, :] = emb 198 | # ePmb_memory_ = emb + Pemb 199 | nextePmb_memory[initIdx, :, :] = init_state 200 | 201 | h01, ctP, palpha, next_palpha_past = self.gru_model.parent_forward(params, emb, context=ctx, init_state=init_state, palpha_past=palpha_past) 202 | 203 | mempctx_ = self.fc_Uamem(nextePmb_memory) 204 | memquery = self.fc_Wamem(h01) 205 | memattention_score = torch.tanh(mempctx_ + memquery[None, :, :]) 206 | memalpha = self.fc_vamem(memattention_score) 207 | memalpha = memalpha - memalpha.max() 208 | memalpha = memalpha.view(memalpha.shape[0], memalpha.shape[1]) # Matt * batch 209 | memalpha = torch.exp(memalpha) 210 | mem_mask = torch.zeros(nextePmb_memory.shape[0], nextePmb_memory.shape[1]).cuda() 211 | mem_mask[:(initIdx+1), :] = 1 212 | memalpha = memalpha * mem_mask # Matt * batch 213 | memalpha = memalpha / (memalpha.sum(0) + 1e-10) 214 | 215 | Pmemalpha = memalpha.view(-1, memalpha.shape[1]) 216 | Pmemalpha = Pmemalpha.permute(1, 0) # batch * Matt 217 | return h01, Pmemalpha, ctP, palpha, next_palpha_past, nextemb_memory, nextePmb_memory 218 | 219 | # decoding: decoder part 220 | def f_next_child(self, params, remb, ctP, ctx, init_state, calpha_past): 221 | 222 | next_state, h1t, ctC, ctP, ct, calpha, next_calpha_past = \ 223 | self.gru_model.child_forward(params, remb, ctP, context=ctx, init_state=init_state, calpha_past=calpha_past) 224 | 225 | # reshape to suit GRU step code 226 | h2te = next_state.view(1, next_state.shape[0], next_state.shape[1]) 227 | ctC = ctC.view(1, ctC.shape[0], ctC.shape[1]) 228 | ctP = ctP.view(1, ctP.shape[0], ctP.shape[1]) 229 | ct = ct.view(1, ct.shape[0], ct.shape[1]) 230 | 231 | # calculate probabilities 232 | cscores, pscores, rescores = self.gru_prob_model(ctC, ctP, ct, h2te, remb, use_dropout=params['use_dropout']) 233 | cscores = cscores.view(-1, cscores.shape[2]) 234 | next_lprobs = F.softmax(cscores, dim=1) 235 | rescores = rescores.view(-1, rescores.shape[2]) 236 | next_reprobs = F.softmax(rescores, dim=1) 237 | next_re = torch.argmax(next_reprobs, dim=1) 238 | 239 | return next_lprobs, next_reprobs, next_state, h1t, calpha, next_calpha_past, next_re 240 | -------------------------------------------------------------------------------- /codes/gtd2latex.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import sys 5 | import pickle as pkl 6 | import numpy 7 | 8 | 9 | def convert(nodeid, gtd_list): 10 | isparent = False 11 | child_list = [] 12 | for i in range(len(gtd_list)): 13 | if gtd_list[i][2] == nodeid: 14 | isparent = True 15 | child_list.append([gtd_list[i][0],gtd_list[i][1],gtd_list[i][3]]) 16 | if not isparent: 17 | return [gtd_list[nodeid][0]] 18 | else: 19 | if gtd_list[nodeid][0] == '\\frac': 20 | return_string = [gtd_list[nodeid][0]] 21 | for i in range(len(child_list)): 22 | if child_list[i][2] == 'Above': 23 | return_string += ['{'] + convert(child_list[i][1], gtd_list) + ['}'] 24 | for i in range(len(child_list)): 25 | if child_list[i][2] == 'Below': 26 | return_string += ['{'] + convert(child_list[i][1], gtd_list) + ['}'] 27 | for i in range(len(child_list)): 28 | if child_list[i][2] == 'Right': 29 | return_string += convert(child_list[i][1], gtd_list) 30 | for i in range(len(child_list)): 31 | if child_list[i][2] not in ['Right','Above','Below']: 32 | return_string += ['illegal'] 33 | else: 34 | return_string = [gtd_list[nodeid][0]] 35 | for i in range(len(child_list)): 36 | if child_list[i][2] == 'Inside': 37 | return_string += ['{'] + convert(child_list[i][1], gtd_list) + ['}'] 38 | for i in range(len(child_list)): 39 | if child_list[i][2] in ['Sub','Below']: 40 | return_string += ['_','{'] + convert(child_list[i][1], gtd_list) + ['}'] 41 | for i in range(len(child_list)): 42 | if child_list[i][2] in ['Sup','Above']: 43 | return_string += ['^','{'] + convert(child_list[i][1], gtd_list) + ['}'] 44 | for i in range(len(child_list)): 45 | if child_list[i][2] in ['Right']: 46 | return_string += convert(child_list[i][1], gtd_list) 47 | return return_string 48 | 49 | latex_root_path = '***' 50 | gtd_root_path = '***' 51 | gtd_paths = ['test_caption_14','test_caption_16','test_caption_19'] 52 | 53 | for gtd_path in gtd_paths: 54 | gtd_files = os.listdir(gtd_root_path + gtd_path + '/') 55 | f_out = open(latex_root_path + gtd_path + '.txt', 'w') 56 | for process_num, gtd_file in enumerate(gtd_files): 57 | # gtd_file = '510_em_101.gtd' 58 | key = gtd_file[:-4] # remove .gtd 59 | f_out.write(key + '\t') 60 | gtd_list = [] 61 | gtd_list.append(['',0,-1,'root']) 62 | with open(gtd_root_path + gtd_path + '/' + gtd_file) as f: 63 | lines = f.readlines() 64 | for line in lines[:-1]: 65 | parts = line.split() 66 | sym = parts[0] 67 | childid = int(parts[1]) 68 | parentid = int(parts[3]) 69 | relation = parts[4] 70 | gtd_list.append([sym,childid,parentid,relation]) 71 | latex_list = convert(1, gtd_list) 72 | if 'illegal' in latex_list: 73 | print (key + ' has error') 74 | latex_string = ' ' 75 | else: 76 | latex_string = ' '.join(latex_list) 77 | f_out.write(latex_string + '\n') 78 | # sys.exit() 79 | 80 | if (process_num+1) // 2000 == (process_num+1) * 1.0 / 2000: 81 | print ('process files', process_num) 82 | -------------------------------------------------------------------------------- /codes/latex2gtd.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import sys 5 | import pickle as pkl 6 | import numpy 7 | 8 | latex_root_path = '' 9 | gtd_root_path = '' 10 | latex_files = ['test_caption_16.txt','test_caption_19.txt','valid_data_v1.txt','train_data_v1.txt','test_caption_14.txt'] 11 | for latexF in latex_files: 12 | latex_file = latex_root_path + latexF 13 | gtd_path = gtd_root_path + latexF[:-4] + '/' 14 | if not os.path.exists(gtd_path): 15 | os.mkdir(gtd_path) 16 | 17 | with open(latex_file) as f: 18 | lines = f.readlines() 19 | for process_num, line in enumerate(lines): 20 | if '\sqrt [' in line: 21 | continue 22 | parts = line.split() 23 | if len(parts) < 2: 24 | print ('error: invalid latex caption ...', line) 25 | continue 26 | key = parts[0] 27 | f_out = open(gtd_path + key + '.gtd', 'w') 28 | raw_cap = parts[1:] 29 | cap = [] 30 | for w in raw_cap: 31 | if w not in ['\limits']: 32 | cap.append(w) 33 | gtd_stack = [] 34 | idx = 0 35 | outidx = 1 36 | error_flag = False 37 | while idx < len(cap): 38 | if idx == 0: 39 | if cap[0] in ['{','}']: 40 | print ('error: {} should NOT appears at START') 41 | print (line.strip()) 42 | sys.exit() 43 | string = cap[0] + '\t' + str(outidx) + '\t\t0\tStart' 44 | f_out.write(string + '\n') 45 | idx += 1 46 | outidx += 1 47 | else: 48 | if cap[idx] == '{': 49 | if cap[idx-1] == '{': 50 | print ('error: double { appears') 51 | print (line.strip()) 52 | sys.exit() 53 | elif cap[idx-1] == '}': 54 | if gtd_stack[-1][0] != '\\frac': 55 | print ('error: } { not follows frac ...', key) 56 | f_out.close() 57 | os.system('rm ' + gtd_path + key + '.gtd') 58 | error_flag = True 59 | break 60 | else: 61 | gtd_stack[-1][2] = 'Below' 62 | idx += 1 63 | else: 64 | if cap[idx-1] == '\\frac': 65 | gtd_stack.append([cap[idx-1],str(outidx-1),'Above']) 66 | idx += 1 67 | elif cap[idx-1] == '\\sqrt': 68 | gtd_stack.append([cap[idx-1],str(outidx-1),'Inside']) 69 | idx += 1 70 | elif cap[idx-1] == '_': 71 | if cap[idx-2] in ['_','^','\\frac','\\sqrt']: 72 | print ('error: ^ _ follows wrong math symbols') 73 | print (line.strip()) 74 | sys.exit() 75 | elif cap[idx-2] in ['\\sum','\\int','\\lim']: 76 | gtd_stack.append([cap[idx-2],str(outidx-1),'Below']) 77 | idx += 1 78 | elif cap[idx-2] == '}': 79 | if gtd_stack[-1][0] in ['\\sum','\\int','\\lim']: 80 | gtd_stack[-1][2] = 'Below' 81 | else: 82 | gtd_stack[-1][2] = 'Sub' 83 | idx += 1 84 | else: 85 | gtd_stack.append([cap[idx-2],str(outidx-1),'Sub']) 86 | idx += 1 87 | elif cap[idx-1] == '^': 88 | if cap[idx-2] in ['_','^','\\frac','\\sqrt']: 89 | print ('error: ^ _ follows wrong math symbols') 90 | print (line.strip()) 91 | sys.exit() 92 | elif cap[idx-2] in ['\\sum','\\int','\\lim']: 93 | gtd_stack.append([cap[idx-2],str(outidx-1),'Above']) 94 | idx += 1 95 | elif cap[idx-2] == '}': 96 | if gtd_stack[-1][0] in ['\\sum','\\int','\\lim']: 97 | gtd_stack[-1][2] = 'Above' 98 | else: 99 | gtd_stack[-1][2] = 'Sup' 100 | idx += 1 101 | else: 102 | gtd_stack.append([cap[idx-2],str(outidx-1),'Sup']) 103 | idx += 1 104 | else: 105 | print ('error: { follows unknown math symbols ...', key) 106 | f_out.close() 107 | os.system('rm ' + gtd_path + key + '.gtd') 108 | error_flag = True 109 | break 110 | elif cap[idx] == '}': 111 | if cap[idx-1] == '}': 112 | del(gtd_stack[-1]) 113 | idx += 1 114 | elif cap[idx] in ['_','^']: 115 | if idx == len(cap)-1: 116 | print ('error: ^ _ appers at end ...', key) 117 | f_out.close() 118 | os.system('rm ' + gtd_path + key + '.gtd') 119 | error_flag = True 120 | break 121 | if cap[idx+1] != '{': 122 | print ('error: ^ _ not follows { ...', key) 123 | f_out.close() 124 | os.system('rm ' + gtd_path + key + '.gtd') 125 | error_flag = True 126 | break 127 | else: 128 | idx += 1 129 | elif cap[idx] in ['\limits']: 130 | print ('error: \limits happens') 131 | print (line.strip()) 132 | sys.exit() 133 | else: 134 | if cap[idx-1] == '{': 135 | string = cap[idx] + '\t' + str(outidx) + '\t' + gtd_stack[-1][0] + '\t' + gtd_stack[-1][1] + '\t' + gtd_stack[-1][2] 136 | f_out.write(string + '\n') 137 | outidx += 1 138 | idx += 1 139 | elif cap[idx-1] == '}': 140 | string = cap[idx] + '\t' + str(outidx) + '\t' + gtd_stack[-1][0] + '\t' + gtd_stack[-1][1] + '\tRight' 141 | f_out.write(string + '\n') 142 | outidx += 1 143 | idx += 1 144 | del(gtd_stack[-1]) 145 | else: 146 | parts = string.split('\t') 147 | string = cap[idx] + '\t' + str(outidx) + '\t' + parts[0] + '\t' + parts[1] + '\tRight' 148 | f_out.write(string + '\n') 149 | outidx += 1 150 | idx += 1 151 | if not error_flag: 152 | parts = string.split('\t') 153 | string = '\t' + str(outidx) + '\t' + parts[0] + '\t' + parts[1] + '\tEnd' 154 | f_out.write(string + '\n') 155 | f_out.close() 156 | 157 | if (process_num+1) // 1000 == (process_num+1) * 1.0 / 1000: 158 | print ('process files', process_num) 159 | -------------------------------------------------------------------------------- /codes/prepare_label.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import os 4 | import sys 5 | import pickle as pkl 6 | import numpy 7 | from scipy.misc import imread, imresize, imsave 8 | 9 | 10 | def gen_gtd_label(): 11 | 12 | bfs1_path = '' 13 | gtd_root_path = '' 14 | gtd_paths = [''] 15 | 16 | for gtd_path in gtd_paths: 17 | outpkl_label_file = bfs1_path + gtd_path + '_label_gtd.pkl' 18 | out_label_fp = open(outpkl_label_file, 'wb') 19 | label_lines = {} 20 | process_num = 0 21 | 22 | file_list = os.listdir(gtd_root_path + gtd_path) 23 | for file_name in file_list: 24 | key = file_name[:-4] # remove suffix .gtd 25 | if key in ['fa66375ede8be1c192a1acc2bc62b575.jpg']: 26 | continue 27 | with open(gtd_root_path + gtd_path + '/' + file_name) as f: 28 | lines = f.readlines() 29 | label_strs = [] 30 | for line in lines: 31 | parts = line.strip().split('\t') 32 | if len(parts) == 5: 33 | sym = parts[0] 34 | align = parts[1] 35 | related_sym = parts[2] 36 | realign = parts[3] 37 | relation = parts[4] 38 | string = sym + '\t' + align + '\t' + related_sym + '\t' + realign + '\t' + relation 39 | label_strs.append(string) 40 | else: 41 | print ('illegal line', key) 42 | sys.exit() 43 | label_lines[key] = label_strs 44 | 45 | process_num = process_num + 1 46 | if process_num // 2000 == process_num * 1.0 / 2000: 47 | print ('process files', process_num) 48 | 49 | print ('process files number ', process_num) 50 | 51 | pkl.dump(label_lines, out_label_fp) 52 | print ('save file done') 53 | out_label_fp.close() 54 | 55 | 56 | def gen_gtd_align(): 57 | 58 | bfs1_path = '' 59 | gtd_root_path = '' 60 | gtd_paths = [''] 61 | 62 | for gtd_path in gtd_paths: 63 | outpkl_label_file = bfs1_path + gtd_path + '_label_align_gtd.pkl' 64 | out_label_fp = open(outpkl_label_file, 'wb') 65 | label_aligns = {} 66 | process_num = 0 67 | 68 | file_list = os.listdir(gtd_root_path + gtd_path) 69 | for file_name in file_list: 70 | key = file_name[:-4] # remove suffix .gtd 71 | if key in ['fa66375ede8be1c192a1acc2bc62b575.jpg']: 72 | continue 73 | with open(gtd_root_path + gtd_path + '/' + file_name) as f: 74 | lines = f.readlines() 75 | wordNum = len(lines) 76 | align = numpy.zeros([wordNum, wordNum], dtype='int8') 77 | wordindex = -1 78 | 79 | for line in lines: 80 | wordindex += 1 81 | parts = line.strip().split('\t') 82 | if len(parts) == 5: 83 | realign = parts[3] 84 | realign_index = int(realign) 85 | align[realign_index,wordindex] = 1 86 | else: 87 | print ('illegal line', key) 88 | sys.exit() 89 | label_aligns[key] = align 90 | 91 | process_num = process_num + 1 92 | if process_num // 2000 == process_num * 1.0 / 2000: 93 | print ('process files', process_num) 94 | 95 | print ('process files number ', process_num) 96 | 97 | pkl.dump(label_aligns, out_label_fp) 98 | print ('save file done') 99 | out_label_fp.close() 100 | 101 | if __name__ == '__main__': 102 | gen_gtd_label() 103 | gen_gtd_align() -------------------------------------------------------------------------------- /codes/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # use CUDA 4 | export CUDA_VISIBLE_DEVICES=0 5 | # source ~/.bashrc 6 | python -u train_wap.py 7 | -------------------------------------------------------------------------------- /codes/train_wap.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import re 4 | import numpy as np 5 | import random 6 | import torch 7 | from torch import optim, nn 8 | from utils import BatchBucket, load_dict, prepare_data, gen_sample, weight_init, compute_wer, compute_sacc 9 | from encoder_decoder import Encoder_Decoder 10 | from data_iterator import dataIterator 11 | 12 | # whether use multi-GPUs 13 | multi_gpu_flag = False 14 | # whether init params 15 | init_param_flag = True 16 | # whether reload params 17 | reload_flag = False 18 | 19 | # load configurations 20 | # root_paths 21 | bfs2_path = '' 22 | work_path = '' 23 | 24 | model_idx = 7 25 | # paths 26 | dictionaries = [bfs2_path + '106_dictionary.txt', bfs2_path + '7relation_dictionary.txt'] 27 | datasets = [bfs2_path + 'jiaming-train-py3.pkl', bfs2_path + 'train_data_v1_label_gtd.pkl', bfs2_path + 'train_data_v1_label_align_gtd.pkl'] 28 | valid_datasets = [bfs2_path + 'jiaming-valid-py3.pkl', bfs2_path + 'valid_data_v1_label_gtd.pkl', bfs2_path + 'valid_data_v1_label_align_gtd.pkl'] 29 | valid_output = [work_path+'results'+str(model_idx)+'/symbol_relation/', work_path+'results'+str(model_idx)+'/memory_alpha/'] 30 | valid_result = [work_path+'results'+str(model_idx)+'/valid.cer', work_path+'results'+str(model_idx)+'/valid.exprate'] 31 | saveto = work_path+'models'+str(model_idx)+'/WAP_params.pkl' 32 | last_saveto = work_path+'models'+str(model_idx)+'/WAP_params_last.pkl' 33 | 34 | # training settings 35 | if multi_gpu_flag: 36 | batch_Imagesize = 500000 37 | valid_batch_Imagesize = 500000 38 | batch_size = 24 39 | valid_batch_size = 24 40 | else: 41 | batch_Imagesize = 500000 42 | valid_batch_Imagesize = 500000 43 | batch_size = 8 44 | valid_batch_size = 8 45 | maxImagesize = 500000 46 | maxlen = 200 47 | max_epochs = 5000 48 | lrate = 1.0 49 | my_eps = 1e-6 50 | decay_c = 1e-4 51 | clip_c = 100. 52 | 53 | # early stop 54 | estop = False 55 | halfLrFlag = 0 56 | bad_counter = 0 57 | patience = 15 58 | validStart = 10 59 | finish_after = 10000000 60 | 61 | # model architecture 62 | params = {} 63 | params['n'] = 256 64 | params['m'] = 256 65 | params['dim_attention'] = 512 66 | params['D'] = 684 67 | params['K'] = 106 68 | 69 | params['Kre'] = 7 70 | params['mre'] = 256 71 | params['maxlen'] = maxlen 72 | 73 | params['growthRate'] = 24 74 | params['reduction'] = 0.5 75 | params['bottleneck'] = True 76 | params['use_dropout'] = True 77 | params['input_channels'] = 1 78 | 79 | params['ly_lambda'] = 1. 80 | params['ry_lambda'] = 0.1 81 | params['re_lambda'] = 1. 82 | params['rpos_lambda'] = 1. 83 | params['KL_lambda'] = 0.1 84 | 85 | # load dictionary 86 | worddicts = load_dict(dictionaries[0]) 87 | print ('total chars',len(worddicts)) 88 | worddicts_r = [None] * len(worddicts) 89 | for kk, vv in worddicts.items(): 90 | worddicts_r[vv] = kk 91 | 92 | reworddicts = load_dict(dictionaries[1]) 93 | print ('total relations',len(reworddicts)) 94 | reworddicts_r = [None] * len(reworddicts) 95 | for kk, vv in reworddicts.items(): 96 | reworddicts_r[vv] = kk 97 | 98 | train,train_uid_list = dataIterator(datasets[0], datasets[1], datasets[2], worddicts, reworddicts, 99 | batch_size=batch_size, batch_Imagesize=batch_Imagesize,maxlen=maxlen,maxImagesize=maxImagesize) 100 | valid,valid_uid_list = dataIterator(valid_datasets[0], valid_datasets[1], valid_datasets[2], worddicts, reworddicts, 101 | batch_size=valid_batch_size, batch_Imagesize=valid_batch_Imagesize,maxlen=maxlen,maxImagesize=maxImagesize) 102 | # display 103 | uidx = 0 # count batch 104 | lpred_loss_s = 0. # count loss 105 | rpred_loss_s = 0. 106 | repred_loss_s = 0. 107 | mem_loss_s = 0. 108 | KL_loss_s = 0. 109 | loss_s = 0. 110 | ud_s = 0 # time for training an epoch 111 | validFreq = -1 112 | saveFreq = -1 113 | sampleFreq = -1 114 | dispFreq = 100 115 | if validFreq == -1: 116 | validFreq = len(train) 117 | if saveFreq == -1: 118 | saveFreq = len(train) 119 | if sampleFreq == -1: 120 | sampleFreq = len(train) 121 | 122 | # initialize model 123 | WAP_model = Encoder_Decoder(params) 124 | if init_param_flag: 125 | WAP_model.apply(weight_init) 126 | if multi_gpu_flag: 127 | WAP_model = nn.DataParallel(WAP_model, device_ids=[0, 1, 2, 3]) 128 | if reload_flag: 129 | WAP_model.load_state_dict(torch.load(saveto,map_location=lambda storage,loc:storage)) 130 | WAP_model.cuda() 131 | 132 | # print model's parameters 133 | model_params = WAP_model.named_parameters() 134 | for k, v in model_params: 135 | print(k) 136 | 137 | # loss function 138 | # criterion = torch.nn.CrossEntropyLoss(reduce=False) 139 | # optimizer 140 | optimizer = optim.Adadelta(WAP_model.parameters(), lr=lrate, eps=my_eps, weight_decay=decay_c) 141 | 142 | print('Optimization') 143 | 144 | # statistics 145 | history_errs = [] 146 | 147 | for eidx in range(max_epochs): 148 | n_samples = 0 149 | ud_epoch = time.time() 150 | random.shuffle(train) 151 | for x, ly, ry, re, ma, lp, rp in train: 152 | WAP_model.train() 153 | ud_start = time.time() 154 | n_samples += len(x) 155 | uidx += 1 156 | x, x_mask, ly, ly_mask, ry, ry_mask, re, re_mask, ma, ma_mask, lp, rp = \ 157 | prepare_data(params, x, ly, ry, re, ma, lp, rp) 158 | 159 | x = torch.from_numpy(x).cuda() # (batch,1,H,W) 160 | x_mask = torch.from_numpy(x_mask).cuda() # (batch,H,W) 161 | ly = torch.from_numpy(ly).cuda() # (seqs_y,batch) 162 | ly_mask = torch.from_numpy(ly_mask).cuda() # (seqs_y,batch) 163 | ry = torch.from_numpy(ry).cuda() # (seqs_y,batch) 164 | ry_mask = torch.from_numpy(ry_mask).cuda() # (seqs_y,batch) 165 | re = torch.from_numpy(re).cuda() # (seqs_y,batch) 166 | re_mask = torch.from_numpy(re_mask).cuda() # (seqs_y,batch) 167 | ma = torch.from_numpy(ma).cuda() # (batch,seqs_y,seqs_y) 168 | ma_mask = torch.from_numpy(ma_mask).cuda() # (batch,seqs_y,seqs_y) 169 | lp = torch.from_numpy(lp).cuda() # (seqs_y,batch) 170 | rp = torch.from_numpy(rp).cuda() # (seqs_y,batch) 171 | 172 | # permute for multi-GPU training 173 | # ly = ly.permute(1, 0) 174 | # ly_mask = ly_mask.permute(1, 0) 175 | # ry = ry.permute(1, 0) 176 | # ry_mask = ry_mask.permute(1, 0) 177 | # lp = lp.permute(1, 0) 178 | # rp = rp.permute(1, 0) 179 | 180 | # forward 181 | loss, lpred_loss, rpred_loss, repred_loss, mem_loss, KL_loss = \ 182 | WAP_model(params, x, x_mask, ly, ly_mask, ry, ry_mask, re, re_mask, ma, ma_mask, lp, rp) 183 | 184 | # recover from permute 185 | lpred_loss_s += lpred_loss.item() 186 | rpred_loss_s += rpred_loss.item() 187 | repred_loss_s += repred_loss.item() 188 | mem_loss_s += mem_loss.item() 189 | KL_loss_s += KL_loss.item() 190 | loss_s += loss.item() 191 | 192 | # backward 193 | optimizer.zero_grad() 194 | loss.backward() 195 | if clip_c > 0.: 196 | torch.nn.utils.clip_grad_norm_(WAP_model.parameters(), clip_c) 197 | 198 | # update 199 | optimizer.step() 200 | 201 | ud = time.time() - ud_start 202 | ud_s += ud 203 | 204 | # display 205 | if np.mod(uidx, dispFreq) == 0: 206 | ud_s /= 60. 207 | loss_s /= dispFreq 208 | lpred_loss_s /= dispFreq 209 | rpred_loss_s /= dispFreq 210 | repred_loss_s /= dispFreq 211 | mem_loss_s /= dispFreq 212 | KL_loss_s /= dispFreq 213 | print ('Epoch', eidx, ' Update', uidx, ' Cost_lpred %.7f, Cost_rpred %.7f, Cost_re %.7f, Cost_matt %.7f, Cost_kl %.7f' % \ 214 | (np.float(lpred_loss_s),np.float(rpred_loss_s),np.float(repred_loss_s),np.float(mem_loss_s),np.float(KL_loss_s)), \ 215 | ' UD %.3f' % ud_s, ' lrate', lrate, ' eps', my_eps, ' bad_counter', bad_counter) 216 | ud_s = 0 217 | loss_s = 0. 218 | lpred_loss_s = 0. 219 | rpred_loss_s = 0. 220 | repred_loss_s = 0. 221 | mem_loss_s = 0. 222 | KL_loss_s = 0. 223 | 224 | # validation 225 | if np.mod(uidx, sampleFreq) == 0 and eidx >= validStart: 226 | print('begin sampling') 227 | ud_epoch_train = (time.time() - ud_epoch) / 60. 228 | print('epoch training cost time ... ', ud_epoch_train) 229 | WAP_model.eval() 230 | valid_out_path = valid_output[0] 231 | valid_malpha_path = valid_output[1] 232 | if not os.path.exists(valid_out_path): 233 | os.mkdir(valid_out_path) 234 | if not os.path.exists(valid_malpha_path): 235 | os.mkdir(valid_malpha_path) 236 | rec_mat = {} 237 | label_mat = {} 238 | rec_re_mat = {} 239 | label_re_mat = {} 240 | rec_ridx_mat = {} 241 | label_ridx_mat = {} 242 | with torch.no_grad(): 243 | valid_count_idx = 0 244 | for x, ly, ry, re, ma, lp, rp in valid: 245 | for xx, lyy, ree, rpp in zip(x, ly, re, rp): 246 | xx_pad = xx.astype(np.float32) / 255. 247 | xx_pad = torch.from_numpy(xx_pad[None, :, :, :]).cuda() # (1,1,H,W) 248 | score, sample, malpha_list, relation_sample = \ 249 | gen_sample(WAP_model, xx_pad, params, multi_gpu_flag, k=3, maxlen=maxlen, rpos_beam=3) 250 | 251 | key = valid_uid_list[valid_count_idx] 252 | rec_mat[key] = [] 253 | label_mat[key] = lyy 254 | rec_re_mat[key] = [] 255 | label_re_mat[key] = ree 256 | rec_ridx_mat[key] = [] 257 | label_ridx_mat[key] = rpp 258 | if len(score) == 0: 259 | rec_mat[key].append(0) 260 | rec_re_mat[key].append(0) # End 261 | rec_ridx_mat[key].append(0) 262 | else: 263 | score = score / np.array([len(s) for s in sample]) 264 | min_score_index = score.argmin() 265 | ss = sample[min_score_index] 266 | rs = relation_sample[min_score_index] 267 | mali = malpha_list[min_score_index] 268 | for i, [vv, rv] in enumerate(zip(ss, rs)): 269 | if vv == 0: 270 | rec_mat[key].append(vv) 271 | rec_re_mat[key].append(0) # End 272 | break 273 | else: 274 | if i == 0: 275 | rec_mat[key].append(vv) 276 | rec_re_mat[key].append(6) # Start 277 | else: 278 | rec_mat[key].append(vv) 279 | rec_re_mat[key].append(rv) 280 | ma_idx_list = np.array(mali).astype(np.int64) 281 | ma_idx_list[-1] = int(len(ma_idx_list)-1) 282 | rec_ridx_mat[key] = ma_idx_list 283 | valid_count_idx=valid_count_idx+1 284 | 285 | print('valid set decode done') 286 | ud_epoch = (time.time() - ud_epoch) / 60. 287 | print('epoch cost time ... ', ud_epoch) 288 | 289 | if np.mod(uidx, saveFreq) == 0: 290 | print('Saving latest model params ... ') 291 | torch.save(WAP_model.state_dict(), last_saveto) 292 | 293 | # calculate wer and expRate 294 | if np.mod(uidx, validFreq) == 0 and eidx >= validStart: 295 | valid_cer_out = compute_wer(rec_mat, label_mat) 296 | valid_cer = 100. * valid_cer_out[0] 297 | valid_recer_out = compute_wer(rec_re_mat, label_re_mat) 298 | valid_recer = 100. * valid_recer_out[0] 299 | valid_ridxcer_out = compute_wer(rec_ridx_mat, label_ridx_mat) 300 | valid_ridxcer = 100. * valid_ridxcer_out[0] 301 | valid_exprate = compute_sacc(rec_mat, label_mat, rec_ridx_mat, label_ridx_mat, rec_re_mat, label_re_mat, worddicts_r, reworddicts_r) 302 | valid_exprate = 100. * valid_exprate 303 | valid_err=valid_cer+valid_ridxcer 304 | history_errs.append(valid_err) 305 | 306 | # the first time validation or better model 307 | if uidx // validFreq == 0 or valid_err <= np.array(history_errs).min(): 308 | bad_counter = 0 309 | print('Saving best model params ... ') 310 | if multi_gpu_flag: 311 | torch.save(WAP_model.module.state_dict(), saveto) 312 | else: 313 | torch.save(WAP_model.state_dict(), saveto) 314 | 315 | # worse model 316 | if uidx / validFreq != 0 and valid_err > np.array(history_errs).min(): 317 | bad_counter += 1 318 | if bad_counter > patience: 319 | if halfLrFlag == 2: 320 | print('Early Stop!') 321 | estop = True 322 | break 323 | else: 324 | print('Lr decay and retrain!') 325 | bad_counter = 0 326 | lrate = lrate / 10. 327 | params['KL_lambda'] = params['KL_lambda'] * 0.5 328 | for param_group in optimizer.param_groups: 329 | param_group['lr'] = lrate 330 | halfLrFlag += 1 331 | print ('Valid CER: %.2f%%, relation_CER: %.2f%%, rpos_CER: %.2f%%, ExpRate: %.2f%%' % (valid_cer,valid_recer,valid_ridxcer,valid_exprate)) 332 | # finish after these many updates 333 | if uidx >= finish_after: 334 | print('Finishing after %d iterations!' % uidx) 335 | estop = True 336 | break 337 | 338 | print('Seen %d samples' % n_samples) 339 | 340 | # early stop 341 | if estop: 342 | break 343 | -------------------------------------------------------------------------------- /codes/translate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import numpy as np 4 | import os 5 | import re 6 | import time 7 | import sys 8 | 9 | import torch 10 | 11 | from data_iterator import dataIterator, dataIterator_test 12 | from encoder_decoder import Encoder_Decoder 13 | 14 | 15 | # Note: 16 | # here model means Encoder_Decoder --> WAP_model 17 | # x means a sample not a batch(or batch_size = 1),and x's shape should be (1,1,H,W),type must be Variable 18 | # live_k is just equal to k -dead_k(except the begin of sentence:live_k = 1,dead_k = 0,so use k-dead_k to represent the number of alive paths in beam search) 19 | 20 | def gen_sample(model, x, params, gpu_flag, k=1, maxlen=30, rpos_beam=3): 21 | 22 | sample = [] 23 | sample_score = [] 24 | rpos_sample = [] 25 | # rpos_sample_score = [] 26 | relation_sample = [] 27 | 28 | live_k = 1 29 | dead_k = 0 # except init, live_k = k - dead_k 30 | 31 | # current living paths and corresponding scores(-log) 32 | hyp_samples = [[]] * live_k 33 | hyp_scores = np.zeros(live_k).astype(np.float32) 34 | hyp_rpos_samples = [[]] * live_k 35 | hyp_relation_samples = [[]] * live_k 36 | # get init state, (1,n) and encoder output, (1,D,H,W) 37 | next_state, ctx0 = model.f_init(x) 38 | next_h1t = next_state 39 | # -1 -> My_embedding -> 0 tensor(1,m) 40 | next_lw = -1 * torch.ones(1, dtype=torch.int64).cuda() 41 | next_calpha_past = torch.zeros(1, ctx0.shape[2], ctx0.shape[3]).cuda() # (live_k,H,W) 42 | next_palpha_past = torch.zeros(1, ctx0.shape[2], ctx0.shape[3]).cuda() 43 | nextemb_memory = torch.zeros(params['maxlen'], live_k, params['m']).cuda() 44 | nextePmb_memory = torch.zeros(params['maxlen'], live_k, params['m']).cuda() 45 | 46 | for ii in range(maxlen): 47 | ctxP = ctx0.repeat(live_k, 1, 1, 1) # (live_k,D,H,W) 48 | next_lpos = ii * torch.ones(live_k, dtype=torch.int64).cuda() 49 | next_h01, next_ma, next_ctP, next_pa, next_palpha_past, nextemb_memory, nextePmb_memory = \ 50 | model.f_next_parent(params, next_lw, next_lpos, ctxP, next_state, next_h1t, next_palpha_past, nextemb_memory, nextePmb_memory, ii) 51 | next_ma = next_ma.cpu().numpy() 52 | # next_ctP = next_ctP.cpu().numpy() 53 | next_palpha_past = next_palpha_past.cpu().numpy() 54 | nextemb_memory = nextemb_memory.cpu().numpy() 55 | nextePmb_memory = nextePmb_memory.cpu().numpy() 56 | 57 | nextemb_memory = np.transpose(nextemb_memory, (1, 0, 2)) # batch * Matt * dim 58 | nextePmb_memory = np.transpose(nextePmb_memory, (1, 0, 2)) 59 | 60 | next_rpos = next_ma.argsort(axis=1)[:,-rpos_beam:] # topK parent index; batch * topK 61 | n_gaps = nextemb_memory.shape[1] 62 | n_batch = nextemb_memory.shape[0] 63 | next_rpos_gap = next_rpos + n_gaps * np.arange(n_batch)[:, None] 64 | next_remb_memory = nextemb_memory.reshape([n_batch*n_gaps, nextemb_memory.shape[-1]]) 65 | next_remb = next_remb_memory[next_rpos_gap.flatten()] # [batch*rpos_beam, emb_dim] 66 | rpos_scores = next_ma.flatten()[next_rpos_gap.flatten()] # [batch*rpos_beam,] 67 | 68 | # next_ctPC = next_ctP.repeat(1, 1, rpos_beam) 69 | # next_ctPC = torch.reshape(next_ctPC, (-1, next_ctP.shape[1])) 70 | ctxC = ctx0.repeat(live_k*rpos_beam, 1, 1, 1) 71 | next_ctPC = torch.zeros(next_ctP.shape[0]*rpos_beam, next_ctP.shape[1]).cuda() 72 | next_h01C = torch.zeros(next_h01.shape[0]*rpos_beam, next_h01.shape[1]).cuda() 73 | next_calpha_pastC = torch.zeros(next_calpha_past.shape[0]*rpos_beam, next_calpha_past.shape[1], next_calpha_past.shape[2]).cuda() 74 | for bidx in range(next_calpha_past.shape[0]): 75 | for ridx in range(rpos_beam): 76 | next_ctPC[bidx*rpos_beam+ridx] = next_ctP[bidx] 77 | next_h01C[bidx*rpos_beam+ridx] = next_h01[bidx] 78 | next_calpha_pastC[bidx*rpos_beam+ridx] = next_calpha_past[bidx] 79 | next_remb = torch.from_numpy(next_remb).cuda() 80 | 81 | next_lp, next_rep, next_state, next_h1t, next_ca, next_calpha_past, next_re = \ 82 | model.f_next_child(params, next_remb, next_ctPC, ctxC, next_h01C, next_calpha_pastC) 83 | 84 | next_lp = next_lp.cpu().numpy() 85 | next_state = next_state.cpu().numpy() 86 | next_h1t = next_h1t.cpu().numpy() 87 | next_calpha_past = next_calpha_past.cpu().numpy() 88 | next_re = next_re.cpu().numpy() 89 | 90 | hyp_scores = np.tile(hyp_scores[:, None], [1, rpos_beam]).flatten() 91 | cand_scores = hyp_scores[:, None] - np.log(next_lp+1e-10)- np.log(rpos_scores+1e-10)[:,None] 92 | cand_flat = cand_scores.flatten() 93 | ranks_flat = cand_flat.argsort()[:(k-dead_k)] 94 | voc_size = next_lp.shape[1] 95 | trans_indices = ranks_flat // voc_size 96 | trans_indicesP = ranks_flat // (voc_size*rpos_beam) 97 | word_indices = ranks_flat % voc_size 98 | costs = cand_flat[ranks_flat] 99 | 100 | # update paths 101 | new_hyp_samples = [] 102 | new_hyp_scores = np.zeros(k-dead_k).astype('float32') 103 | new_hyp_rpos_samples = [] 104 | new_hyp_relation_samples = [] 105 | new_hyp_states = [] 106 | new_hyp_h1ts = [] 107 | new_hyp_calpha_past = [] 108 | new_hyp_palpha_past = [] 109 | new_hyp_emb_memory = [] 110 | new_hyp_ePmb_memory = [] 111 | 112 | for idx, [ti, wi, tPi] in enumerate(zip(trans_indices, word_indices, trans_indicesP)): 113 | new_hyp_samples.append(hyp_samples[tPi]+[wi]) 114 | new_hyp_scores[idx] = copy.copy(costs[idx]) 115 | new_hyp_rpos_samples.append(hyp_rpos_samples[tPi]+[next_rpos.flatten()[ti]]) 116 | new_hyp_relation_samples.append(hyp_relation_samples[tPi]+[next_re[ti]]) 117 | new_hyp_states.append(copy.copy(next_state[ti])) 118 | new_hyp_h1ts.append(copy.copy(next_h1t[ti])) 119 | new_hyp_calpha_past.append(copy.copy(next_calpha_past[ti])) 120 | new_hyp_palpha_past.append(copy.copy(next_palpha_past[tPi])) 121 | new_hyp_emb_memory.append(copy.copy(nextemb_memory[tPi])) 122 | new_hyp_ePmb_memory.append(copy.copy(nextePmb_memory[tPi])) 123 | 124 | # check the finished samples 125 | new_live_k = 0 126 | hyp_samples = [] 127 | hyp_scores = [] 128 | hyp_rpos_samples = [] 129 | hyp_relation_samples = [] 130 | hyp_states = [] 131 | hyp_h1ts = [] 132 | hyp_calpha_past = [] 133 | hyp_palpha_past = [] 134 | hyp_emb_memory = [] 135 | hyp_ePmb_memory = [] 136 | 137 | for idx in range(len(new_hyp_samples)): 138 | if new_hyp_samples[idx][-1] == 0: # 139 | sample_score.append(new_hyp_scores[idx]) 140 | sample.append(new_hyp_samples[idx]) 141 | rpos_sample.append(new_hyp_rpos_samples[idx]) 142 | relation_sample.append(new_hyp_relation_samples[idx]) 143 | dead_k += 1 144 | else: 145 | new_live_k += 1 146 | hyp_scores.append(new_hyp_scores[idx]) 147 | hyp_samples.append(new_hyp_samples[idx]) 148 | hyp_rpos_samples.append(new_hyp_rpos_samples[idx]) 149 | hyp_relation_samples.append(new_hyp_relation_samples[idx]) 150 | hyp_states.append(new_hyp_states[idx]) 151 | hyp_h1ts.append(new_hyp_h1ts[idx]) 152 | hyp_calpha_past.append(new_hyp_calpha_past[idx]) 153 | hyp_palpha_past.append(new_hyp_palpha_past[idx]) 154 | hyp_emb_memory.append(new_hyp_emb_memory[idx]) 155 | hyp_ePmb_memory.append(new_hyp_ePmb_memory[idx]) 156 | 157 | hyp_scores = np.array(hyp_scores) 158 | live_k = new_live_k 159 | 160 | # whether finish beam search 161 | if new_live_k < 1: 162 | break 163 | if dead_k >= k: 164 | break 165 | 166 | next_lw = np.array([w[-1] for w in hyp_samples]) # each path's final symbol, (live_k,) 167 | next_state = np.array(hyp_states) # h2t, (live_k,n) 168 | next_h1t = np.array(hyp_h1ts) 169 | next_calpha_past = np.array(hyp_calpha_past) # (live_k,H,W) 170 | next_palpha_past = np.array(hyp_palpha_past) 171 | nextemb_memory = np.array(hyp_emb_memory) 172 | nextemb_memory = np.transpose(nextemb_memory, (1, 0, 2)) 173 | nextePmb_memory = np.array(hyp_ePmb_memory) 174 | nextePmb_memory = np.transpose(nextePmb_memory, (1, 0, 2)) 175 | next_lw = torch.from_numpy(next_lw).cuda() 176 | next_state = torch.from_numpy(next_state).cuda() 177 | next_h1t = torch.from_numpy(next_h1t).cuda() 178 | next_calpha_past = torch.from_numpy(next_calpha_past).cuda() 179 | next_palpha_past = torch.from_numpy(next_palpha_past).cuda() 180 | nextemb_memory = torch.from_numpy(nextemb_memory).cuda() 181 | nextePmb_memory = torch.from_numpy(nextePmb_memory).cuda() 182 | 183 | return sample_score, sample, rpos_sample, relation_sample 184 | 185 | 186 | def load_dict(dictFile): 187 | fp = open(dictFile) 188 | stuff = fp.readlines() 189 | fp.close() 190 | lexicon = {} 191 | for l in stuff: 192 | w = l.strip().split() 193 | lexicon[w[0]] = int(w[1]) 194 | 195 | print('total words/phones', len(lexicon)) 196 | return lexicon 197 | 198 | 199 | def main(model_path, dictionary_target, dictionary_retarget, fea, output_path, k=5): 200 | # set parameters 201 | params = {} 202 | params['n'] = 256 203 | params['m'] = 256 204 | params['dim_attention'] = 512 205 | params['D'] = 684 206 | params['K'] = 106 207 | params['growthRate'] = 24 208 | params['reduction'] = 0.5 209 | params['bottleneck'] = True 210 | params['use_dropout'] = True 211 | params['input_channels'] = 1 212 | params['Kre'] = 7 213 | params['mre'] = 256 214 | 215 | maxlen = 300 216 | params['maxlen'] = maxlen 217 | 218 | # load model 219 | model = Encoder_Decoder(params) 220 | model.load_state_dict(torch.load(model_path,map_location=lambda storage,loc:storage)) 221 | # enable CUDA 222 | model.cuda() 223 | 224 | # load source dictionary and invert 225 | worddicts = load_dict(dictionary_target) 226 | print ('total chars',len(worddicts)) 227 | worddicts_r = [None] * len(worddicts) 228 | for kk, vv in worddicts.items(): 229 | worddicts_r[vv] = kk 230 | 231 | reworddicts = load_dict(dictionary_retarget) 232 | print ('total relations',len(reworddicts)) 233 | reworddicts_r = [None] * len(reworddicts) 234 | for kk, vv in reworddicts.items(): 235 | reworddicts_r[vv] = kk 236 | 237 | valid,valid_uid_list = dataIterator_test(fea, worddicts, reworddicts, 238 | batch_size=8, batch_Imagesize=800000,maxImagesize=800000) 239 | 240 | # change model's mode to eval 241 | model.eval() 242 | 243 | valid_out_path = output_path + 'symbol_relation/' 244 | valid_malpha_path = output_path + 'memory_alpha/' 245 | if not os.path.exists(valid_out_path): 246 | os.mkdir(valid_out_path) 247 | if not os.path.exists(valid_malpha_path): 248 | os.mkdir(valid_malpha_path) 249 | valid_count_idx = 0 250 | print('Decoding ... ') 251 | ud_epoch = time.time() 252 | model.eval() 253 | with torch.no_grad(): 254 | for x in valid: 255 | for xx in x: # xx:当前batch中的一个数据,numpy 256 | print('%d : %s' % (valid_count_idx + 1, valid_uid_list[valid_count_idx])) 257 | xx_pad = np.zeros((xx.shape[0], xx.shape[1], xx.shape[2]), dtype='float32') # (1,height,width) 258 | xx_pad[:, :, :] = xx / 255. 259 | xx_pad = torch.from_numpy(xx_pad[None, :, :, :]).cuda() 260 | score, sample, malpha_list, relation_sample = \ 261 | gen_sample(model, xx_pad, params, gpu_flag=False, k=k, maxlen=maxlen) 262 | # sys.exit() 263 | if len(score) != 0: 264 | score = score / np.array([len(s) for s in sample]) 265 | # relation_score = relation_score / np.array([len(r) for r in relation_sample]) 266 | min_score_index = score.argmin() 267 | ss = sample[min_score_index] 268 | rs = relation_sample[min_score_index] 269 | mali = malpha_list[min_score_index] 270 | fpp_sample = open(valid_out_path+valid_uid_list[valid_count_idx]+'.txt','w') 271 | file_malpha_sample = valid_malpha_path+valid_uid_list[valid_count_idx]+'_malpha.txt' 272 | for i, [vv, rv] in enumerate(zip(ss, rs)): 273 | if vv == 0: 274 | string = worddicts_r[vv] + '\tEnd\n' 275 | fpp_sample.write(string) 276 | break 277 | else: 278 | if i == 0: 279 | string = worddicts_r[vv] + '\tStart\n' 280 | else: 281 | string = worddicts_r[vv] + '\t' + reworddicts_r[rv] + '\n' 282 | fpp_sample.write(string) 283 | np.savetxt(file_malpha_sample, np.array(mali)) 284 | fpp_sample.close() 285 | valid_count_idx=valid_count_idx+1 286 | print('test set decode done') 287 | ud_epoch = (time.time() - ud_epoch) / 60. 288 | print('epoch cost time ... ', ud_epoch) 289 | 290 | # valid_result = [result_wer, result_exprate] 291 | # os.system('python compute_sym_re_ridx_cer.py ' + valid_out_path + ' ' + valid_malpha_path + ' ' + label_path + ' ' + valid_result[0]) 292 | # fpp=open(valid_result[0]) 293 | # lines = fpp.readlines() 294 | # fpp.close() 295 | # part1 = lines[-3].split() 296 | # if part1[0] == 'CER': 297 | # valid_cer=100. * float(part1[1]) 298 | # else: 299 | # print ('no CER result') 300 | # part2 = lines[-2].split() 301 | # if part2[0] == 'reCER': 302 | # valid_recer=100. * float(part2[1]) 303 | # else: 304 | # print ('no reCER result') 305 | # part3 = lines[-1].split() 306 | # if part3[0] == 'ridxCER': 307 | # valid_ridxcer=100. * float(part3[1]) 308 | # else: 309 | # print ('no ridxCER result') 310 | # os.system('python evaluate_ExpRate2.py ' + valid_out_path + ' ' + valid_malpha_path + ' ' + label_path + ' ' + valid_result[1]) 311 | # fpp=open(valid_result[1]) 312 | # exp_lines = fpp.readlines() 313 | # fpp.close() 314 | # parts = exp_lines[0].split() 315 | # if parts[0] == 'ExpRate': 316 | # valid_exprate = float(parts[1]) 317 | # else: 318 | # print ('no ExpRate result') 319 | # print ('ExpRate: %.2f%%' % (valid_exprate)) 320 | # print ('Valid CER: %.2f%%, relation_CER: %.2f%%, rpos_CER: %.2f%%, ExpRate: %.2f%%' % (valid_cer,valid_recer,valid_ridxcer,valid_exprate)) 321 | 322 | 323 | if __name__ == "__main__": 324 | parser = argparse.ArgumentParser() 325 | parser.add_argument('-k', type=int, default=10) 326 | parser.add_argument('model_path', type=str) 327 | parser.add_argument('dictionary_target', type=str) 328 | parser.add_argument('dictionary_retarget', type=str) 329 | parser.add_argument('fea', type=str) 330 | parser.add_argument('output_path', type=str) 331 | 332 | args = parser.parse_args() 333 | 334 | main(args.model_path, args.dictionary_target, args.dictionary_retarget, args.fea, 335 | args.output_path, k=args.k) 336 | -------------------------------------------------------------------------------- /codes/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import copy 4 | import sys 5 | import pickle as pkl 6 | import torch 7 | from torch import nn 8 | 9 | 10 | class BatchBucket(): 11 | def __init__(self, max_h, max_w, max_l, max_img_size, max_batch_size, feature_file, label_file, dictionary, 12 | use_all=True): 13 | self._max_img_size = max_img_size 14 | self._max_batch_size = max_batch_size 15 | self._fea_file = feature_file 16 | self._label_file = label_file 17 | self._dictionary_file = dictionary 18 | self._use_all = use_all 19 | self._dict_load() 20 | self._data_load() 21 | self.keys = self._calc_keys(max_h, max_w, max_l) 22 | self._make_plan() 23 | self._reset() 24 | 25 | def _dict_load(self): 26 | fp = open(self._dictionary_file) 27 | stuff = fp.readlines() 28 | fp.close() 29 | self._lexicon = {} 30 | for l in stuff: 31 | w = l.strip().split() 32 | self._lexicon[w[0]] = int(w[1]) 33 | 34 | def _data_load(self): 35 | fp_fea = open(self._fea_file, 'rb') 36 | self._features = pkl.load(fp_fea) 37 | fp_fea.close() 38 | fp_label = open(self._label_file, 'r') 39 | labels = fp_label.readlines() 40 | fp_label.close() 41 | self._targets = {} 42 | for l in labels: 43 | tmp = l.strip().split() 44 | uid = tmp[0] 45 | w_list = [] 46 | for w in tmp[1:]: 47 | if self._lexicon.__contains__(w): 48 | w_list.append(self._lexicon[w]) 49 | else: 50 | print('a word not in the dictionary !! sentence ', uid, 'word ', w) 51 | sys.exit() 52 | self._targets[uid] = w_list 53 | # (uid, h, w, tgt_len) 54 | self._data_parser = [(uid, fea.shape[1], fea.shape[2], len(label)) for (uid, fea), (_, label) in 55 | zip(self._features.items(), self._targets.items())] 56 | 57 | def _calc_keys(self, max_h, max_w, max_l): 58 | mh = mw = ml = 0 59 | for _, h, w, l in self._data_parser: 60 | if h > mh: 61 | mh = h 62 | if w > mw: 63 | mw = w 64 | if l > ml: 65 | ml = l 66 | max_h = min(max_h, mh) 67 | max_w = min(max_w, mw) 68 | max_l = min(max_l, ml) 69 | keys = [] 70 | init_h = 64 if 64 < max_h else max_h 71 | init_w = 64 if 64 < max_w else max_w 72 | init_l = 20 if 20 < max_l else max_l 73 | h_step = 64 74 | w_step = 64 75 | l_step = 30 76 | h = init_h 77 | while h <= max_h: 78 | w = init_w 79 | while w <= max_w: 80 | l = init_l 81 | while l <= max_l: 82 | keys.append([h, w, l, h * w * l, 0]) 83 | if l < max_l and l + l_step > max_l: 84 | l = max_l 85 | continue 86 | l += l_step 87 | if w < max_w and w + w_step > max_w: 88 | w = max_w 89 | continue 90 | w += w_step 91 | if h < max_h and h + h_step > max_h: 92 | h = max_h 93 | continue 94 | h += h_step 95 | keys = sorted(keys, key=lambda area: area[3]) 96 | for _, h, w, l in self._data_parser: 97 | for i in range(len(keys)): 98 | hh, ww, ll, _, _ = keys[i] 99 | if h <= hh and w <= ww and l <= ll: 100 | keys[i][-1] += 1 101 | break 102 | new_keys = [] 103 | n_samples = len(self._data_parser) 104 | th = n_samples * 0.01 105 | if self._use_all: 106 | th = 1 107 | num = 0 108 | for key in keys: 109 | hh, ww, ll, _, n = key 110 | num += n 111 | if num >= th: 112 | new_keys.append((hh, ww, ll)) 113 | num = 0 114 | return new_keys 115 | 116 | def _make_plan(self): 117 | self._bucket_keys = [] 118 | for h, w, l in self.keys: 119 | batch_size = int(self._max_img_size / (h * w)) 120 | if batch_size > self._max_batch_size: 121 | batch_size = self._max_batch_size 122 | if batch_size == 0: 123 | continue 124 | self._bucket_keys.append((batch_size, h, w, l)) 125 | self._data_buckets = [[] for key in self._bucket_keys] 126 | unuse_num = 0 127 | for item in self._data_parser: 128 | flag = 0 129 | for key, bucket in zip(self._bucket_keys, self._data_buckets): 130 | _, h, w, l = key 131 | if item[1] <= h and item[2] <= w and item[3] <= l: 132 | bucket.append(item) 133 | flag = 1 134 | break 135 | if flag == 0: 136 | unuse_num += 1 137 | print('The number of unused samples: ', unuse_num) 138 | all_sample_num = 0 139 | for key, bucket in zip(self._bucket_keys, self._data_buckets): 140 | sample_num = len(bucket) 141 | all_sample_num += sample_num 142 | print('bucket {}, sample number={}'.format(key, len(bucket))) 143 | print('All samples number={}, raw samples number={}'.format(all_sample_num, len(self._data_parser))) 144 | 145 | def _reset(self): 146 | # shuffle data in each bucket 147 | for bucket in self._data_buckets: 148 | random.shuffle(bucket) 149 | self._batches = [] 150 | for id, (key, bucket) in enumerate(zip(self._bucket_keys, self._data_buckets)): 151 | batch_size, _, _, _ = key 152 | bucket_len = len(bucket) 153 | batch_num = (bucket_len + batch_size - 1) // batch_size 154 | for i in range(batch_num): 155 | start = i * batch_size 156 | end = start + batch_size if start + batch_size < bucket_len else bucket_len 157 | if start != end: # remove empty batch 158 | self._batches.append(bucket[start:end]) 159 | 160 | def get_batches(self): 161 | batches = [] 162 | uid_batches = [] 163 | for batch_info in self._batches: 164 | fea_batch = [] 165 | label_batch = [] 166 | for uid, _, _, _ in batch_info: 167 | feature = self._features[uid] 168 | label = self._targets[uid] 169 | fea_batch.append(feature) 170 | label_batch.append(label) 171 | uid_batches.append(uid) 172 | batches.append((fea_batch, label_batch)) 173 | return batches, uid_batches 174 | 175 | 176 | # load dictionary 177 | def load_dict(dictFile): 178 | fp = open(dictFile) 179 | stuff = fp.readlines() 180 | fp.close() 181 | lexicon = {} 182 | for l in stuff: 183 | w = l.strip().split() 184 | lexicon[w[0]] = int(w[1]) 185 | print('total words/phones', len(lexicon)) 186 | return lexicon 187 | 188 | 189 | # create batch 190 | def prepare_data(params, images_x, seqs_ly, seqs_ry, seqs_re, seqs_ma, seqs_lp, seqs_rp): 191 | heights_x = [s.shape[1] for s in images_x] 192 | widths_x = [s.shape[2] for s in images_x] 193 | lengths_ly = [len(s) for s in seqs_ly] 194 | lengths_ry = [len(s) for s in seqs_ry] 195 | 196 | n_samples = len(heights_x) 197 | max_height_x = np.max(heights_x) 198 | max_width_x = np.max(widths_x) 199 | maxlen_ly = np.max(lengths_ly) 200 | maxlen_ry = np.max(lengths_ry) 201 | 202 | x = np.zeros((n_samples, params['input_channels'], max_height_x, max_width_x)).astype(np.float32) 203 | ly = np.zeros((maxlen_ly, n_samples)).astype(np.int64) # must be 0 in the dict 204 | ry = np.zeros((maxlen_ry, n_samples)).astype(np.int64) 205 | re = np.zeros((maxlen_ly, n_samples)).astype(np.int64) 206 | ma = np.zeros((n_samples, maxlen_ly, maxlen_ly)).astype(np.int64) 207 | lp = np.zeros((maxlen_ly, n_samples)).astype(np.int64) 208 | rp = np.zeros((maxlen_ry, n_samples)).astype(np.int64) 209 | 210 | x_mask = np.zeros((n_samples, max_height_x, max_width_x)).astype(np.float32) 211 | ly_mask = np.zeros((maxlen_ly, n_samples)).astype(np.float32) 212 | ry_mask = np.zeros((maxlen_ry, n_samples)).astype(np.float32) 213 | re_mask = np.zeros((maxlen_ly, n_samples)).astype(np.float32) 214 | ma_mask = np.zeros((n_samples, maxlen_ly, maxlen_ly)).astype(np.float32) 215 | 216 | for idx, [s_x, s_ly, s_ry, s_re, s_ma, s_lp, s_rp] in enumerate(zip(images_x, seqs_ly, seqs_ry, seqs_re, seqs_ma, seqs_lp, seqs_rp)): 217 | x[idx, :, :heights_x[idx], :widths_x[idx]] = s_x / 255. 218 | x_mask[idx, :heights_x[idx], :widths_x[idx]] = 1. 219 | ly[:lengths_ly[idx], idx] = s_ly 220 | ly_mask[:lengths_ly[idx], idx] = 1. 221 | ry[:lengths_ry[idx], idx] = s_ry 222 | ry_mask[:lengths_ry[idx], idx] = 1. 223 | ry_mask[0, idx] = 0. # remove the 224 | re[:lengths_ly[idx], idx] = s_re 225 | re_mask[:lengths_ly[idx], idx] = 1. 226 | re_mask[0, idx] = 0. # remove the Start relation 227 | re_mask[lengths_ly[idx]-1, idx] = 0. # remove the End relation 228 | ma[idx, :lengths_ly[idx], :lengths_ly[idx]] = s_ma 229 | for ma_idx in range(lengths_ly[idx]): 230 | ma_mask[idx, :(ma_idx+1), ma_idx] = 1. 231 | lp[:lengths_ly[idx], idx] = s_lp 232 | # lp_mask[:lengths_ly[idx], idx] = 1 233 | rp[:lengths_ry[idx], idx] = s_rp 234 | 235 | return x, x_mask, ly, ly_mask, ry, ry_mask, re, re_mask, ma, ma_mask, lp, rp 236 | 237 | def gen_sample(model, x, params, gpu_flag, k=1, maxlen=30, rpos_beam=3): 238 | 239 | sample = [] 240 | sample_score = [] 241 | rpos_sample = [] 242 | # rpos_sample_score = [] 243 | relation_sample = [] 244 | 245 | live_k = 1 246 | dead_k = 0 # except init, live_k = k - dead_k 247 | 248 | # current living paths and corresponding scores(-log) 249 | hyp_samples = [[]] * live_k 250 | hyp_scores = np.zeros(live_k).astype(np.float32) 251 | hyp_rpos_samples = [[]] * live_k 252 | hyp_relation_samples = [[]] * live_k 253 | # get init state, (1,n) and encoder output, (1,D,H,W) 254 | next_state, ctx0 = model.f_init(x) 255 | next_h1t = next_state 256 | # -1 -> My_embedding -> 0 tensor(1,m) 257 | next_lw = -1 * torch.ones(1, dtype=torch.int64).cuda() 258 | next_calpha_past = torch.zeros(1, ctx0.shape[2], ctx0.shape[3]).cuda() # (live_k,H,W) 259 | next_palpha_past = torch.zeros(1, ctx0.shape[2], ctx0.shape[3]).cuda() 260 | nextemb_memory = torch.zeros(params['maxlen'], live_k, params['m']).cuda() 261 | nextePmb_memory = torch.zeros(params['maxlen'], live_k, params['m']).cuda() 262 | 263 | for ii in range(maxlen): 264 | ctxP = ctx0.repeat(live_k, 1, 1, 1) # (live_k,D,H,W) 265 | next_lpos = ii * torch.ones(live_k, dtype=torch.int64).cuda() 266 | next_h01, next_ma, next_ctP, next_pa, next_palpha_past, nextemb_memory, nextePmb_memory = \ 267 | model.f_next_parent(params, next_lw, next_lpos, ctxP, next_state, next_h1t, next_palpha_past, nextemb_memory, nextePmb_memory, ii) 268 | next_ma = next_ma.cpu().numpy() 269 | # next_ctP = next_ctP.cpu().numpy() 270 | next_palpha_past = next_palpha_past.cpu().numpy() 271 | nextemb_memory = nextemb_memory.cpu().numpy() 272 | nextePmb_memory = nextePmb_memory.cpu().numpy() 273 | 274 | nextemb_memory = np.transpose(nextemb_memory, (1, 0, 2)) # batch * Matt * dim 275 | nextePmb_memory = np.transpose(nextePmb_memory, (1, 0, 2)) 276 | 277 | next_rpos = next_ma.argsort(axis=1)[:,-rpos_beam:] # topK parent index; batch * topK 278 | n_gaps = nextemb_memory.shape[1] 279 | n_batch = nextemb_memory.shape[0] 280 | next_rpos_gap = next_rpos + n_gaps * np.arange(n_batch)[:, None] 281 | next_remb_memory = nextemb_memory.reshape([n_batch*n_gaps, nextemb_memory.shape[-1]]) 282 | next_remb = next_remb_memory[next_rpos_gap.flatten()] # [batch*rpos_beam, emb_dim] 283 | rpos_scores = next_ma.flatten()[next_rpos_gap.flatten()] # [batch*rpos_beam,] 284 | 285 | # next_ctPC = next_ctP.repeat(1, 1, rpos_beam) 286 | # next_ctPC = torch.reshape(next_ctPC, (-1, next_ctP.shape[1])) 287 | ctxC = ctx0.repeat(live_k*rpos_beam, 1, 1, 1) 288 | next_ctPC = torch.zeros(next_ctP.shape[0]*rpos_beam, next_ctP.shape[1]).cuda() 289 | next_h01C = torch.zeros(next_h01.shape[0]*rpos_beam, next_h01.shape[1]).cuda() 290 | next_calpha_pastC = torch.zeros(next_calpha_past.shape[0]*rpos_beam, next_calpha_past.shape[1], next_calpha_past.shape[2]).cuda() 291 | for bidx in range(next_calpha_past.shape[0]): 292 | for ridx in range(rpos_beam): 293 | next_ctPC[bidx*rpos_beam+ridx] = next_ctP[bidx] 294 | next_h01C[bidx*rpos_beam+ridx] = next_h01[bidx] 295 | next_calpha_pastC[bidx*rpos_beam+ridx] = next_calpha_past[bidx] 296 | next_remb = torch.from_numpy(next_remb).cuda() 297 | 298 | next_lp, next_rep, next_state, next_h1t, next_ca, next_calpha_past, next_re = \ 299 | model.f_next_child(params, next_remb, next_ctPC, ctxC, next_h01C, next_calpha_pastC) 300 | 301 | next_lp = next_lp.cpu().numpy() 302 | next_state = next_state.cpu().numpy() 303 | next_h1t = next_h1t.cpu().numpy() 304 | next_calpha_past = next_calpha_past.cpu().numpy() 305 | next_re = next_re.cpu().numpy() 306 | 307 | hyp_scores = np.tile(hyp_scores[:, None], [1, rpos_beam]).flatten() 308 | cand_scores = hyp_scores[:, None] - np.log(next_lp+1e-10)- np.log(rpos_scores+1e-10)[:,None] 309 | cand_flat = cand_scores.flatten() 310 | ranks_flat = cand_flat.argsort()[:(k-dead_k)] 311 | voc_size = next_lp.shape[1] 312 | trans_indices = ranks_flat // voc_size 313 | trans_indicesP = ranks_flat // (voc_size*rpos_beam) 314 | word_indices = ranks_flat % voc_size 315 | costs = cand_flat[ranks_flat] 316 | 317 | # update paths 318 | new_hyp_samples = [] 319 | new_hyp_scores = np.zeros(k-dead_k).astype('float32') 320 | new_hyp_rpos_samples = [] 321 | new_hyp_relation_samples = [] 322 | new_hyp_states = [] 323 | new_hyp_h1ts = [] 324 | new_hyp_calpha_past = [] 325 | new_hyp_palpha_past = [] 326 | new_hyp_emb_memory = [] 327 | new_hyp_ePmb_memory = [] 328 | 329 | for idx, [ti, wi, tPi] in enumerate(zip(trans_indices, word_indices, trans_indicesP)): 330 | new_hyp_samples.append(hyp_samples[tPi]+[wi]) 331 | new_hyp_scores[idx] = copy.copy(costs[idx]) 332 | new_hyp_rpos_samples.append(hyp_rpos_samples[tPi]+[next_rpos.flatten()[ti]]) 333 | new_hyp_relation_samples.append(hyp_relation_samples[tPi]+[next_re[ti]]) 334 | new_hyp_states.append(copy.copy(next_state[ti])) 335 | new_hyp_h1ts.append(copy.copy(next_h1t[ti])) 336 | new_hyp_calpha_past.append(copy.copy(next_calpha_past[ti])) 337 | new_hyp_palpha_past.append(copy.copy(next_palpha_past[tPi])) 338 | new_hyp_emb_memory.append(copy.copy(nextemb_memory[tPi])) 339 | new_hyp_ePmb_memory.append(copy.copy(nextePmb_memory[tPi])) 340 | 341 | # check the finished samples 342 | new_live_k = 0 343 | hyp_samples = [] 344 | hyp_scores = [] 345 | hyp_rpos_samples = [] 346 | hyp_relation_samples = [] 347 | hyp_states = [] 348 | hyp_h1ts = [] 349 | hyp_calpha_past = [] 350 | hyp_palpha_past = [] 351 | hyp_emb_memory = [] 352 | hyp_ePmb_memory = [] 353 | 354 | for idx in range(len(new_hyp_samples)): 355 | if new_hyp_samples[idx][-1] == 0: # 356 | sample_score.append(new_hyp_scores[idx]) 357 | sample.append(new_hyp_samples[idx]) 358 | rpos_sample.append(new_hyp_rpos_samples[idx]) 359 | relation_sample.append(new_hyp_relation_samples[idx]) 360 | dead_k += 1 361 | else: 362 | new_live_k += 1 363 | hyp_scores.append(new_hyp_scores[idx]) 364 | hyp_samples.append(new_hyp_samples[idx]) 365 | hyp_rpos_samples.append(new_hyp_rpos_samples[idx]) 366 | hyp_relation_samples.append(new_hyp_relation_samples[idx]) 367 | hyp_states.append(new_hyp_states[idx]) 368 | hyp_h1ts.append(new_hyp_h1ts[idx]) 369 | hyp_calpha_past.append(new_hyp_calpha_past[idx]) 370 | hyp_palpha_past.append(new_hyp_palpha_past[idx]) 371 | hyp_emb_memory.append(new_hyp_emb_memory[idx]) 372 | hyp_ePmb_memory.append(new_hyp_ePmb_memory[idx]) 373 | 374 | hyp_scores = np.array(hyp_scores) 375 | live_k = new_live_k 376 | 377 | # whether finish beam search 378 | if new_live_k < 1: 379 | break 380 | if dead_k >= k: 381 | break 382 | 383 | next_lw = np.array([w[-1] for w in hyp_samples]) # each path's final symbol, (live_k,) 384 | next_state = np.array(hyp_states) # h2t, (live_k,n) 385 | next_h1t = np.array(hyp_h1ts) 386 | next_calpha_past = np.array(hyp_calpha_past) # (live_k,H,W) 387 | next_palpha_past = np.array(hyp_palpha_past) 388 | nextemb_memory = np.array(hyp_emb_memory) 389 | nextemb_memory = np.transpose(nextemb_memory, (1, 0, 2)) 390 | nextePmb_memory = np.array(hyp_ePmb_memory) 391 | nextePmb_memory = np.transpose(nextePmb_memory, (1, 0, 2)) 392 | next_lw = torch.from_numpy(next_lw).cuda() 393 | next_state = torch.from_numpy(next_state).cuda() 394 | next_h1t = torch.from_numpy(next_h1t).cuda() 395 | next_calpha_past = torch.from_numpy(next_calpha_past).cuda() 396 | next_palpha_past = torch.from_numpy(next_palpha_past).cuda() 397 | nextemb_memory = torch.from_numpy(nextemb_memory).cuda() 398 | nextePmb_memory = torch.from_numpy(nextePmb_memory).cuda() 399 | 400 | return sample_score, sample, rpos_sample, relation_sample 401 | 402 | 403 | # init model params 404 | def weight_init(m): 405 | if isinstance(m, nn.Conv2d): 406 | nn.init.xavier_uniform_(m.weight.data) 407 | try: 408 | nn.init.constant_(m.bias.data, 0.) 409 | except: 410 | pass 411 | 412 | if isinstance(m, nn.Linear): 413 | nn.init.xavier_uniform_(m.weight.data) 414 | try: 415 | nn.init.constant_(m.bias.data, 0.) 416 | except: 417 | pass 418 | 419 | # compute metric 420 | def cmp_result(rec,label): 421 | dist_mat = np.zeros((len(label)+1, len(rec)+1),dtype='int32') 422 | dist_mat[0,:] = range(len(rec) + 1) 423 | dist_mat[:,0] = range(len(label) + 1) 424 | for i in range(1, len(label) + 1): 425 | for j in range(1, len(rec) + 1): 426 | hit_score = dist_mat[i-1, j-1] + (label[i-1] != rec[j-1]) 427 | ins_score = dist_mat[i,j-1] + 1 428 | del_score = dist_mat[i-1, j] + 1 429 | dist_mat[i,j] = min(hit_score, ins_score, del_score) 430 | 431 | dist = dist_mat[len(label), len(rec)] 432 | return dist, len(label) 433 | 434 | def compute_wer(rec_mat, label_mat): 435 | total_dist = 0 436 | total_label = 0 437 | total_line = 0 438 | total_line_rec = 0 439 | for key_rec in rec_mat: 440 | label = label_mat[key_rec] 441 | rec = rec_mat[key_rec] 442 | # label = list(map(int,label)) 443 | # rec = list(map(int,rec)) 444 | dist, llen = cmp_result(rec, label) 445 | total_dist += dist 446 | total_label += llen 447 | total_line += 1 448 | if dist == 0: 449 | total_line_rec += 1 450 | wer = float(total_dist)/total_label 451 | sacc = float(total_line_rec)/total_line 452 | return wer, sacc 453 | 454 | def cmp_sacc_result(rec_list,label_list,rec_ridx_list,label_ridx_list,rec_re_list,label_re_list,chdict,redict): 455 | rec = True 456 | out_sym_pdict = {} 457 | label_sym_pdict = {} 458 | out_sym_pdict['0'] = '' 459 | label_sym_pdict['0'] = '' 460 | for idx, sym in enumerate(rec_list): 461 | out_sym_pdict[str(idx+1)] = chdict[sym] 462 | for idx, sym in enumerate(label_list): 463 | label_sym_pdict[str(idx+1)] = chdict[sym] 464 | 465 | if len(rec_list) != len(label_list): 466 | rec = False 467 | else: 468 | for idx in range(len(rec_list)): 469 | out_sym = chdict[rec_list[idx]] 470 | label_sym = chdict[label_list[idx]] 471 | out_repos = int(rec_ridx_list[idx]) 472 | label_repos = int(label_ridx_list[idx]) 473 | out_re = redict[rec_re_list[idx]] 474 | label_re = redict[label_re_list[idx]] 475 | if out_repos in out_sym_pdict: 476 | out_resym_s = out_sym_pdict[out_repos] 477 | else: 478 | out_resym_s = 'unknown' 479 | if label_repos in label_sym_pdict: 480 | label_resym_s = label_sym_pdict[label_repos] 481 | else: 482 | label_resym_s = 'unknown' 483 | 484 | # post-processing only for math recognition 485 | if (out_resym_s == '\lim' and label_resym_s == '\lim') or \ 486 | (out_resym_s == '\int' and label_resym_s == '\int') or \ 487 | (out_resym_s == '\sum' and label_resym_s == '\sum'): 488 | if out_re == 'Above': 489 | out_re = 'Sup' 490 | if out_re == 'Below': 491 | out_re = 'Sub' 492 | if label_re == 'Above': 493 | label_re = 'Sup' 494 | if label_re == 'Below': 495 | label_re = 'Sub' 496 | 497 | # if out_sym != label_sym or out_pos != label_pos or out_repos != label_repos or out_re != label_re: 498 | # if out_sym != label_sym or out_repos != label_repos: 499 | if out_sym != label_sym or out_repos != label_repos or out_re != label_re: 500 | rec = False 501 | break 502 | return rec 503 | 504 | def compute_sacc(rec_mat, label_mat, rec_ridx_mat, label_ridx_mat, rec_re_mat, label_re_mat, chdict, redict): 505 | total_num = len(rec_mat) 506 | correct_num = 0 507 | for key_rec in rec_mat: 508 | rec_list = rec_mat[key_rec] 509 | label_list = label_mat[key_rec] 510 | rec_ridx_list = rec_ridx_mat[key_rec] 511 | label_ridx_list = label_ridx_mat[key_rec] 512 | rec_re_list = rec_re_mat[key_rec] 513 | label_re_list = label_re_mat[key_rec] 514 | rec_result = cmp_sacc_result(rec_list,label_list,rec_ridx_list,label_ridx_list,rec_re_list,label_re_list,chdict,redict) 515 | if rec_result: 516 | correct_num += 1 517 | correct_rate = 1. * correct_num / total_num 518 | return correct_rate 519 | -------------------------------------------------------------------------------- /paper/TD_camera_v1.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianshuZhang/TreeDecoder/e73da41ba234d01467d23b9bf0f36e1079e96c64/paper/TD_camera_v1.pdf --------------------------------------------------------------------------------