├── README.md ├── RNN.py ├── bi_seq_lm_input.py ├── config.py ├── masked_cross_entropy.py ├── policy_config.py └── reinforce.py /README.md: -------------------------------------------------------------------------------- 1 | ## Code for Sentence Compression 2 | 3 | We describe the steps in detail. If any points unclear, please contact authors by emailing code4conference@gmail.com . 4 | 5 | ### Step 1: Training a bi-directional neural language model. 6 | Take an example, use "BOS This is an anonymous Github EOS" to predict "This is an anonymous Github". (see RNN.py) 7 | 8 | ### Step 2: Pre-training a policy network 9 | Pre-training a sequence labeling neural network as policy network using labels yielded by the unsupervised method, Integer Linear Programming. 10 | 11 | ### Step 3: Reinforcment Learning 12 | Start with the pre-trained policy instead of random policy, and take pre-trained language model as reward to fine tune the policy network. 13 | -------------------------------------------------------------------------------- /RNN.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | from torch.autograd import Variable 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | import config 9 | conf = config.config() 10 | 11 | USE_CUDA = torch.cuda.is_available() 12 | FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor 13 | LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor 14 | ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor 15 | 16 | def prepare_sequence(word_seq_id, dep_seq_id, pos_seq_id, target_seq_id): 17 | seq_lens = [len(seq_id) for seq_id in word_seq_id] 18 | max_l = max(seq_lens) 19 | indexs = np.argsort(seq_lens)[::-1].tolist() 20 | 21 | word_seq_id=np.array(word_seq_id)[indexs] 22 | dep_seq_id=np.array(dep_seq_id)[indexs] 23 | pos_seq_id=np.array(pos_seq_id)[indexs] 24 | target_seq_id=np.array(target_seq_id)[indexs] 25 | 26 | word_seq_id1=[] 27 | dep_seq_id1=[] 28 | pos_seq_id1=[] 29 | target_seq_id1=[] 30 | mask=[] 31 | for w_seq, d_seq, p_seq, t_seq in zip(word_seq_id, dep_seq_id, pos_seq_id, target_seq_id): 32 | if len(w_seq)!=len(d_seq) or len(d_seq)!=len(p_seq) or len(p_seq)!=len(w_seq): 33 | print "sth wrong with w_seq, d_seq, p_seq" 34 | 35 | word_seq_id1.append(w_seq+[0]*(max_l-len(w_seq))) 36 | dep_seq_id1.append(d_seq+[0]*(max_l-len(d_seq))) 37 | pos_seq_id1.append(p_seq+[0]*(max_l-len(p_seq))) 38 | if conf.num_directions == 2: 39 | target_seq_id1.append(t_seq+[0]*(max_l-len(t_seq)-2)) 40 | mask.append([1]*(len(w_seq)-2)+[0]*(max_l-len(w_seq))) 41 | else: 42 | target_seq_id1.append(t_seq+[0]*(max_l-len(t_seq))) 43 | mask.append([1]*len(w_seq)+[0]*(max_l-len(w_seq))) 44 | 45 | return Variable(LongTensor(word_seq_id1)),\ 46 | Variable(LongTensor(dep_seq_id1)), \ 47 | Variable(LongTensor(pos_seq_id1)), \ 48 | Variable(LongTensor(target_seq_id1)), \ 49 | seq_lens,\ 50 | indexs,\ 51 | np.array(mask),\ 52 | max_l 53 | 54 | 55 | class vanilla_RNN(nn.Module): 56 | def __init__(self, freq_vocab, word2id, dep2id, pos2id): 57 | super(vanilla_RNN, self).__init__() 58 | self.word_embeddings = nn.Embedding(len(word2id)+1, conf.emb_dim) 59 | self.dep_embeddings = nn.Embedding(len(dep2id)+1, conf.dep_dim) 60 | self.pos_embeddings = nn.Embedding(len(pos2id)+1, conf.pos_dim) 61 | self.init_all_embeddings() 62 | 63 | self.word_embeddings = self.word_embeddings.cpu() 64 | self.dep_embeddings = self.dep_embeddings.cpu() 65 | self.pos_embeddings = self.pos_embeddings.cpu() 66 | 67 | #initialize RNN 68 | self.rnn = nn.RNN(conf.emb_dim+conf.dep_dim+conf.pos_dim, 69 | conf.hidden_dim, conf.num_layers, 70 | batch_first=True, bidirectional=True if conf.num_directions==2 else False) 71 | self.params_init(self.rnn.named_parameters()) 72 | 73 | #initialize linear 74 | self.linear = nn.Linear(conf.hidden_dim*conf.num_directions, conf.vocab_size+1) 75 | self.params_init(self.linear.named_parameters()) 76 | 77 | self.freq_vocab=freq_vocab 78 | self.word2id=word2id 79 | self.dep2id=dep2id 80 | self.pos2id=pos2id 81 | 82 | def init_hidden(self, batch_size): 83 | h0 = Variable(torch.zeros(conf.num_layers*conf.num_directions, batch_size, conf.hidden_dim)) 84 | return h0.cuda() if USE_CUDA else h0 85 | 86 | def init_all_embeddings(self): 87 | self.word_embeddings.weight = nn.init.xavier_uniform(self.word_embeddings.weight) 88 | self.dep_embeddings.weight = nn.init.xavier_uniform(self.dep_embeddings.weight) 89 | self.pos_embeddings.weight = nn.init.xavier_uniform(self.pos_embeddings.weight) 90 | 91 | def params_init(self, params): 92 | for name, param in params: 93 | if len(param.data.shape)==2: 94 | print(name) 95 | nn.init.kaiming_normal(param, a=0, mode='fan_in') 96 | if len(param.data.shape)==1: 97 | nn.init.normal(param) 98 | 99 | 100 | def forward(self, word_seq_id, dep_seq_id, pos_seq_id, target_seq_ids, h0, is_training=False): 101 | word_padded_ids, dep_padded_ids, pos_padded_ids, target_padded_ids, seq_lens, indexs, mask, max_l = \ 102 | prepare_sequence(word_seq_id, dep_seq_id, pos_seq_id, target_seq_ids) 103 | 104 | word_vecs = self.word_embeddings(word_padded_ids) 105 | dep_vecs = self.dep_embeddings(dep_padded_ids) 106 | pos_vecs = self.pos_embeddings(pos_padded_ids) 107 | input_x = torch.cat((word_vecs, dep_vecs, pos_vecs), 2) 108 | 109 | ''' 110 | input_seq_packed = torch.nn.utils.rnn.pack_padded_sequence(input_x, seq_lens, batch_first=True) 111 | out_pack, hx = self.rnn(input_seq_packed, self.hidden) 112 | out, seq_lens = torch.nn.utils.rnn.pad_packed_sequence(out_pack, batch_first=True) 113 | ''' 114 | out, hx = self.rnn(input_x, h0) 115 | #out = out.contiguous().view(-1, conf.hidden_dim*2) 116 | #mask = mask.view(-1) 117 | if conf.num_directions==2: 118 | forward_out, backward_out = out[:, :-2, :conf.hidden_dim], out[:, 2:, conf.hidden_dim:] 119 | out_cat = torch.cat((forward_out, backward_out), dim=-1) 120 | 121 | logits = self.linear(out_cat if conf.num_directions==2 else out ) 122 | probs=0 123 | if is_training==False: 124 | probs = F.softmax(logits, dim=2) 125 | #pred = probs.data.cpu().numpy().argmax(2) 126 | return logits, probs, word_padded_ids, target_padded_ids, indexs, mask 127 | 128 | 129 | 130 | -------------------------------------------------------------------------------- /bi_seq_lm_input.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.optim as optim 5 | from torch.autograd import Variable 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import numpy as np 9 | torch.manual_seed(1337) 10 | 11 | import config 12 | conf = config.config() 13 | import collections as col 14 | 15 | USE_CUDA = torch.cuda.is_available() 16 | FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor 17 | LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor 18 | ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor 19 | 20 | #we use https://spacy.io/models/ to parse sentences 21 | f_parsing=open('./your_own_parsing_ouput.txt', "r") 22 | toy_parsing=[] 23 | temp=[] 24 | for line in f_parsing.readlines(): 25 | if line.split()==[]: 26 | toy_parsing.append(temp) 27 | temp=[] 28 | else: 29 | temp.append(line.split()) 30 | 31 | f_parsing.close() 32 | 33 | 34 | pos_seq=[] 35 | dep_seq=[] 36 | child=[] 37 | father=[] 38 | ori_seq=[] 39 | lower_seq=[] 40 | for sent in toy_parsing: 41 | child_dict={} 42 | temp_father=[] 43 | for k in range(len(sent)+1): 44 | child_dict[k]=[] 45 | temp_father.append(k) 46 | 47 | temp_pos=['BOS'] 48 | temp_dep=['BOS'] 49 | temp_ori_seq=['BOS'] 50 | temp_lower_seq=['BOS'] 51 | for token in sent: 52 | temp_ori_seq.append(token[2]) 53 | temp_lower_seq.append(token[2].lower()) 54 | child_dict[int(token[1])].append(int(token[0])) 55 | temp_father[int(token[0])]=int(token[1]) 56 | temp_pos.append(token[5]) 57 | temp_dep.append(token[4]) 58 | 59 | temp_pos[-1]='EOS' 60 | temp_dep[-1]='EOS' 61 | temp_ori_seq[-1]='EOS' 62 | temp_lower_seq[-1]='EOS' 63 | 64 | ori_seq.append(temp_ori_seq) 65 | lower_seq.append(temp_lower_seq) 66 | 67 | father.append(temp_father) 68 | child.append(child_dict) 69 | pos_seq.append(temp_pos) 70 | dep_seq.append(temp_dep) 71 | 72 | #%% 73 | """ 74 | inputs are word_seq, dep_seq, pos_seq 75 | """ 76 | for w, d, p in zip(lower_seq, dep_seq, pos_seq): 77 | if len(w)!=len(d) or len(w)!=len(p) or len(d)!=len(p): 78 | print "something wrong!" 79 | 80 | 81 | def mapping(_list): 82 | _2id={} 83 | for i, item in enumerate(_list): 84 | _2id[item]=i+1 85 | 86 | return _2id 87 | 88 | 89 | def prepare_data2id(word_seq, dep_seq, pos_seq): 90 | word_vocab=set() 91 | dep_vocab=set() 92 | pos_vocab=set() 93 | bag_of_words = [] 94 | 95 | for w_seq_i, d_seq_i, p_seq_i in zip(word_seq, dep_seq, pos_seq): 96 | for w_i, d_i, p_i in zip(w_seq_i, d_seq_i, p_seq_i): 97 | bag_of_words.append(w_i) 98 | word_vocab.add(w_i) 99 | dep_vocab.add(d_i) 100 | pos_vocab.add(p_i) 101 | 102 | freq_words = col.Counter(bag_of_words).items() 103 | freq_words=sorted(freq_words, key=lambda s:s[-1], reverse=True) 104 | assert len(freq_words)>conf.vocab_size 105 | freq_vocab=[] 106 | for w in freq_words[:conf.vocab_size]: 107 | freq_vocab.append(w[0]) 108 | 109 | vocab_set=set(freq_vocab) 110 | dep_vocab=list(dep_vocab) 111 | pos_vocab=list(pos_vocab) 112 | 113 | word2id=mapping(freq_vocab) 114 | dep2id=mapping(dep_vocab) 115 | pos2id=mapping(pos_vocab) 116 | 117 | word_seq_id=[] 118 | dep_seq_id=[] 119 | pos_seq_id=[] 120 | target_seq_id=[] 121 | for w_seq_i, d_seq_i, p_seq_i in zip(word_seq, dep_seq, pos_seq): 122 | temp_w=[] 123 | temp_dep=[] 124 | temp_pos=[] 125 | for w_i, d_i, p_i in zip(w_seq_i, d_seq_i, p_seq_i): 126 | if w_i in vocab_set: 127 | temp_w.append(word2id[w_i]) 128 | else: 129 | temp_w.append(len(freq_vocab)) 130 | 131 | temp_dep.append(dep2id[d_i]) 132 | temp_pos.append(pos2id[p_i]) 133 | 134 | word_seq_id.append(temp_w) 135 | dep_seq_id.append(temp_dep) 136 | pos_seq_id.append(temp_pos) 137 | target_seq_id.append(temp_w[1:-1]) 138 | 139 | 140 | 141 | 142 | 143 | return word_seq_id, target_seq_id, dep_seq_id, pos_seq_id, \ 144 | word2id, dep2id, pos2id, freq_vocab 145 | 146 | 147 | word_seq_id, target_seq_id, dep_seq_id, pos_seq_id, \ 148 | word2id, dep2id, pos2id, freq_vocab = prepare_data2id(lower_seq, dep_seq, pos_seq) 149 | 150 | #%% 151 | 152 | import RNN, masked_cross_entropy 153 | lm_model = RNN.vanilla_RNN(freq_vocab, word2id, dep2id, pos2id,) 154 | if USE_CUDA: 155 | lm_model = lm_model.cuda() 156 | 157 | optimizer = optim.Adam(lm_model.parameters(),lr=conf.lr) 158 | 159 | #%% 160 | train_word_seq_id, train_target_seq_id, train_dep_seq_id, train_pos_seq_id = \ 161 | word_seq_id[2000:], target_seq_id[2000:], dep_seq_id[2000:], pos_seq_id[2000:] 162 | 163 | val_word_seq_id, val_target_seq_id, val_dep_seq_id, val_pos_seq_id = \ 164 | word_seq_id[1000:2000], target_seq_id[1000:2000], dep_seq_id[1000:2000], pos_seq_id[1000:2000] 165 | 166 | test_word_seq_id, test_target_seq_id, test_dep_seq_id, test_pos_seq_id = \ 167 | word_seq_id[:1000], target_seq_id[:1000], dep_seq_id[:1000], pos_seq_id[:1000] 168 | 169 | 170 | #%% 171 | bz = conf.batch_size 172 | for epoch in range(100): 173 | #total_loss = 0 174 | losses=[] 175 | 176 | train_data = zip(train_word_seq_id, train_target_seq_id, train_dep_seq_id, train_pos_seq_id) 177 | np.random.shuffle(train_data) 178 | train_word_seq_id, train_target_seq_id, train_dep_seq_id, train_pos_seq_id= zip(*train_data) 179 | 180 | nb = len(train_word_seq_id)/bz 181 | for b_i in range(nb): 182 | h0 = lm_model.init_hidden(bz) 183 | lm_model.zero_grad() 184 | logits, probs, word_padded_ids, target_padded_ids, indexs, mask = \ 185 | lm_model(train_word_seq_id[b_i*bz:(b_i+1)*bz], 186 | train_dep_seq_id[b_i*bz:(b_i+1)*bz], 187 | train_pos_seq_id[b_i*bz:(b_i+1)*bz], 188 | train_target_seq_id[b_i*bz:(b_i+1)*bz], 189 | h0, 190 | is_training=True) 191 | 192 | seq_lens = Variable(torch.sum(LongTensor(mask), 1)) 193 | loss = masked_cross_entropy.compute_loss(logits, target_padded_ids, seq_lens) 194 | 195 | 196 | loss.backward() 197 | #torch.nn.utils.clip_grad_norm(lm_model.parameters(), 0.5) # gradient clipping 198 | optimizer.step() 199 | 200 | 201 | 202 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | class config(object): 4 | 5 | emb_dim = 128 6 | dep_dim = 128 7 | pos_dim = 128 8 | vocab_size = 50000 9 | hidden_dim = 512 10 | num_layers = 1 11 | num_directions = 2 12 | 13 | batch_size = 50 14 | lr = 0.00025 15 | decay_rate = 0.99 16 | 17 | 18 | -------------------------------------------------------------------------------- /masked_cross_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn.functional as F 4 | torch.manual_seed(1) 5 | 6 | 7 | def _sequence_mask(sequence_length, max_len=None): 8 | if max_len is None: 9 | max_len = sequence_length.data.max() 10 | batch_size = sequence_length.size(0) 11 | seq_range = torch.range(0, max_len - 1).long() 12 | seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) 13 | seq_range_expand = Variable(seq_range_expand) 14 | 15 | if sequence_length.is_cuda: 16 | seq_range_expand = seq_range_expand.cuda() 17 | 18 | seq_length_expand = (sequence_length.unsqueeze(1) 19 | .expand_as(seq_range_expand)) 20 | return seq_range_expand < seq_length_expand 21 | 22 | 23 | def compute_loss(logits, target, length): 24 | """ 25 | Args: 26 | logits: A Variable containing a FloatTensor of size 27 | (batch, max_len, num_classes) which contains the 28 | unnormalized probability for each class. 29 | target: A Variable containing a LongTensor of size 30 | (batch, max_len) which contains the index of the true 31 | class for each corresponding step. 32 | length: A Variable containing a LongTensor of size (batch,) 33 | which contains the length of each data in a batch. 34 | 35 | Returns: 36 | loss: An average loss value masked by the length. 37 | """ 38 | 39 | # logits_flat: (batch * max_len, num_classes) 40 | logits_flat = logits.view(-1, logits.size(-1)) 41 | # log_probs_flat: (batch * max_len, num_classes) 42 | log_probs_flat = F.log_softmax(logits_flat, dim=-1) 43 | # target_flat: (batch * max_len, 1) 44 | target_flat = target.view(-1, 1) 45 | # losses_flat: (batch * max_len, 1) 46 | #print target_flat 47 | losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat) 48 | # losses: (batch, max_len) 49 | losses = losses_flat.view(*target.size()) 50 | # mask: (batch, max_len) 51 | mask = _sequence_mask(sequence_length=length, max_len=target.size(1)) 52 | losses = losses * mask.float() 53 | loss = losses.sum() / length.float().sum() 54 | return loss 55 | 56 | 57 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /policy_config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | class config(object): 4 | 5 | emb_dim = 128 6 | dep_dim = 128 7 | pos_dim = 128 8 | vocab_size = 50000 9 | hidden_dim = 512 10 | num_layers = 1 11 | num_directions = 2 12 | 13 | batch_size = 5 14 | lr = 0.00001 15 | decay_rate = 0.99 16 | 17 | -------------------------------------------------------------------------------- /reinforce.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2 2 | # -*- coding: utf-8 -*- 3 | 4 | import torch 5 | from torch.autograd import Variable 6 | import torch.autograd as autograd 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.optim as optim 10 | import numpy as np 11 | from torch.distributions import Categorical 12 | 13 | import policy_config 14 | conf = policy_config.config() 15 | 16 | USE_CUDA = torch.cuda.is_available() 17 | FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor 18 | LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor 19 | ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor 20 | 21 | 22 | def reward_length(x): 23 | return 16*x**2*(1-x)**2 24 | 25 | def reward_grammar(word_seq_id, dep_seq_id, pos_seq_id, language_model, pred_actions, mask, indexs): 26 | 27 | new_word_seq_id = np.array(word_seq_id)[indexs].tolist() 28 | new_dep_seq_id = np.array(dep_seq_id)[indexs].tolist() 29 | new_pos_seq_id = np.array(pos_seq_id)[indexs].tolist() 30 | 31 | word2id = language_model.word2id 32 | dep2id = language_model.dep2id 33 | pos2id = language_model.pos2id 34 | 35 | new_word_seq_id1 = [] 36 | new_dep_seq_id1 = [] 37 | new_pos_seq_id1 = [] 38 | for i in range(len(pred_actions)): 39 | temp_word = [word2id['BOS']] 40 | temp_dep = [dep2id['BOS']] 41 | temp_pos = [pos2id['BOS']] 42 | for j in range(sum(mask[i])): 43 | if pred_actions[i][j]==1: 44 | temp_word.append(new_word_seq_id[i][j]) 45 | temp_dep.append(new_dep_seq_id[i][j]) 46 | temp_pos.append(new_pos_seq_id[i][j]) 47 | 48 | temp_word.append(word2id['EOS']) 49 | temp_dep.append(dep2id['EOS']) 50 | temp_pos.append(pos2id['EOS']) 51 | 52 | new_word_seq_id1.append(temp_word) 53 | new_dep_seq_id1.append(temp_dep) 54 | new_pos_seq_id1.append(temp_pos) 55 | 56 | new_target_seq_id1=[] 57 | for sent in new_word_seq_id1: 58 | new_target_seq_id1.append(sent[1:-1]) 59 | 60 | h0 = language_model.init_hidden(conf.batch_size) 61 | 62 | _, probs, _, target_padded_ids, indexs, mask = \ 63 | language_model(new_word_seq_id1, new_dep_seq_id1, new_pos_seq_id1, new_target_seq_id1, h0, False) 64 | 65 | 66 | ori_index = np.argsort(indexs).tolist() 67 | probs=probs[ori_index].cpu().data.numpy() 68 | mask=mask[ori_index].cpu().data.numpy() 69 | target_padded_ids=target_padded_ids[ori_index].cpu().data.numpy() 70 | 71 | 72 | print 73 | ppl=[] 74 | for i in range(len(mask)): 75 | l=int(sum(mask[i])) 76 | temp=[] 77 | for probability, id_ in zip(probs[i][:l], target_padded_ids[i][:l]): 78 | print probability[id_] 79 | temp.append(np.log(probability[id_])) 80 | 81 | ppl.append(np.exp(-sum(temp)/float(len(temp)))) 82 | 83 | rewards = FloatTensor(ppl) 84 | rewards = (rewards - rewards.mean()) / (rewards.std() + np.finfo(np.float32).eps) 85 | return rewards 86 | 87 | 88 | def select_action(word_seq_id, dep_seq_id, pos_seq_id, policy, language_model, optimizer): 89 | h0 = policy.init_hidden(conf.batch_size) 90 | _, probs, indexs, mask, pred = policy(word_seq_id, dep_seq_id, pos_seq_id, h0, True) 91 | 92 | pred_actions=[] 93 | for prob, index, mask_i in zip(probs, indexs, mask): 94 | length = sum(mask_i).astype('int64') 95 | m = Categorical(prob[:length]) 96 | actions = m.sample() 97 | pred_actions.append(actions.data.numpy().tolist()+[-1]*(mask.shape[1]-length)) 98 | policy.saved_log_probs.append(m.log_prob(actions)) 99 | policy.pred_actions=pred_actions 100 | 101 | 102 | policy.length_r = reward_length(model.length_r) 103 | policy.grammar_r =reward_grammar(word_seq_id, dep_seq_id, pos_seq_id, language_model, 104 | pred_actions, mask, indexs) 105 | 106 | batch_loss=0 107 | for i in range(mask.shape[0]): 108 | batch_loss+=finish_episode(policy, i)#/float(length) 109 | batch_loss = batch_loss/conf.batch_size 110 | 111 | optimizer.zero_grad() 112 | batch_loss.backward() 113 | optimizer.step() 114 | 115 | policy.saved_log_probs = [] 116 | policy.length_r = [] 117 | policy.grammar_r = [] 118 | 119 | return batch_loss 120 | 121 | def finish_episode(model, index): 122 | model_loss = [] 123 | for log_prob in model.saved_log_probs[index]: 124 | model_loss.append(-log_prob * (model.grammar_r[index]+model.length_r[index])) 125 | 126 | model_loss = torch.cat(model_loss).sum() 127 | return model_loss 128 | 129 | 130 | 131 | --------------------------------------------------------------------------------