├── README.md ├── chsmm.py ├── data ├── __init__.py ├── e2e_aligned.tar.gz ├── make_e2e_labedata.py ├── make_wikibio_labedata.py ├── utils.py └── wb_aligned.tar.gz ├── infc.py ├── labeled_data.py ├── slides.pdf ├── template_extraction.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # neural-template-gen 2 | 3 | Code for [Learning Neural Templates for Text Generation](https://arxiv.org/abs/1808.10122) (Wiseman, Shieber, Rush; EMNLP 2018) 4 | 5 | For questions/concerns/bugs please feel free to email swiseman[at]ttic.edu. 6 | 7 | **N.B.** This code was tested with python 2.7 and pytorch 0.3.1. 8 | 9 | ## Data and Data Preparation 10 | 11 | The E2E NLG Challenge data is available [here](http://www.macs.hw.ac.uk/InteractionLab/E2E/), and the preprocessed version of the data used for training is at [data/e2e_aligned.tar.gz](https://github.com/harvardnlp/neural-template-gen/blob/master/data/e2e_aligned.tar.gz). This preprocessed data uses the same database record preprocessing scheme applied by Sebastian Gehrmann in his [system](https://github.com/sebastianGehrmann/OpenNMT-py/tree/diverse_ensemble), and also annotates text spans that occur in the corresponding database. Code for annotating the data in this way is at [data/make_e2e_labedata.py](https://github.com/harvardnlp/neural-template-gen/blob/master/data/make_e2e_labedata.py). 12 | 13 | 14 | The WikiBio data is available [here](https://github.com/DavidGrangier/wikipedia-biography-dataset), and the preprocessed version of the target-side data used for training is at [data/wb_aligned.tar.gz](https://github.com/harvardnlp/neural-template-gen/blob/master/data/wb_aligned.tar.gz). This target-side data is again preprocessed to annotate spans appearing in the corresponding database. Code for this annotation is at [data/make_wikibio_labedata.py](https://github.com/harvardnlp/neural-template-gen/blob/master/data/make_wikibio_labedata.py). The source-side data can be downloaded directly from the [WikiBio repo](https://github.com/DavidGrangier/wikipedia-biography-dataset), and we used it unchanged; in particular the `*.box` files become our `src_*.txt` files mentioned below. 15 | 16 | 17 | The code assumes that each dataset lives in a directory containing `src_train.txt`, `train.txt`, `src_valid.txt`, and `valid.txt` files, and that if the files are from the WikiBio dataset the directory name will contain the string `wiki`. 18 | 19 | ## Training 20 | The four trained models mentioned in the paper can be downloaded [here](https://drive.google.com/drive/folders/1iv71Oq7cmXRY6h2jn0QzlYbbr0GwHCfA?usp=sharing). The commands for retraining the models are given below. 21 | 22 | Assuming your E2E data is in `data/labee2e/`, you can train the non-autoregressive model as follows 23 | 24 | ``` 25 | python chsmm.py -data data/labee2e/ -emb_size 300 -hid_size 300 -layers 1 -K 55 -L 4 -log_interval 200 -thresh 9 -emb_drop -bsz 15 -max_seqlen 55 -lr 0.5 -sep_attn -max_pool -unif_lenps -one_rnn -Kmul 5 -mlpinp -onmt_decay -cuda -seed 1818 -save models/chsmm-e2e-300-55-5.pt 26 | ``` 27 | 28 | and the autoregressive model as follows. 29 | 30 | ``` 31 | python chsmm.py -data data/labee2e/ -emb_size 300 -hid_size 300 -layers 1 -K 55 -L 4 -log_interval 200 -thresh 9 -emb_drop -bsz 15 -max_seqlen 55 -lr 0.5 -sep_attn -max_pool -unif_lenps -one_rnn -Kmul 5 -mlpinp -onmt_decay -cuda -seed 1111 -save models/chsmm-e2e-300-55-5-far.pt -ar_after_decay 32 | ``` 33 | 34 | 35 | Assuming your WikiBio data is in `data/labewiki`, you can train the non-autoregressive model as follows 36 | 37 | ``` 38 | python chsmm.py -data data/labewiki/ -emb_size 300 -hid_size 300 -layers 1 -K 45 -L 4 -log_interval 1000 -thresh 29 -emb_drop -bsz 5 -max_seqlen 55 -lr 0.5 -sep_attn -max_pool -unif_lenps -one_rnn -Kmul 3 -mlpinp -onmt_decay -cuda -save models/chsmm-wiki-300-45-3.pt 39 | ``` 40 | 41 | and the autoregressive model as follows. 42 | 43 | ``` 44 | python chsmm.py -data data/labewiki/ -emb_size 300 -hid_size 300 -layers 1 -K 45 -L 4 -log_interval 1000 -thresh 29 -emb_drop -bsz 5 -max_seqlen 55 -lr 0.5 -sep_attn -max_pool -unif_lenps -one_rnn -Kmul 3 -mlpinp -onmt_decay -cuda -save models/chsmm-wiki-300-45-3-war.pt -ar_after_decay -word_ar 45 | ``` 46 | 47 | The above scripts will also attempt to save to a `models/` directory (which must be created first). Also see [chsmm.py](https://github.com/harvardnlp/neural-template-gen/blob/master/chsmm.py) for additional training and model options. 48 | 49 | **N.B.** training is somewhat sensitive to the random seed, and it may be necessary to try several seeds in order to get the best performance. 50 | 51 | 52 | ## Viterbi Segmentation/Template Extraction 53 | 54 | Once you've trained a model, you can use it to compute the Viterbi segmentation of the training data, which we use to extract templates. A gzipped tarball containing Viterbi segmentations corresponding to the four models above can be downloaded [here](https://drive.google.com/file/d/1ON4ROs_coDNmVt3-JON4wK1Kc_NkIV2M/view?usp=sharing). 55 | 56 | You can rerun the segmentation for the non-autoregressive E2E model as follows 57 | 58 | ``` 59 | python chsmm.py -data data/labee2e/ -emb_size 300 -hid_size 300 -layers 1 -K 55 -L 4 -log_interval 200 -thresh 9 -emb_drop -bsz 16 -max_seqlen 55 -lr 0.5 -sep_attn -max_pool -unif_lenps -one_rnn -Kmul 5 -mlpinp -onmt_decay -cuda -load models/e2e-55-5.pt -label_train | tee segs/seg-e2e-300-55-5.txt 60 | ``` 61 | 62 | and for the autoregressive one as follows. 63 | 64 | ``` 65 | python chsmm.py -data data/labee2e/ -emb_size 300 -hid_size 300 -layers 1 -K 60 -L 4 -log_interval 200 -thresh 9 -emb_drop -bsz 16 -max_seqlen 55 -lr 0.5 -sep_attn -max_pool -unif_lenps -one_rnn -Kmul 1 -mlpinp -onmt_decay -cuda -load models/e2e-60-1-far.pt -label_train -ar_after_decay | tee segs/seg-e2e-300-60-1-far.txt 66 | ``` 67 | 68 | You can rerun the segmentation for the non-autoregressive WikiBio model as follows 69 | 70 | ``` 71 | python chsmm.py -data data/labewiki/ -emb_size 300 -hid_size 300 -layers 1 -K 45 -L 4 -log_interval 200 -thresh 29 -emb_drop -bsz 16 -max_seqlen 55 -lr 0.5 -sep_attn -max_pool -unif_lenps -one_rnn -Kmul 3 -mlpinp -onmt_decay -cuda -load models/wb-45-3.pt -label_train | tee segs/seg-wb-300-45-3.txt 72 | ``` 73 | 74 | and for the autoregressive one as follows. 75 | 76 | ``` 77 | python chsmm.py -data data/labewiki/ -emb_size 300 -hid_size 300 -layers 1 -K 45 -L 4 -log_interval 200 -thresh 29 -emb_drop -bsz 16 -max_seqlen 55 -lr 0.5 -sep_attn -max_pool -unif_lenps -one_rnn -Kmul 3 -mlpinp -onmt_decay -cuda -load models/wb-45-3-war.pt -label_train | tee segs/seg-wb-300-45-3-war.txt 78 | ``` 79 | 80 | The above scripts write the MAP segmentations (in text) to standard out. Above, they have been redirected to a `segs/` directory. 81 | 82 | ### Examining and Extracting Templates 83 | The [template_extraction.py](https://github.com/harvardnlp/neural-template-gen/blob/master/template_extraction.py) script can be used to extract templates from the segmentations produced as above, and to look at them. In particular, `extract_from_tagged_data()` returns the most common templates, and mappings from these templates to sentences, and from states to phrases. This script is also used in generation (see below). 84 | 85 | 86 | ## Generation 87 | Once a model has been trained and the MAP segmentations created, we can generate by limiting to (for instance) the top 100 extracted templates. 88 | 89 | The following command will generate on the E2E validation set using the autoregressive model: 90 | 91 | ``` 92 | python chsmm.py -data data/labee2e/ -emb_size 300 -hid_size 300 -layers 1 -dropout 0.3 -K 60 -L 4 -log_interval 100 -thresh 9 -lr 0.5 -sep_attn -unif_lenps -emb_drop -mlpinp -onmt_decay -one_rnn -max_pool -gen_from_fi data/labee2e/src_uniq_valid.txt -load models/e2e-60-1-far.pt -tagged_fi segs/seg-e2e-60-1-far.txt -beamsz 5 -ntemplates 100 -gen_wts '1,1' -cuda -min_gen_tokes 0 > gens/gen-e2e-60-1-far.txt 93 | ``` 94 | 95 | The following command will generate on the WikiBio test using the autoregressive model: 96 | ``` 97 | python chsmm.py -data data/labewiki/ -emb_size 300 -hid_size 300 -layers 1 -K 45 -L 4 -log_interval 1000 -thresh 29 -emb_drop -bsz 5 -max_seqlen 55 -lr 0.5 -sep_attn -max_pool -unif_lenps -one_rnn -Kmul 3 -mlpinp -onmt_decay -cuda -gen_from_fi wikipedia-biography-dataset/test/test.box -load models/wb-45-3-war.pt -tagged_fi segs/seg-wb-300-45-3-war.txt -beamsz 5 -ntemplates 100 -gen_wts '1,1' -cuda -min_gen_tokes 20 > gens/gen-wb-45-3-war.txt 98 | ``` 99 | 100 | Generations from the other models can be obtained analogously, by substituting in the correct arguments for `-data` (path to data directory), `-gen_from_fi` (the source file from which to generate), `-load` (path to the saved model), and `-tagged_fi` (path to the MAP segmentations under the corresponding model). See [chsmm.py](https://github.com/harvardnlp/neural-template-gen/blob/master/chsmm.py) for additional generation options. 101 | 102 | 103 | **N.B.** The format of the generations is: `|||`, where `` provides the segmentation used in generating. As such, all the text beginning with '|||' should be stripped off before evaluating the generations. 104 | -------------------------------------------------------------------------------- /chsmm.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import math 4 | import random 5 | import argparse 6 | from collections import defaultdict, Counter 7 | import heapq 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | from torch.autograd import Variable 14 | import labeled_data 15 | from utils import logsumexp1, make_fwd_constr_idxs, make_bwd_constr_idxs, backtrace3, backtrace 16 | from data.utils import get_wikibio_poswrds, get_e2e_poswrds 17 | import infc 18 | 19 | 20 | class HSMM(nn.Module): 21 | """ 22 | standard hsmm 23 | """ 24 | def __init__(self, wordtypes, gentypes, opt): 25 | super(HSMM, self).__init__() 26 | self.K = opt.K 27 | self.Kmul = opt.Kmul 28 | self.L = opt.L 29 | self.A_dim = opt.A_dim 30 | self.unif_lenps = opt.unif_lenps 31 | self.A_from = nn.Parameter(torch.Tensor(opt.K*opt.Kmul, opt.A_dim)) 32 | self.A_to = nn.Parameter(torch.Tensor(opt.A_dim, opt.K*opt.Kmul)) 33 | if self.unif_lenps: 34 | self.len_scores = nn.Parameter(torch.ones(1, opt.L)) 35 | self.len_scores.requires_grad = False 36 | else: 37 | self.len_decoder = nn.Linear(2*opt.A_dim, opt.L) 38 | 39 | self.yes_self_trans = opt.yes_self_trans 40 | if not self.yes_self_trans: 41 | selfmask = torch.Tensor(opt.K*opt.Kmul).fill_(-float("inf")) 42 | self.register_buffer('selfmask', Variable(torch.diag(selfmask), requires_grad=False)) 43 | 44 | self.max_pool = opt.max_pool 45 | self.emb_size, self.layers, self.hid_size = opt.emb_size, opt.layers, opt.hid_size 46 | self.pad_idx = opt.pad_idx 47 | self.lut = nn.Embedding(wordtypes, opt.emb_size, padding_idx=opt.pad_idx) 48 | self.mlpinp = opt.mlpinp 49 | self.word_ar = opt.word_ar 50 | self.ar = False 51 | inp_feats = 4 52 | sz_mult = opt.mlp_sz_mult 53 | if opt.mlpinp: 54 | rnninsz = sz_mult*opt.emb_size 55 | mlpinp_sz = inp_feats*opt.emb_size 56 | self.inpmlp = nn.Sequential(nn.Linear(mlpinp_sz, sz_mult*opt.emb_size), 57 | nn.ReLU()) 58 | else: 59 | rnninsz = inp_feats*opt.emb_size 60 | 61 | self.start_emb = nn.Parameter(torch.Tensor(1, 1, rnninsz)) 62 | self.pad_emb = nn.Parameter(torch.zeros(1, 1, rnninsz)) 63 | 64 | self.one_rnn = opt.one_rnn 65 | if opt.one_rnn: 66 | rnninsz += opt.emb_size 67 | 68 | self.seg_rnns = nn.ModuleList() 69 | if opt.one_rnn: 70 | self.seg_rnns.append(nn.LSTM(rnninsz, opt.hid_size, 71 | opt.layers, dropout=opt.dropout)) 72 | self.state_embs = nn.Parameter(torch.Tensor(opt.K, 1, 1, opt.emb_size)) 73 | else: 74 | for _ in xrange(opt.K): 75 | self.seg_rnns.append(nn.LSTM(rnninsz, opt.hid_size, 76 | opt.layers, dropout=opt.dropout)) 77 | self.ar_rnn = nn.LSTM(opt.emb_size, opt.hid_size, opt.layers, dropout=opt.dropout) 78 | 79 | self.h0_lin = nn.Linear(opt.emb_size, 2*opt.hid_size) 80 | self.state_att_gates = nn.Parameter(torch.Tensor(opt.K, 1, 1, opt.hid_size)) 81 | self.state_att_biases = nn.Parameter(torch.Tensor(opt.K, 1, 1, opt.hid_size)) 82 | 83 | self.sep_attn = opt.sep_attn 84 | if self.sep_attn: 85 | self.state_att2_gates = nn.Parameter(torch.Tensor(opt.K, 1, 1, opt.hid_size)) 86 | self.state_att2_biases = nn.Parameter(torch.Tensor(opt.K, 1, 1, opt.hid_size)) 87 | 88 | out_hid_sz = opt.hid_size + opt.emb_size 89 | self.state_out_gates = nn.Parameter(torch.Tensor(opt.K, 1, 1, out_hid_sz)) 90 | self.state_out_biases = nn.Parameter(torch.Tensor(opt.K, 1, 1, out_hid_sz)) 91 | # add one more output word for eop 92 | self.decoder = nn.Linear(out_hid_sz, gentypes+1) 93 | self.eop_idx = gentypes 94 | self.attn_lin1 = nn.Linear(opt.hid_size, opt.emb_size) 95 | self.linear_out = nn.Linear(opt.hid_size + opt.emb_size, opt.hid_size) 96 | 97 | self.drop = nn.Dropout(opt.dropout) 98 | self.emb_drop = opt.emb_drop 99 | self.initrange = opt.initrange 100 | self.lsm = nn.LogSoftmax(dim=1) 101 | self.zeros = torch.Tensor(1, 1).fill_(-float("inf")) if opt.lse_obj else torch.zeros(1, 1) 102 | self.lse_obj = opt.lse_obj 103 | if opt.cuda: 104 | self.zeros = self.zeros.cuda() 105 | 106 | # src encoder stuff 107 | self.src_bias = nn.Parameter(torch.Tensor(1, opt.emb_size)) 108 | self.uniq_bias = nn.Parameter(torch.Tensor(1, opt.emb_size)) 109 | 110 | self.init_lin = nn.Linear(opt.emb_size, opt.K*opt.Kmul) 111 | self.cond_A_dim = opt.cond_A_dim 112 | self.smaller_cond_dim = opt.smaller_cond_dim 113 | if opt.smaller_cond_dim > 0: 114 | self.cond_trans_lin = nn.Sequential( 115 | nn.Linear(opt.emb_size, opt.smaller_cond_dim), 116 | nn.ReLU(), 117 | nn.Linear(opt.smaller_cond_dim, opt.K*opt.Kmul*opt.cond_A_dim*2)) 118 | else: 119 | self.cond_trans_lin = nn.Linear(opt.emb_size, opt.K*opt.Kmul*opt.cond_A_dim*2) 120 | self.init_weights() 121 | 122 | 123 | def init_weights(self): 124 | """ 125 | (re)init weights 126 | """ 127 | initrange = self.initrange 128 | self.lut.weight.data.uniform_(-initrange, initrange) 129 | self.lut.weight.data[self.pad_idx].zero_() 130 | self.lut.weight.data[corpus.dictionary.word2idx[""]].zero_() 131 | self.lut.weight.data[corpus.dictionary.word2idx[""]].zero_() 132 | self.lut.weight.data[corpus.dictionary.word2idx[""]].zero_() 133 | params = [self.src_bias, self.state_out_gates, self.state_att_gates, 134 | self.state_out_biases, self.state_att_biases, self.start_emb, 135 | self.A_from, self.A_to, self.uniq_bias] 136 | if self.sep_attn: 137 | params.extend([self.state_att2_gates, self.state_att2_biases]) 138 | if self.one_rnn: 139 | params.append(self.state_embs) 140 | 141 | for param in params: 142 | param.data.uniform_(-initrange, initrange) 143 | 144 | rnns = [rnn for rnn in self.seg_rnns] 145 | rnns.append(self.ar_rnn) 146 | for rnn in rnns: 147 | for thing in rnn.parameters(): 148 | thing.data.uniform_(-initrange, initrange) 149 | 150 | lins = [self.init_lin, self.decoder, self.attn_lin1, self.linear_out, self.h0_lin] 151 | if self.smaller_cond_dim == 0: 152 | lins.append(self.cond_trans_lin) 153 | else: 154 | lins.extend([self.cond_trans_lin[0], self.cond_trans_lin[2]]) 155 | if not self.unif_lenps: 156 | lins.append(self.len_decoder) 157 | if self.mlpinp: 158 | lins.append(self.inpmlp[0]) 159 | for lin in lins: 160 | lin.weight.data.uniform_(-initrange, initrange) 161 | if lin.bias is not None: 162 | lin.bias.data.zero_() 163 | 164 | 165 | def trans_logprobs(self, uniqenc, seqlen): 166 | """ 167 | args: 168 | uniqenc - bsz x emb_size 169 | returns: 170 | 1 x K tensor and seqlen-1 x bsz x K x K tensor of log probabilities, 171 | where lps[i] is p(q_{i+1} | q_i) 172 | """ 173 | bsz = uniqenc.size(0) 174 | K = self.K*self.Kmul 175 | A_dim = self.A_dim 176 | # bsz x K*A_dim*2 -> bsz x K x A_dim or bsz x K x 2*A_dim 177 | cond_trans_mat = self.cond_trans_lin(uniqenc).view(bsz, K, -1) 178 | # nufrom, nuto each bsz x K x A_dim 179 | A_dim = self.cond_A_dim 180 | nufrom, nuto = cond_trans_mat[:, :, :A_dim], cond_trans_mat[:, :, A_dim:] 181 | A_from, A_to = self.A_from, self.A_to 182 | if self.drop.p > 0: 183 | A_from = self.drop(A_from) 184 | nufrom = self.drop(nufrom) 185 | tscores = torch.mm(A_from, A_to) 186 | if not self.yes_self_trans: 187 | tscores = tscores + self.selfmask 188 | trans_lps = tscores.unsqueeze(0).expand(bsz, K, K) 189 | trans_lps = trans_lps + torch.bmm(nufrom, nuto.transpose(1, 2)) 190 | trans_lps = self.lsm(trans_lps.view(-1, K)).view(bsz, K, K) 191 | 192 | init_lps = self.lsm(self.init_lin(uniqenc)) # bsz x K 193 | return init_lps, trans_lps.view(1, bsz, K, K).expand(seqlen-1, bsz, K, K) 194 | 195 | def len_logprobs(self): 196 | """ 197 | returns: 198 | [1xK tensor, 2 x K tensor, .., L-1 x K tensor, L x K tensor] of logprobs 199 | """ 200 | K = self.K*self.Kmul 201 | state_embs = torch.cat([self.A_from, self.A_to.t()], 1) # K x 2*A_dim 202 | if self.unif_lenps: 203 | len_scores = self.len_scores.expand(K, self.L) 204 | else: 205 | len_scores = self.len_decoder(state_embs) # K x L 206 | lplist = [Variable(len_scores.data.new(1, K).zero_())] 207 | for l in xrange(2, self.L+1): 208 | lplist.append(self.lsm(len_scores.narrow(1, 0, l)).t()) 209 | return lplist, len_scores 210 | 211 | 212 | def to_seg_embs(self, xemb): 213 | """ 214 | xemb - bsz x seqlen x emb_size 215 | returns - L+1 x bsz*seqlen x emb_size, 216 | where [1 2 3 4] becomes [ ] 217 | [5 6 7 8] [ 1 2 3 4 5 6 7 8 ] 218 | [ 2 3 4

6 7 8

] 219 | [ 3 4

7 8

] 220 | """ 221 | bsz, seqlen, emb_size = xemb.size() 222 | newx = [self.start_emb.expand(bsz, seqlen, emb_size)] 223 | newx.append(xemb) 224 | for i in xrange(1, self.L): 225 | pad = self.pad_emb.expand(bsz, i, emb_size) 226 | rowi = torch.cat([xemb[:, i:], pad], 1) 227 | newx.append(rowi) 228 | # L+1 x bsz x seqlen x emb_size -> L+1 x bsz*seqlen x emb_size 229 | return torch.stack(newx).view(self.L+1, -1, emb_size) 230 | 231 | 232 | def to_seg_hist(self, states): 233 | """ 234 | states - bsz x seqlen+1 x rnn_size 235 | returns - L+1 x bsz*seqlen x emb_size, 236 | where [ 1 2 3 4] becomes [ 1 2 3 5 6 7 ] 237 | [ 5 6 7 8] [ 1 2 3 4 5 6 7 8 ] 238 | [ 2 3 4

6 7 8

] 239 | [ 3 4

7 8

] 240 | """ 241 | bsz, seqlenp1, rnn_size = states.size() 242 | newh = [states[:, :seqlenp1-1, :]] # [bsz x seqlen x rnn_size] 243 | newh.append(states[:, 1:, :]) 244 | for i in xrange(1, self.L): 245 | pad = self.pad_emb[:, :, :rnn_size].expand(bsz, i, rnn_size) 246 | rowi = torch.cat([states[:, i+1:, :], pad], 1) 247 | newh.append(rowi) 248 | # L+1 x bsz x seqlen x rnn_size -> L+1 x bsz*seqlen x rnn_size 249 | return torch.stack(newh).view(self.L+1, -1, rnn_size) 250 | 251 | 252 | def obs_logprobs(self, x, srcenc, srcfieldenc, fieldmask, combotargs, bsz): 253 | """ 254 | args: 255 | x - seqlen x bsz x max_locs x nfeats 256 | srcenc - bsz x emb_size 257 | srcfieldenc - bsz x nfields x dim 258 | fieldmask - bsz x nfields mask with 0s and -infs where it's a dummy field 259 | combotargs - L x bsz*seqlen x max_locs 260 | returns: 261 | a L x seqlen x bsz x K tensor, where l'th row has prob of sequences of length l+1. 262 | specifically, obs_logprobs[:,t,i,k] gives p(x_t|k), p(x_{t:t+1}|k), ..., p(x_{t:t+l}|k). 263 | the infc code ignores the entries rows corresponding to x_{t:t+m} where t+m > T 264 | """ 265 | seqlen, bsz, maxlocs, nfeats = x.size() 266 | embs = self.lut(x.view(seqlen, -1)) # seqlen x bsz*maxlocs*nfeats x emb_size 267 | 268 | if self.mlpinp: 269 | inpembs = self.inpmlp(embs.view(seqlen, bsz, maxlocs, -1)).mean(2) 270 | else: 271 | inpembs = embs.view(seqlen, bsz, maxlocs, -1).mean(2) # seqlen x bsz x nfeats*emb_size 272 | 273 | if self.emb_drop: 274 | inpembs = self.drop(inpembs) 275 | 276 | if self.ar: 277 | if self.word_ar: 278 | ar_embs = embs.view(seqlen, bsz, maxlocs, nfeats, -1)[:, :, 0, 0] # seqlen x bsz x embsz 279 | else: # ar on fields 280 | ar_embs = embs.view(seqlen, bsz, maxlocs, nfeats, -1)[:, :, :, 1].mean(2) # same 281 | if self.emb_drop: 282 | ar_embs = self.drop(ar_embs) 283 | 284 | # add on initial thing; this is a HACK! 285 | embsz = ar_embs.size(2) 286 | ar_embs = torch.cat([self.lut.weight[2].view(1, 1, embsz).expand(1, bsz, embsz), 287 | ar_embs], 0) # seqlen+1 x bsz x emb_size 288 | ar_states, _ = self.ar_rnn(ar_embs) # seqlen+1 x bsz x rnn_size 289 | 290 | # get L+1 x bsz*seqlen x emb_size segembs 291 | segembs = self.to_seg_embs(inpembs.transpose(0, 1)) 292 | Lp1, bszsl, _ = segembs.size() 293 | if self.ar: 294 | segars = self.to_seg_hist(ar_states.transpose(0, 1)) #L+1 x bsz*seqlen x rnn_size 295 | 296 | bsz, nfields, encdim = srcfieldenc.size() 297 | layers, rnn_size = self.layers, self.hid_size 298 | 299 | # bsz x dim -> bsz x seqlen x dim -> bsz*seqlen x dim -> layers x bsz*seqlen x dim 300 | inits = self.h0_lin(srcenc) # bsz x 2*dim 301 | h0, c0 = inits[:, :rnn_size], inits[:, rnn_size:] # (bsz x dim, bsz x dim) 302 | h0 = F.tanh(h0).unsqueeze(1).expand(bsz, seqlen, rnn_size).contiguous().view( 303 | -1, rnn_size).unsqueeze(0).expand(layers, -1, rnn_size).contiguous() 304 | c0 = c0.unsqueeze(1).expand(bsz, seqlen, rnn_size).contiguous().view( 305 | -1, rnn_size).unsqueeze(0).expand(layers, -1, rnn_size).contiguous() 306 | 307 | # easiest to just loop over K 308 | state_emb_sz = self.state_embs.size(3) 309 | seg_lls = [] 310 | for k in xrange(self.K): 311 | if self.one_rnn: 312 | condembs = torch.cat( 313 | [segembs, self.state_embs[k].expand(Lp1, bszsl, state_emb_sz)], 2) 314 | states, _ = self.seg_rnns[0](condembs, (h0, c0)) # L+1 x bsz*seqlen x rnn_size 315 | else: 316 | states, _ = self.seg_rnns[k](segembs, (h0, c0)) # L+1 x bsz*seqlen x rnn_size 317 | 318 | if self.ar: 319 | states = states + segars # L+1 x bsz*seqlen x rnn_size 320 | 321 | if self.drop.p > 0: 322 | states = self.drop(states) 323 | attnin1 = (states * self.state_att_gates[k].expand_as(states) 324 | + self.state_att_biases[k].expand_as(states)).view( 325 | Lp1, bsz, seqlen, -1) 326 | # L+1 x bsz x seqlen x rnn_size -> bsz x (L+1)seqlen x rnn_size 327 | attnin1 = attnin1.transpose(0, 1).contiguous().view(bsz, Lp1*seqlen, -1) 328 | attnin1 = F.tanh(attnin1) 329 | ascores = torch.bmm(attnin1, srcfieldenc.transpose(1, 2)) # bsz x (L+1)slen x nfield 330 | ascores = ascores + fieldmask.unsqueeze(1).expand_as(ascores) 331 | aprobs = F.softmax(ascores, dim=2) 332 | # bsz x (L+1)seqlen x nfields * bsz x nfields x dim -> bsz x (L+1)seqlen x dim 333 | ctx = torch.bmm(aprobs, srcfieldenc) 334 | # concatenate states and ctx to get L+1 x bsz x seqlen x rnn_size + encdim 335 | cat_ctx = torch.cat([states.view(Lp1, bsz, seqlen, -1), 336 | ctx.view(bsz, Lp1, seqlen, -1).transpose(0, 1)], 3) 337 | out_hid_sz = rnn_size + encdim 338 | cat_ctx = cat_ctx.view(Lp1, -1, out_hid_sz) 339 | # now linear to get L+1 x bsz*seqlen x rnn_size 340 | states_k = F.tanh(cat_ctx * self.state_out_gates[k].expand_as(cat_ctx) 341 | + self.state_out_biases[k].expand_as(cat_ctx)).view( 342 | Lp1, -1, out_hid_sz) 343 | 344 | if self.sep_attn: 345 | attnin2 = (states * self.state_att2_gates[k].expand_as(states) 346 | + self.state_att2_biases[k].expand_as(states)).view( 347 | Lp1, bsz, seqlen, -1) 348 | # L+1 x bsz x seqlen x rnn_size -> bsz x (L+1)seqlen x emb_size 349 | attnin2 = attnin2.transpose(0, 1).contiguous().view(bsz, Lp1*seqlen, -1) 350 | attnin2 = F.tanh(attnin2) 351 | ascores = torch.bmm(attnin2, srcfieldenc.transpose(1, 2)) # bsz x (L+1)slen x nfield 352 | ascores = ascores + fieldmask.unsqueeze(1).expand_as(ascores) 353 | 354 | normfn = F.log_softmax if self.lse_obj else F.softmax 355 | wlps_k = normfn(torch.cat([self.decoder(states_k.view(-1, out_hid_sz)), #L+1*bsz*sl x V 356 | ascores.view(bsz, Lp1, seqlen, nfields).transpose( 357 | 0, 1).contiguous().view(-1, nfields)], 1), dim=1) 358 | # concatenate on dummy column for when only a single answer... 359 | wlps_k = torch.cat([wlps_k, Variable(self.zeros.expand(wlps_k.size(0), 1))], 1) 360 | # get scores for predicted next-words (but not for last words in each segment as usual) 361 | psk = wlps_k.narrow(0, 0, self.L*bszsl).gather(1, combotargs.view(self.L*bszsl, -1)) 362 | if self.lse_obj: 363 | lls_k = logsumexp1(psk) 364 | else: 365 | lls_k = psk.sum(1).log() 366 | 367 | # sum up log probs of words in each segment 368 | seglls_k = lls_k.view(self.L, -1).cumsum(0) # L x bsz*seqlen 369 | # need to add end-of-phrase prob too 370 | eop_lps = wlps_k.narrow(0, bszsl, self.L*bszsl)[:, self.eop_idx] # L*bsz*seqlen 371 | if self.lse_obj: 372 | seglls_k = seglls_k + eop_lps.contiguous().view(self.L, -1) 373 | else: 374 | seglls_k = seglls_k + eop_lps.log().view(self.L, -1) 375 | seg_lls.append(seglls_k) 376 | 377 | # K x L x bsz x seqlen -> seqlen x L x bsz x K -> L x seqlen x bsz x K 378 | obslps = torch.stack(seg_lls).view(self.K, self.L, bsz, -1).transpose( 379 | 0, 3).transpose(0, 1) 380 | if self.Kmul > 1: 381 | obslps = obslps.repeat(1, 1, 1, self.Kmul) 382 | return obslps 383 | 384 | 385 | def encode(self, src, avgmask, uniqfields): 386 | """ 387 | args: 388 | src - bsz x nfields x nfeats 389 | avgmask - bsz x nfields, with 0s for pad and 1/tru_nfields for rest 390 | uniqfields - bsz x maxfields 391 | returns bsz x emb_size, bsz x nfields x emb_size 392 | """ 393 | bsz, nfields, nfeats = src.size() 394 | emb_size = self.lut.embedding_dim 395 | # do src stuff that depends on words 396 | embs = self.lut(src.view(-1, nfeats)) # bsz*nfields x nfeats x emb_size 397 | if self.max_pool: 398 | embs = F.relu(embs.sum(1) + self.src_bias.expand(bsz*nfields, emb_size)) 399 | if avgmask is not None: 400 | masked = (embs.view(bsz, nfields, emb_size) 401 | * avgmask.unsqueeze(2).expand(bsz, nfields, emb_size)) 402 | else: 403 | masked = embs.view(bsz, nfields, emb_size) 404 | srcenc = F.max_pool1d(masked.transpose(1, 2), nfields).squeeze(2) # bsz x emb_size 405 | else: 406 | embs = F.tanh(embs.sum(1) + self.src_bias.expand(bsz*nfields, emb_size)) 407 | # average it manually, bleh 408 | if avgmask is not None: 409 | srcenc = (embs.view(bsz, nfields, emb_size) 410 | * avgmask.unsqueeze(2).expand(bsz, nfields, emb_size)).sum(1) 411 | else: 412 | srcenc = embs.view(bsz, nfields, emb_size).mean(1) # bsz x emb_size 413 | 414 | srcfieldenc = embs.view(bsz, nfields, emb_size) 415 | 416 | # do stuff that depends only on uniq fields 417 | uniqenc = self.lut(uniqfields).sum(1) # bsz x nfields x emb_size -> bsz x emb_size 418 | 419 | # add a bias 420 | uniqenc = uniqenc + self.uniq_bias.expand_as(uniqenc) 421 | uniqenc = F.relu(uniqenc) 422 | 423 | return srcenc, srcfieldenc, uniqenc 424 | 425 | def get_next_word_dist(self, hid, k, srcfieldenc): 426 | """ 427 | hid - 1 x bsz x rnn_size 428 | srcfieldenc - 1 x nfields x dim 429 | returns a bsz x nthings dist; not a log dist 430 | """ 431 | bsz = hid.size(1) 432 | _, nfields, rnn_size = srcfieldenc.size() 433 | srcfldenc = srcfieldenc.expand(bsz, nfields, rnn_size) 434 | attnin1 = (hid * self.state_att_gates[k].expand_as(hid) 435 | + self.state_att_biases[k].expand_as(hid)) # 1 x bsz x rnn_size 436 | attnin1 = F.tanh(attnin1) 437 | ascores = torch.bmm(attnin1.transpose(0, 1), srcfldenc.transpose(1, 2)) # bsz x 1 x nfields 438 | aprobs = F.softmax(ascores, dim=2) 439 | ctx = torch.bmm(aprobs, srcfldenc) # bsz x 1 x rnn_size 440 | cat_ctx = torch.cat([hid, ctx.transpose(0, 1)], 2) # 1 x bsz x rnn_size 441 | state_k = F.tanh(cat_ctx * self.state_out_gates[k].expand_as(cat_ctx) 442 | + self.state_out_biases[k].expand_as(cat_ctx)) # 1 x bsz x rnn_size 443 | 444 | if self.sep_attn: 445 | attnin2 = (hid * self.state_att2_gates[k].expand_as(hid) 446 | + self.state_att2_biases[k].expand_as(hid)) 447 | attnin2 = F.tanh(attnin2) 448 | ascores = torch.bmm(attnin2.transpose(0, 1), srcfldenc.transpose(1, 2)) # bsz x 1 x nfld 449 | 450 | wlps_k = F.softmax(torch.cat([self.decoder(state_k.squeeze(0)), 451 | ascores.squeeze(1)], 1), dim=1) 452 | return wlps_k.data 453 | 454 | def collapse_word_probs(self, row2tblent, wrd_dist, corpus): 455 | """ 456 | wrd_dist is a K x nwords matrix and it gets modified. 457 | this collapsing only makes sense if src_tbl is the same for every row. 458 | """ 459 | nout_wrds = self.decoder.out_features 460 | i2w, w2i = corpus.dictionary.idx2word, corpus.dictionary.word2idx 461 | # collapse probabilities 462 | first_seen = {} 463 | for i, (field, idx, wrd) in row2tblent.iteritems(): 464 | if field is not None: 465 | if wrd not in first_seen: 466 | first_seen[wrd] = i 467 | # add gen prob if any 468 | if wrd in corpus.genset: 469 | widx = w2i[wrd] 470 | wrd_dist[:, nout_wrds + i].add_(wrd_dist[:, widx]) 471 | wrd_dist[:, widx].zero_() 472 | else: # seen it before, so add its prob 473 | wrd_dist[:, nout_wrds + first_seen[wrd]].add_(wrd_dist[:, nout_wrds + i]) 474 | wrd_dist[:, nout_wrds + i].zero_() 475 | else: # should really have zeroed out before, but this is easier 476 | wrd_dist[:, nout_wrds + i].zero_() 477 | 478 | def temp_bs(self, corpus, ss, start_inp, exh0, exc0, srcfieldenc, 479 | len_lps, row2tblent, row2feats, K, final_state=False): 480 | """ 481 | ss - discrete state index 482 | exh0 - layers x 1 x rnn_size 483 | exc0 - layers x 1 x rnn_size 484 | start_inp - 1 x 1 x emb_size 485 | len_lps - K x L, log normalized 486 | """ 487 | rul_ss = ss % self.K 488 | i2w = corpus.dictionary.idx2word 489 | w2i = corpus.dictionary.word2idx 490 | genset = corpus.genset 491 | unk_idx, eos_idx, pad_idx = w2i[""], w2i[""], w2i[""] 492 | state_emb_sz = self.state_embs.size(3) if self.one_rnn else 0 493 | if self.one_rnn: 494 | cond_start_inp = torch.cat([start_inp, self.state_embs[rul_ss]], 2) # 1 x 1 x cat_size 495 | hid, (hc, cc) = self.seg_rnns[0](cond_start_inp, (exh0, exc0)) 496 | else: 497 | hid, (hc, cc) = self.seg_rnns[rul_ss](start_inp, (exh0, exc0)) 498 | curr_hyps = [(None, None)] 499 | best_wscore, best_lscore = None, None # so we can truly average over words etc later 500 | best_hyp, best_hyp_score = None, -float("inf") 501 | curr_scores = torch.zeros(K, 1) 502 | # N.B. we assume we have a single feature row for each timestep rather than avg 503 | # over them as at training time. probably better, but could conceivably average like 504 | # at training time. 505 | inps = Variable(torch.LongTensor(K, 4), volatile=True) 506 | for ell in xrange(self.L): 507 | wrd_dist = self.get_next_word_dist(hid, rul_ss, srcfieldenc).cpu() # K x nwords 508 | # disallow unks 509 | wrd_dist[:, unk_idx].zero_() 510 | if not final_state: 511 | wrd_dist[:, eos_idx].zero_() 512 | self.collapse_word_probs(row2tblent, wrd_dist, corpus) 513 | wrd_dist.log_() 514 | if ell > 0: # add previous scores 515 | wrd_dist.add_(curr_scores.expand_as(wrd_dist)) 516 | maxprobs, top2k = torch.topk(wrd_dist.view(-1), 2*K) 517 | cols = wrd_dist.size(1) 518 | # we'll break as soon as is at the top of the beam. 519 | # this ignores but whatever 520 | if top2k[0] == eos_idx: 521 | final_hyp = backtrace(curr_hyps[0]) 522 | final_hyp.append(eos_idx) 523 | return final_hyp, maxprobs[0], len_lps[ss][ell] 524 | 525 | new_hyps, anc_hs, anc_cs = [], [], [] 526 | #inps.data.fill_(pad_idx) 527 | inps.data[:, 1].fill_(w2i[""]) 528 | inps.data[:, 2].fill_(w2i[""]) 529 | inps.data[:, 3].fill_(w2i[""]) 530 | for k in xrange(2*K): 531 | anc, wrd = top2k[k] / cols, top2k[k] % cols 532 | # check if any of the maxes are eop 533 | if wrd == self.eop_idx and ell > 0: 534 | # add len score (and avg over num words incl eop i guess) 535 | wlenscore = maxprobs[k]/(ell+1) + len_lps[ss][ell-1] 536 | if wlenscore > best_hyp_score: 537 | best_hyp_score = wlenscore 538 | best_hyp = backtrace(curr_hyps[anc]) 539 | best_wscore, best_lscore = maxprobs[k], len_lps[ss][ell-1] 540 | else: 541 | curr_scores[len(new_hyps)][0] = maxprobs[k] 542 | if wrd >= self.decoder.out_features: # a copy 543 | tblidx = wrd - self.decoder.out_features 544 | inps.data[len(new_hyps)].copy_(row2feats[tblidx]) 545 | else: 546 | inps.data[len(new_hyps)][0] = wrd if i2w[wrd] in genset else unk_idx 547 | new_hyps.append((wrd, curr_hyps[anc])) 548 | anc_hs.append(hc.narrow(1, anc, 1)) # layers x 1 x rnn_size 549 | anc_cs.append(cc.narrow(1, anc, 1)) # layers x 1 x rnn_size 550 | if len(new_hyps) == K: 551 | break 552 | assert len(new_hyps) == K 553 | curr_hyps = new_hyps 554 | if self.lut.weight.data.is_cuda: 555 | inps = inps.cuda() 556 | embs = self.lut(inps).view(1, K, -1) # 1 x K x nfeats*emb_size 557 | if self.mlpinp: 558 | embs = self.inpmlp(embs) # 1 x K x rnninsz 559 | if self.one_rnn: 560 | cond_embs = torch.cat([embs, self.state_embs[rul_ss].expand(1, K, state_emb_sz)], 2) 561 | hid, (hc, cc) = self.seg_rnns[0](cond_embs, (torch.cat(anc_hs, 1), torch.cat(anc_cs, 1))) 562 | else: 563 | hid, (hc, cc) = self.seg_rnns[rul_ss](embs, (torch.cat(anc_hs, 1), torch.cat(anc_cs, 1))) 564 | # hypotheses of length L still need their end probs added 565 | # N.B. if the falls off the beam we could end up with situations 566 | # where we take an L-length phrase w/ a lower score than 1-word followed by eos. 567 | wrd_dist = self.get_next_word_dist(hid, rul_ss, srcfieldenc).cpu() # K x nwords 568 | wrd_dist.log_() 569 | wrd_dist.add_(curr_scores.expand_as(wrd_dist)) 570 | for k in xrange(K): 571 | wlenscore = wrd_dist[k][self.eop_idx]/(self.L+1) + len_lps[ss][self.L-1] 572 | if wlenscore > best_hyp_score: 573 | best_hyp_score = wlenscore 574 | best_hyp = backtrace(curr_hyps[k]) 575 | best_wscore, best_lscore = wrd_dist[k][self.eop_idx], len_lps[ss][self.L-1] 576 | 577 | return best_hyp, best_wscore, best_lscore 578 | 579 | 580 | def gen_one(self, templt, h0, c0, srcfieldenc, len_lps, row2tblent, row2feats): 581 | """ 582 | src - 1 x nfields x nfeatures 583 | h0 - rnn_size vector 584 | c0 - rnn_size vector 585 | srcfieldenc - 1 x nfields x dim 586 | len_lps - K x L, log normalized 587 | returns a list of phrases 588 | """ 589 | phrases = [] 590 | tote_wscore, tote_lscore, tokes, segs = 0.0, 0.0, 0.0, 0.0 591 | #start_inp = self.lut.weight[start_idx].view(1, 1, -1) 592 | start_inp = self.start_emb 593 | exh0 = h0.view(1, 1, self.hid_size).expand(self.layers, 1, self.hid_size) 594 | exc0 = c0.view(1, 1, self.hid_size).expand(self.layers, 1, self.hid_size) 595 | nout_wrds = self.decoder.out_features 596 | i2w, w2i = corpus.dictionary.idx2word, corpus.dictionary.word2idx 597 | for stidx, k in enumerate(templt): 598 | phrs_idxs, wscore, lscore = self.temp_bs(corpus, k, start_inp, exh0, exc0, 599 | srcfieldenc, len_lps, row2tblent, row2feats, 600 | args.beamsz, final_state=(stidx == len(templt)-1)) 601 | phrs = [] 602 | for ii in xrange(len(phrs_idxs)): 603 | if phrs_idxs[ii] < nout_wrds: 604 | phrs.append(i2w[phrs_idxs[ii]]) 605 | else: 606 | tblidx = phrs_idxs[ii] - nout_wrds 607 | _, _, wordstr = row2tblent[tblidx] 608 | if args.verbose: 609 | phrs.append(wordstr + " (c)") 610 | else: 611 | phrs.append(wordstr) 612 | if phrs[-1] == "": 613 | break 614 | phrases.append(phrs) 615 | tote_wscore += wscore 616 | tote_lscore += lscore 617 | tokes += len(phrs_idxs) + 1 # add 1 for token 618 | segs += 1 619 | 620 | return phrases, tote_wscore, tote_lscore, tokes, segs 621 | 622 | 623 | def temp_ar_bs(self, templt, row2tblent, row2feats, h0, c0, srcfieldenc, len_lps, K, 624 | corpus): 625 | assert self.unif_lenps # ignoring lenps 626 | exh0 = h0.view(1, 1, self.hid_size).expand(self.layers, 1, self.hid_size) 627 | exc0 = c0.view(1, 1, self.hid_size).expand(self.layers, 1, self.hid_size) 628 | start_inp = self.start_emb 629 | state_emb_sz = self.state_embs.size(3) 630 | i2w, w2i = corpus.dictionary.idx2word, corpus.dictionary.word2idx 631 | genset = corpus.genset 632 | unk_idx, eos_idx, pad_idx = w2i[""], w2i[""], w2i[""] 633 | 634 | curr_hyps = [(None, None, None)] 635 | nfeats = 4 636 | inps = Variable(torch.LongTensor(K, nfeats), volatile=True) 637 | curr_scores, curr_lens, nulens = torch.zeros(K, 1), torch.zeros(K, 1), torch.zeros(K, 1) 638 | if self.lut.weight.data.is_cuda: 639 | inps = inps.cuda() 640 | curr_scores, curr_lens, nulens = curr_scores.cuda(), curr_lens.cuda(), nulens.cuda() 641 | 642 | # start ar rnn; hackily use bos_idx 643 | rnnsz = self.ar_rnn.hidden_size 644 | thid, (thc, tcc) = self.ar_rnn(self.lut.weight[2].view(1, 1, -1)) # 1 x 1 x rnn_size 645 | 646 | for stidx, ss in enumerate(templt): 647 | final_state = (stidx == len(templt) - 1) 648 | minq = [] # so we can compare stuff of different lengths 649 | rul_ss = ss % self.K 650 | 651 | if self.one_rnn: 652 | cond_start_inp = torch.cat([start_inp, self.state_embs[rul_ss]], 2) # 1x1x cat_size 653 | hid, (hc, cc) = self.seg_rnns[0](cond_start_inp, (exh0, exc0)) # 1 x 1 x rnn_size 654 | else: 655 | hid, (hc, cc) = self.seg_rnns[rul_ss](start_inp, (exh0, exc0)) # 1 x 1 x rnn_size 656 | hid = hid.expand_as(thid) 657 | hc = hc.expand_as(thc) 658 | cc = cc.expand_as(tcc) 659 | 660 | for ell in xrange(self.L+1): 661 | new_hyps, anc_hs, anc_cs, anc_ths, anc_tcs = [], [], [], [], [] 662 | inps.data[:, 1].fill_(w2i[""]) 663 | inps.data[:, 2].fill_(w2i[""]) 664 | inps.data[:, 3].fill_(w2i[""]) 665 | 666 | wrd_dist = self.get_next_word_dist(hid + thid, rul_ss, srcfieldenc) # K x nwords 667 | currK = wrd_dist.size(0) 668 | # disallow unks and eos's 669 | wrd_dist[:, unk_idx].zero_() 670 | if not final_state: 671 | wrd_dist[:, eos_idx].zero_() 672 | self.collapse_word_probs(row2tblent, wrd_dist, corpus) 673 | wrd_dist.log_() 674 | curr_scores[:currK].mul_(curr_lens[:currK]) 675 | wrd_dist.add_(curr_scores[:currK].expand_as(wrd_dist)) 676 | wrd_dist.div_((curr_lens[:currK]+1).expand_as(wrd_dist)) 677 | maxprobs, top2k = torch.topk(wrd_dist.view(-1), 2*K) 678 | cols = wrd_dist.size(1) 679 | # used to check for eos here, but maybe we shouldn't 680 | 681 | for k in xrange(2*K): 682 | anc, wrd = top2k[k] / cols, top2k[k] % cols 683 | # check if any of the maxes are eop 684 | if wrd == self.eop_idx and ell > 0 and (not final_state or curr_hyps[anc][0] == eos_idx): 685 | ## add len score (and avg over num words *incl eop*) 686 | ## actually ignoring len score for now 687 | #wlenscore = maxprobs[k]/(ell+1) # + len_lps[ss][ell-1] 688 | #assert not final_state or curr_hyps[anc][0] == eos_idx # seems like should hold... 689 | heapitem = (maxprobs[k], curr_lens[anc][0]+1, curr_hyps[anc], 690 | thc.narrow(1, anc, 1), tcc.narrow(1, anc, 1)) 691 | if len(minq) < K: 692 | heapq.heappush(minq, heapitem) 693 | else: 694 | heapq.heappushpop(minq, heapitem) 695 | elif ell < self.L: # only allow non-eop if < L so far 696 | curr_scores[len(new_hyps)][0] = maxprobs[k] 697 | nulens[len(new_hyps)][0] = curr_lens[anc][0]+1 698 | if wrd >= self.decoder.out_features: # a copy 699 | tblidx = wrd - self.decoder.out_features 700 | inps.data[len(new_hyps)].copy_(row2feats[tblidx]) 701 | else: 702 | inps.data[len(new_hyps)][0] = wrd if i2w[wrd] in genset else unk_idx 703 | 704 | new_hyps.append((wrd, ss, curr_hyps[anc])) 705 | anc_hs.append(hc.narrow(1, anc, 1)) # layers x 1 x rnn_size 706 | anc_cs.append(cc.narrow(1, anc, 1)) # layers x 1 x rnn_size 707 | anc_ths.append(thc.narrow(1, anc, 1)) # layers x 1 x rnn_size 708 | anc_tcs.append(tcc.narrow(1, anc, 1)) # layers x 1 x rnn_size 709 | if len(new_hyps) == K: 710 | break 711 | 712 | if ell >= self.L: # don't want to put in eops 713 | break 714 | 715 | assert len(new_hyps) == K 716 | curr_hyps = new_hyps 717 | curr_lens.copy_(nulens) 718 | embs = self.lut(inps).view(1, K, -1) # 1 x K x nfeats*emb_size 719 | if self.word_ar: 720 | ar_embs = embs.view(1, K, nfeats, -1)[:, :, 0] # 1 x K x emb_size 721 | else: # ar on fields 722 | ar_embs = embs.view(1, K, nfeats, -1)[:, :, 1] # 1 x K x emb_size 723 | if self.mlpinp: 724 | embs = self.inpmlp(embs) # 1 x K x rnninsz 725 | if self.one_rnn: 726 | cond_embs = torch.cat([embs, self.state_embs[rul_ss].expand( 727 | 1, K, state_emb_sz)], 2) 728 | hid, (hc, cc) = self.seg_rnns[0](cond_embs, (torch.cat(anc_hs, 1), 729 | torch.cat(anc_cs, 1))) 730 | else: 731 | hid, (hc, cc) = self.seg_rnns[rul_ss](embs, (torch.cat(anc_hs, 1), 732 | torch.cat(anc_cs, 1))) 733 | thid, (thc, tcc) = self.ar_rnn(ar_embs, (torch.cat(anc_ths, 1), 734 | torch.cat(anc_tcs, 1))) 735 | 736 | # retrieve topk for this segment (in reverse order) 737 | seghyps = [heapq.heappop(minq) for _ in xrange(len(minq))] 738 | if len(seghyps) == 0: 739 | return -float("inf"), None 740 | 741 | if len(seghyps) < K and not final_state: 742 | # haaaaaaaaaaaaaaack 743 | ugh = [] 744 | for ick in xrange(K-len(seghyps)): 745 | scoreick, lenick, hypick, thcick, tccick = seghyps[0] 746 | ugh.append((scoreick - 9999999.0 + ick, lenick, hypick, thcick, tccick)) 747 | # break ties for the comparison 748 | ugh.extend(seghyps) 749 | seghyps = ugh 750 | 751 | #assert final_state or len(seghyps) == K 752 | 753 | if final_state: 754 | if len(seghyps) > 0: 755 | scoreb, lenb, hypb, thcb, tccb = seghyps[-1] 756 | return scoreb, backtrace3(hypb) 757 | else: 758 | return -float("inf"), None 759 | else: 760 | thidlst, thclst, tcclst = [], [], [] 761 | for i in xrange(K): 762 | scorei, leni, hypi, thci, tcci = seghyps[K-i-1] 763 | curr_scores[i][0], curr_lens[i][0], curr_hyps[i] = scorei, leni, hypi 764 | thidlst.append(thci[-1:, :, :]) # each is 1 x 1 x rnn_size 765 | thclst.append(thci) # each is layers x 1 x rnn_size 766 | tcclst.append(tcci) # each is layers x 1 x rnn_size 767 | 768 | # we already have the state for the next word b/c we put it thru to also predict eop 769 | thid, (thc, tcc) = torch.cat(thidlst, 1), (torch.cat(thclst, 1), torch.cat(tcclst, 1)) 770 | 771 | 772 | def gen_one_ar(self, templt, h0, c0, srcfieldenc, len_lps, row2tblent, row2feats): 773 | """ 774 | src - 1 x nfields x nfeatures 775 | h0 - rnn_size vector 776 | c0 - rnn_size vector 777 | srcfieldenc - 1 x nfields x dim 778 | len_lps - K x L, log normalized 779 | returns a list of phrases 780 | """ 781 | nout_wrds = self.decoder.out_features 782 | i2w, w2i = corpus.dictionary.idx2word, corpus.dictionary.word2idx 783 | phrases, phrs = [], [] 784 | tokes = 0.0 785 | wscore, hyp = self.temp_ar_bs(templt, row2tblent, row2feats, h0, c0, srcfieldenc, len_lps, 786 | args.beamsz, corpus) 787 | if hyp is None: 788 | return None, -float("inf"), 0 789 | curr_labe = hyp[0][1] 790 | tokes = 0 791 | for widx, labe in hyp: 792 | if labe != curr_labe: 793 | phrases.append(phrs) 794 | tokes += len(phrs) 795 | phrs = [] 796 | curr_labe = labe 797 | if widx < nout_wrds: 798 | phrs.append(i2w[widx]) 799 | else: 800 | tblidx = widx - nout_wrds 801 | _, _, wordstr = row2tblent[tblidx] 802 | if args.verbose: 803 | phrs.append(wordstr + " (c)") 804 | else: 805 | phrs.append(wordstr) 806 | if len(phrs) > 0: 807 | phrases.append(phrs) 808 | tokes += len(phrs) 809 | 810 | return phrases, wscore, tokes 811 | 812 | def make_combo_targs(locs, x, L, nfields, ngen_types): 813 | """ 814 | combines word and copy targets into a single tensor. 815 | locs - seqlen x bsz x max_locs 816 | x - seqlen x bsz 817 | assumes we have word indices, then fields, then a dummy 818 | returns L x bsz*seqlen x max_locs tensor corresponding to xsegs[1:] 819 | """ 820 | seqlen, bsz, max_locs = locs.size() 821 | # first replace -1s in first loc with target words 822 | addloc = locs + (ngen_types+1) # seqlen x bsz x max_locs 823 | firstloc = addloc[:, :, 0] # seqlen x bsz 824 | targmask = (firstloc == ngen_types) # -1 will now have value ngentypes 825 | firstloc[targmask] = x[targmask] 826 | # now replace remaining -1s w/ zero location 827 | addloc[addloc == ngen_types] = ngen_types+1+nfields # last index 828 | # finally put in same format as x_segs 829 | newlocs = torch.LongTensor(L, seqlen, bsz, max_locs).fill_(ngen_types+1+nfields) 830 | for i in xrange(L): 831 | newlocs[i][:seqlen-i].copy_(addloc[i:]) 832 | return newlocs.transpose(1, 2).contiguous().view(L, bsz*seqlen, max_locs) 833 | 834 | 835 | def get_uniq_fields(src, pad_idx, keycol=0): 836 | """ 837 | src - bsz x nfields x nfeats 838 | """ 839 | bsz = src.size(0) 840 | # get unique keys for each example 841 | keys = [torch.LongTensor(list(set(src[b, :, keycol]))) for b in xrange(bsz)] 842 | maxkeys = max(keyset.size(0) for keyset in keys) 843 | fields = torch.LongTensor(bsz, maxkeys).fill_(pad_idx) 844 | for b, keyset in enumerate(keys): 845 | fields[b][:len(keyset)].copy_(keyset) 846 | return fields 847 | 848 | 849 | def make_masks(src, pad_idx, max_pool=False): 850 | """ 851 | src - bsz x nfields x nfeats 852 | """ 853 | neginf = -1e38 854 | bsz, nfields, nfeats = src.size() 855 | fieldmask = (src.eq(pad_idx).sum(2) == nfeats) # binary bsz x nfields tensor 856 | avgmask = (1 - fieldmask).float() # 1s where not padding 857 | if not max_pool: 858 | avgmask.div_(avgmask.sum(1, True).expand(bsz, nfields)) 859 | fieldmask = fieldmask.float() * neginf # 0 where not all pad and -1e38 elsewhere 860 | return fieldmask, avgmask 861 | 862 | parser = argparse.ArgumentParser(description='') 863 | parser.add_argument('-data', type=str, default='', help='path to data dir') 864 | parser.add_argument('-epochs', type=int, default=40, help='upper epoch limit') 865 | parser.add_argument('-bsz', type=int, default=16, help='batch size') 866 | parser.add_argument('-seed', type=int, default=1111, help='random seed') 867 | parser.add_argument('-cuda', action='store_true', help='use CUDA') 868 | parser.add_argument('-log_interval', type=int, default=200, 869 | help='minibatches to wait before logging training status') 870 | parser.add_argument('-save', type=str, default='', help='path to save the final model') 871 | parser.add_argument('-load', type=str, default='', help='path to saved model') 872 | parser.add_argument('-test', action='store_true', help='use test data') 873 | parser.add_argument('-thresh', type=int, default=9, help='prune if occurs <= thresh') 874 | parser.add_argument('-max_mbs_per_epoch', type=int, default=35000, help='max minibatches per epoch') 875 | 876 | parser.add_argument('-emb_size', type=int, default=100, help='size of word embeddings') 877 | parser.add_argument('-hid_size', type=int, default=100, help='size of rnn hidden state') 878 | parser.add_argument('-layers', type=int, default=1, help='num rnn layers') 879 | parser.add_argument('-A_dim', type=int, default=64, 880 | help='dim of factors if factoring transition matrix') 881 | parser.add_argument('-cond_A_dim', type=int, default=32, 882 | help='dim of factors if factoring transition matrix') 883 | parser.add_argument('-smaller_cond_dim', type=int, default=64, 884 | help='dim of thing we feed into linear to get transitions') 885 | parser.add_argument('-yes_self_trans', action='store_true', help='') 886 | parser.add_argument('-mlpinp', action='store_true', help='') 887 | parser.add_argument('-mlp_sz_mult', type=int, default=2, help='mlp hidsz is this x emb_size') 888 | parser.add_argument('-max_pool', action='store_true', help='for word-fields') 889 | 890 | parser.add_argument('-constr_tr_epochs', type=int, default=100, help='') 891 | parser.add_argument('-no_ar_epochs', type=int, default=100, help='') 892 | 893 | parser.add_argument('-word_ar', action='store_true', help='') 894 | parser.add_argument('-ar_after_decay', action='store_true', help='') 895 | parser.add_argument('-no_ar_for_vit', action='store_true', help='') 896 | parser.add_argument('-fine_tune', action='store_true', help='only train ar rnn') 897 | 898 | parser.add_argument('-dropout', type=float, default=0.3, help='dropout') 899 | parser.add_argument('-emb_drop', action='store_true', help='dropout on embeddings') 900 | parser.add_argument('-lse_obj', action='store_true', help='') 901 | parser.add_argument('-sep_attn', action='store_true', help='') 902 | parser.add_argument('-max_seqlen', type=int, default=70, help='') 903 | 904 | parser.add_argument('-K', type=int, default=10, help='number of states') 905 | parser.add_argument('-Kmul', type=int, default=1, help='number of states multiplier') 906 | parser.add_argument('-L', type=int, default=10, help='max segment length') 907 | parser.add_argument('-unif_lenps', action='store_true', help='') 908 | parser.add_argument('-one_rnn', action='store_true', help='') 909 | 910 | parser.add_argument('-initrange', type=float, default=0.1, help='uniform init interval') 911 | parser.add_argument('-lr', type=float, default=1.0, help='initial learning rate') 912 | parser.add_argument('-lr_decay', type=float, default=0.5, help='learning rate decay') 913 | parser.add_argument('-optim', type=str, default="sgd", help='optimization algorithm') 914 | parser.add_argument('-onmt_decay', action='store_true', help='') 915 | parser.add_argument('-clip', type=float, default=5, help='gradient clipping') 916 | parser.add_argument('-interactive', action='store_true', help='') 917 | parser.add_argument('-label_train', action='store_true', help='') 918 | parser.add_argument('-gen_from_fi', type=str, default='', help='') 919 | parser.add_argument('-verbose', action='store_true', help='') 920 | parser.add_argument('-prev_loss', type=float, default=None, help='') 921 | parser.add_argument('-best_loss', type=float, default=None, help='') 922 | 923 | parser.add_argument('-tagged_fi', type=str, default='', help='path to tagged fi') 924 | parser.add_argument('-ntemplates', type=int, default=200, help='num templates for gen') 925 | parser.add_argument('-beamsz', type=int, default=1, help='') 926 | parser.add_argument('-gen_wts', type=str, default='1,1', help='') 927 | parser.add_argument('-min_gen_tokes', type=int, default=0, help='') 928 | parser.add_argument('-min_gen_states', type=int, default=0, help='') 929 | parser.add_argument('-gen_on_valid', action='store_true', help='') 930 | parser.add_argument('-align', action='store_true', help='') 931 | parser.add_argument('-wid_workers', type=str, default='', help='') 932 | 933 | if __name__ == "__main__": 934 | args = parser.parse_args() 935 | print args 936 | 937 | torch.manual_seed(args.seed) 938 | if torch.cuda.is_available(): 939 | if not args.cuda: 940 | print "WARNING: You have a CUDA device, so you should probably run with -cuda" 941 | else: 942 | torch.cuda.manual_seed(args.seed) 943 | 944 | # Load data 945 | corpus = labeled_data.SentenceCorpus(args.data, args.bsz, thresh=args.thresh, add_bos=False, 946 | add_eos=False, test=args.test) 947 | 948 | if not args.interactive and not args.label_train and len(args.gen_from_fi) == 0: 949 | # make constraint things from labels 950 | train_cidxs, train_fwd_cidxs = [], [] 951 | for i in xrange(len(corpus.train)): 952 | x, constrs, _, _, _ = corpus.train[i] 953 | train_cidxs.append(make_bwd_constr_idxs(args.L, x.size(0), constrs)) 954 | train_fwd_cidxs.append(make_fwd_constr_idxs(args.L, x.size(0), constrs)) 955 | 956 | saved_args, saved_state = None, None 957 | if len(args.load) > 0: 958 | saved_stuff = torch.load(args.load) 959 | saved_args, saved_state = saved_stuff["opt"], saved_stuff["state_dict"] 960 | for k, v in args.__dict__.iteritems(): 961 | if k not in saved_args.__dict__: 962 | saved_args.__dict__[k] = v 963 | net = HSMM(len(corpus.dictionary), corpus.ngen_types, saved_args) 964 | # for some reason selfmask breaks load_state 965 | del saved_state["selfmask"] 966 | net.load_state_dict(saved_state, strict=False) 967 | args.pad_idx = corpus.dictionary.word2idx[""] 968 | if args.fine_tune: 969 | for name, param in net.named_parameters(): 970 | if name in saved_state: 971 | param.requires_grad = False 972 | 973 | else: 974 | args.pad_idx = corpus.dictionary.word2idx[""] 975 | net = HSMM(len(corpus.dictionary), corpus.ngen_types, args) 976 | 977 | if args.cuda: 978 | net = net.cuda() 979 | 980 | if args.optim == "adagrad": 981 | optalg = optim.Adagrad(filter(lambda p: p.requires_grad, net.parameters()), lr=args.lr) 982 | for group in optalg.param_groups: 983 | for p in group['params']: 984 | optalg.state[p]['sum'].fill_(0.1) 985 | elif args.optim == "rmsprop": 986 | optalg = optim.RMSprop(filter(lambda p: p.requires_grad, net.parameters()), lr=args.lr) 987 | elif args.optim == "adam": 988 | optalg = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=args.lr) 989 | else: 990 | optalg = None 991 | 992 | def train(epoch): 993 | # Turn on training mode which enables dropout. 994 | net.train() 995 | neglogev = 0.0 # negative log evidence 996 | nsents = 0 997 | trainperm = torch.randperm(len(corpus.train)) 998 | nmini_batches = min(len(corpus.train), args.max_mbs_per_epoch) 999 | for batch_idx in xrange(nmini_batches): 1000 | net.zero_grad() 1001 | x, _, src, locs, inps = corpus.train[trainperm[batch_idx]] 1002 | cidxs = train_cidxs[trainperm[batch_idx]] if epoch <= args.constr_tr_epochs else None 1003 | 1004 | seqlen, bsz = x.size() 1005 | nfields = src.size(1) 1006 | if seqlen < args.L or seqlen > args.max_seqlen: 1007 | continue 1008 | 1009 | combotargs = make_combo_targs(locs, x, args.L, nfields, corpus.ngen_types) 1010 | # get bsz x nfields, bsz x nfields masks 1011 | fmask, amask = make_masks(src, args.pad_idx, max_pool=args.max_pool) 1012 | 1013 | uniqfields = get_uniq_fields(src, args.pad_idx) # bsz x max_fields 1014 | 1015 | if args.cuda: 1016 | combotargs = combotargs.cuda() 1017 | if cidxs is not None: 1018 | cidxs = [tens.cuda() if tens is not None else None for tens in cidxs] 1019 | src = src.cuda() 1020 | inps = inps.cuda() 1021 | fmask, amask = fmask.cuda(), amask.cuda() 1022 | uniqfields = uniqfields.cuda() 1023 | 1024 | srcenc, srcfieldenc, uniqenc = net.encode(Variable(src), Variable(amask), # bsz x hid 1025 | Variable(uniqfields)) 1026 | init_logps, trans_logps = net.trans_logprobs(uniqenc, seqlen) # bsz x K, T-1 x bsz x KxK 1027 | len_logprobs, _ = net.len_logprobs() 1028 | fwd_obs_logps = net.obs_logprobs(Variable(inps), srcenc, srcfieldenc, Variable(fmask), 1029 | Variable(combotargs), bsz) # L x T x bsz x K 1030 | # get T+1 x bsz x K beta quantities 1031 | beta, beta_star = infc.just_bwd(trans_logps, fwd_obs_logps, 1032 | len_logprobs, constraints=cidxs) 1033 | log_marg = logsumexp1(beta_star[0] + init_logps).sum() # bsz x 1 -> 1 1034 | neglogev -= log_marg.data[0] 1035 | lossvar = -log_marg/bsz 1036 | lossvar.backward() 1037 | torch.nn.utils.clip_grad_norm(net.parameters(), args.clip) 1038 | if optalg is not None: 1039 | optalg.step() 1040 | else: 1041 | for p in net.parameters(): 1042 | if p.grad is not None: 1043 | p.data.add_(-args.lr, p.grad.data) 1044 | 1045 | nsents += bsz 1046 | 1047 | if (batch_idx+1) % args.log_interval == 0: 1048 | print "batch %d/%d | train neglogev %g " % (batch_idx+1, 1049 | nmini_batches, 1050 | neglogev/nsents) 1051 | print "epoch %d | train neglogev %g " % (epoch, neglogev/nsents) 1052 | return neglogev/nsents 1053 | 1054 | def test(epoch): 1055 | net.eval() 1056 | neglogev = 0.0 1057 | nsents = 0 1058 | 1059 | for i in xrange(len(corpus.valid)): 1060 | x, _, src, locs, inps = corpus.valid[i] 1061 | cidxs = None 1062 | 1063 | seqlen, bsz = x.size() 1064 | nfields = src.size(1) 1065 | if seqlen < args.L or seqlen > args.max_seqlen: 1066 | continue 1067 | 1068 | combotargs = make_combo_targs(locs, x, args.L, nfields, corpus.ngen_types) 1069 | # get bsz x nfields, bsz x nfields masks 1070 | fmask, amask = make_masks(src, args.pad_idx, max_pool=args.max_pool) 1071 | 1072 | uniqfields = get_uniq_fields(src, args.pad_idx) # bsz x max_fields 1073 | 1074 | if args.cuda: 1075 | combotargs = combotargs.cuda() 1076 | if cidxs is not None: 1077 | cidxs = [tens.cuda() if tens is not None else None for tens in cidxs] 1078 | src = src.cuda() 1079 | inps = inps.cuda() 1080 | fmask, amask = fmask.cuda(), amask.cuda() 1081 | uniqfields = uniqfields.cuda() 1082 | 1083 | srcenc, srcfieldenc, uniqenc = net.encode(Variable(src, volatile=True), # bsz x hid 1084 | Variable(amask, volatile=True), 1085 | Variable(uniqfields, volatile=True)) 1086 | init_logps, trans_logps = net.trans_logprobs(uniqenc, seqlen) # bsz x K, T-1 x bsz x KxK 1087 | len_logprobs, _ = net.len_logprobs() 1088 | fwd_obs_logps = net.obs_logprobs(Variable(inps, volatile=True), srcenc, 1089 | srcfieldenc, Variable(fmask, volatile=True), 1090 | Variable(combotargs, volatile=True), 1091 | bsz) # L x T x bsz x K 1092 | 1093 | # get T+1 x bsz x K beta quantities 1094 | beta, beta_star = infc.just_bwd(trans_logps, fwd_obs_logps, 1095 | len_logprobs, constraints=cidxs) 1096 | log_marg = logsumexp1(beta_star[0] + init_logps).sum() # bsz x 1 -> 1 1097 | neglogev -= log_marg.data[0] 1098 | nsents += bsz 1099 | print "epoch %d | valid ev %g" % (epoch, neglogev/nsents) 1100 | return neglogev/nsents 1101 | 1102 | def label_train(): 1103 | net.ar = saved_args.ar_after_decay and not args.no_ar_for_vit 1104 | print "btw, net.ar:", net.ar 1105 | for i in xrange(len(corpus.train)): 1106 | x, _, src, locs, inps = corpus.train[i] 1107 | fwd_cidxs = None 1108 | 1109 | seqlen, bsz = x.size() 1110 | nfields = src.size(1) 1111 | if seqlen <= saved_args.L: #or seqlen > args.max_seqlen: 1112 | continue 1113 | 1114 | combotargs = make_combo_targs(locs, x, saved_args.L, nfields, corpus.ngen_types) 1115 | # get bsz x nfields, bsz x nfields masks 1116 | fmask, amask = make_masks(src, saved_args.pad_idx, max_pool=saved_args.max_pool) 1117 | uniqfields = get_uniq_fields(src, args.pad_idx) # bsz x max_fields 1118 | 1119 | if args.cuda: 1120 | combotargs = combotargs.cuda() 1121 | if fwd_cidxs is not None: 1122 | fwd_cidxs = [tens.cuda() if tens is not None else None for tens in fwd_cidxs] 1123 | src = src.cuda() 1124 | inps = inps.cuda() 1125 | fmask, amask = fmask.cuda(), amask.cuda() 1126 | uniqfields = uniqfields.cuda() 1127 | 1128 | srcenc, srcfieldenc, uniqenc = net.encode(Variable(src, volatile=True), # bsz x hid 1129 | Variable(amask, volatile=True), 1130 | Variable(uniqfields, volatile=True)) 1131 | init_logps, trans_logps = net.trans_logprobs(uniqenc, seqlen) # bsz x K, T-1 x bsz x KxK 1132 | len_logprobs, _ = net.len_logprobs() 1133 | fwd_obs_logps = net.obs_logprobs(Variable(inps, volatile=True), srcenc, 1134 | srcfieldenc, Variable(fmask, volatile=True), 1135 | Variable(combotargs, volatile=True), bsz) # LxTxbsz x K 1136 | bwd_obs_logprobs = infc.bwd_from_fwd_obs_logprobs(fwd_obs_logps.data) 1137 | seqs = infc.viterbi(init_logps.data, trans_logps.data, bwd_obs_logprobs, 1138 | [t.data for t in len_logprobs], constraints=fwd_cidxs) 1139 | for b in xrange(bsz): 1140 | words = [corpus.dictionary.idx2word[w] for w in x[:, b]] 1141 | for (start, end, label) in seqs[b]: 1142 | print "%s|%d" % (" ".join(words[start:end]), label), 1143 | print 1144 | 1145 | def gen_from_srctbl(src_tbl, top_temps, coeffs, src_line=None): 1146 | net.ar = saved_args.ar_after_decay 1147 | #print "btw2", net.ar 1148 | i2w, w2i = corpus.dictionary.idx2word, corpus.dictionary.word2idx 1149 | best_score, best_phrases, best_templt = -float("inf"), None, None 1150 | best_len = 0 1151 | best_tscore, best_gscore = None, None 1152 | 1153 | # get srcrow 2 key, idx 1154 | #src_b = src.narrow(0, b, 1) # 1 x nfields x nfeats 1155 | src_b = corpus.featurize_tbl(src_tbl).unsqueeze(0) # 1 x nfields x nfeats 1156 | uniq_b = get_uniq_fields(src_b, saved_args.pad_idx) # 1 x max_fields 1157 | if args.cuda: 1158 | src_b = src_b.cuda() 1159 | uniq_b = uniq_b.cuda() 1160 | 1161 | srcenc, srcfieldenc, uniqenc = net.encode(Variable(src_b, volatile=True), None, 1162 | Variable(uniq_b, volatile=True)) 1163 | init_logps, trans_logps = net.trans_logprobs(uniqenc, 2) 1164 | _, len_scores = net.len_logprobs() 1165 | len_lps = net.lsm(len_scores).data 1166 | init_logps, trans_logps = init_logps.data.cpu(), trans_logps.data[0].cpu() 1167 | inits = net.h0_lin(srcenc) 1168 | h0, c0 = F.tanh(inits[:, :inits.size(1)/2]), inits[:, inits.size(1)/2:] 1169 | 1170 | nfields = src_b.size(1) 1171 | row2tblent = {} 1172 | for ff in xrange(nfields): 1173 | field, idx = i2w[src_b[0][ff][0]], i2w[src_b[0][ff][1]] 1174 | if (field, idx) in src_tbl: 1175 | row2tblent[ff] = (field, idx, src_tbl[field, idx]) 1176 | else: 1177 | row2tblent[ff] = (None, None, None) 1178 | 1179 | # get row to input feats 1180 | row2feats = {} 1181 | # precompute wrd stuff 1182 | fld_cntr = Counter([key for key, _ in src_tbl]) 1183 | for row, (k, idx, wrd) in row2tblent.iteritems(): 1184 | if k in w2i: 1185 | widx = w2i[wrd] if wrd in w2i else w2i[""] 1186 | keyidx = w2i[k] if k in w2i else w2i[""] 1187 | idxidx = w2i[idx] 1188 | cheatfeat = w2i[""] if fld_cntr[k] == idx else w2i[""] 1189 | #row2feats[row] = torch.LongTensor([keyidx, idxidx, cheatfeat]) 1190 | row2feats[row] = torch.LongTensor([widx, keyidx, idxidx, cheatfeat]) 1191 | 1192 | constr_sat = False 1193 | # search over all templates 1194 | for templt in top_temps: 1195 | #print "templt is", templt 1196 | # get templt transition prob 1197 | tscores = [init_logps[0][templt[0]]] 1198 | [tscores.append(trans_logps[0][templt[tt-1]][templt[tt]]) 1199 | for tt in xrange(1, len(templt))] 1200 | 1201 | if net.ar: 1202 | phrases, wscore, tokes = net.gen_one_ar(templt, h0[0], c0[0], srcfieldenc, 1203 | len_lps, row2tblent, row2feats) 1204 | rul_tokes = tokes 1205 | else: 1206 | phrases, wscore, lscore, tokes, segs = net.gen_one(templt, h0[0], c0[0], 1207 | srcfieldenc, len_lps, row2tblent, row2feats) 1208 | rul_tokes = tokes - segs # subtract imaginary toke for each 1209 | wscore /= tokes 1210 | segs = len(templt) 1211 | if (rul_tokes < args.min_gen_tokes or segs < args.min_gen_states) and constr_sat: 1212 | continue 1213 | if rul_tokes >= args.min_gen_tokes and segs >= args.min_gen_states: 1214 | constr_sat = True # satisfied our constraint 1215 | tscore = sum(tscores[:int(segs)])/segs 1216 | if not net.unif_lenps: 1217 | tscore += lscore/segs 1218 | 1219 | gscore = wscore 1220 | ascore = coeffs[0]*tscore + coeffs[1]*gscore 1221 | if (constr_sat and ascore > best_score) or (not constr_sat and rul_tokes > best_len) or (not constr_sat and rul_tokes == best_len and ascore > best_score): 1222 | # take if improves score or not long enough yet and this is longer... 1223 | #if ascore > best_score: #or (not constr_sat and rul_tokes > best_len): 1224 | best_score, best_tscore, best_gscore = ascore, tscore, gscore 1225 | best_phrases, best_templt = phrases, templt 1226 | best_len = rul_tokes 1227 | #str_phrases = [" ".join(phrs) for phrs in phrases] 1228 | #tmpltd = ["%s|%d" % (phrs, templt[k]) for k, phrs in enumerate(str_phrases)] 1229 | #statstr = "a=%.2f t=%.2f g=%.2f" % (ascore, tscore, gscore) 1230 | #print "%s|||%s" % (" ".join(str_phrases), " ".join(tmpltd)), statstr 1231 | #assert False 1232 | #assert False 1233 | 1234 | try: 1235 | str_phrases = [" ".join(phrs) for phrs in best_phrases] 1236 | except TypeError: 1237 | # sometimes it puts an actual number in 1238 | str_phrases = [" ".join([str(n) if type(n) is int else n for n in phrs]) for phrs in best_phrases] 1239 | tmpltd = ["%s|%d" % (phrs, best_templt[kk]) for kk, phrs in enumerate(str_phrases)] 1240 | if args.verbose: 1241 | print src_line 1242 | #print src_tbl 1243 | 1244 | print "%s|||%s" % (" ".join(str_phrases), " ".join(tmpltd)) 1245 | if args.verbose: 1246 | statstr = "a=%.2f t=%.2f g=%.2f" % (best_score, best_tscore, best_gscore) 1247 | print statstr 1248 | print 1249 | #assert False 1250 | 1251 | def gen_from_src(): 1252 | from template_extraction import extract_from_tagged_data, align_cntr 1253 | top_temps, _, _ = extract_from_tagged_data(args.data, args.bsz, args.thresh, 1254 | args.tagged_fi, args.ntemplates) 1255 | 1256 | with open(args.gen_from_fi) as f: 1257 | src_lines = f.readlines() 1258 | 1259 | if len(args.wid_workers) > 0: 1260 | wid, nworkers = [int(n.strip()) for n in args.wid_workers.split(',')] 1261 | chunksz = math.floor(len(src_lines)/float(nworkers)) 1262 | startln = int(wid*chunksz) 1263 | endln = int((wid+1)*chunksz) if wid < nworkers-1 else len(src_lines) 1264 | print >> sys.stderr, "worker", wid, "doing lines", startln, "thru", endln-1 1265 | src_lines = src_lines[startln:endln] 1266 | 1267 | net.eval() 1268 | coeffs = [float(flt.strip()) for flt in args.gen_wts.split(',')] 1269 | if args.gen_on_valid: 1270 | for i in xrange(len(corpus.valid)): 1271 | if i > 2: 1272 | break 1273 | x, _, src, locs, inps = corpus.valid[i] 1274 | seqlen, bsz = x.size() 1275 | #nfields = src.size(1) 1276 | # get bsz x nfields, bsz x nfields masks 1277 | #fmask, amask = make_masks(src, saved_args.pad_idx, max_pool=saved_args.max_pool) 1278 | #if args.cuda: 1279 | #src = src.cuda() 1280 | #amask = amask.cuda() 1281 | 1282 | for b in xrange(bsz): 1283 | src_line = src_lines[corpus.val_mb2linenos[i][b]] 1284 | if "wiki" in args.data: 1285 | src_tbl = get_wikibio_poswrds(src_line.strip().split()) 1286 | else: 1287 | src_tbl = get_e2e_poswrds(src_line.strip().split()) 1288 | 1289 | gen_from_srctbl(src_tbl, top_temps, coeffs, src_line=src_line) 1290 | else: 1291 | for ll, src_line in enumerate(src_lines): 1292 | if "wiki" in args.data: 1293 | src_tbl = get_wikibio_poswrds(src_line.strip().split()) 1294 | else: 1295 | src_tbl = get_e2e_poswrds(src_line.strip().split()) 1296 | 1297 | gen_from_srctbl(src_tbl, top_temps, coeffs, src_line=src_line) 1298 | 1299 | 1300 | def align_stuff(): 1301 | from template_extraction import extract_from_tagged_data 1302 | i2w = corpus.dictionary.idx2word 1303 | net.eval() 1304 | cop_counters = [Counter() for _ in xrange(net.K*net.Kmul)] 1305 | net.ar = saved_args.ar_after_decay and not args.no_ar_for_vit 1306 | top_temps, _, _ = extract_from_tagged_data(args.data, args.bsz, args.thresh, 1307 | args.tagged_fi, args.ntemplates) 1308 | top_temps = set(temp for temp in top_temps) 1309 | 1310 | with open(os.path.join(args.data, "train.txt")) as f: 1311 | tgtlines = [line.strip().split() for line in f] 1312 | 1313 | with open(os.path.join(args.data, "src_train.txt")) as f: 1314 | srclines = [line.strip().split() for line in f] 1315 | 1316 | assert len(srclines) == len(tgtlines) 1317 | 1318 | for i in xrange(len(corpus.train)): 1319 | x, _, src, locs, inps = corpus.train[i] 1320 | fwd_cidxs = None 1321 | 1322 | seqlen, bsz = x.size() 1323 | nfields = src.size(1) 1324 | if seqlen <= saved_args.L or seqlen > args.max_seqlen: 1325 | continue 1326 | 1327 | combotargs = make_combo_targs(locs, x, saved_args.L, nfields, corpus.ngen_types) 1328 | # get bsz x nfields, bsz x nfields masks 1329 | fmask, amask = make_masks(src, saved_args.pad_idx, max_pool=saved_args.max_pool) 1330 | uniqfields = get_uniq_fields(src, args.pad_idx) # bsz x max_fields 1331 | 1332 | if args.cuda: 1333 | combotargs = combotargs.cuda() 1334 | src = src.cuda() 1335 | inps = inps.cuda() 1336 | fmask, amask = fmask.cuda(), amask.cuda() 1337 | uniqfields = uniqfields.cuda() 1338 | 1339 | srcenc, srcfieldenc, uniqenc = net.encode(Variable(src, volatile=True), # bsz x hid 1340 | Variable(amask, volatile=True), 1341 | Variable(uniqfields, volatile=True)) 1342 | init_logps, trans_logps = net.trans_logprobs(uniqenc, seqlen) # bsz x K, T-1 x bsz x KxK 1343 | len_logprobs, _ = net.len_logprobs() 1344 | fwd_obs_logps = net.obs_logprobs(Variable(inps, volatile=True), srcenc, 1345 | srcfieldenc, Variable(fmask, volatile=True), 1346 | Variable(combotargs, volatile=True), bsz) # LxTxbsz x K 1347 | bwd_obs_logprobs = infc.bwd_from_fwd_obs_logprobs(fwd_obs_logps.data) 1348 | seqs = infc.viterbi(init_logps.data, trans_logps.data, bwd_obs_logprobs, 1349 | [t.data for t in len_logprobs], constraints=fwd_cidxs) 1350 | # get rid of stuff not in our top_temps 1351 | for bidx in xrange(bsz): 1352 | if tuple(labe for (start, end, labe) in seqs[bidx]) in top_temps: 1353 | lineno = corpus.train_mb2linenos[i][bidx] 1354 | tgttokes = tgtlines[lineno] 1355 | if "wiki" in args.data: 1356 | src_tbl = get_wikibio_poswrds(srclines[lineno]) 1357 | else: 1358 | src_tbl = get_e2e_poswrds(srclines[lineno]) # field, idx -> wrd 1359 | wrd2fields = defaultdict(list) 1360 | for (field, idx), wrd in src_tbl.iteritems(): 1361 | wrd2fields[wrd].append(field) 1362 | for (start, end, labe) in seqs[bidx]: 1363 | for wrd in tgttokes[start:end]: 1364 | if wrd in wrd2fields: 1365 | cop_counters[labe].update(wrd2fields[wrd]) 1366 | else: 1367 | cop_counters[labe]["other"] += 1 1368 | 1369 | return cop_counters 1370 | 1371 | if args.interactive: 1372 | pass 1373 | elif args.align: 1374 | from utils import calc_pur 1375 | cop_counters = align_stuff() 1376 | calc_pur(cop_counters) 1377 | elif args.label_train: 1378 | net.eval() 1379 | label_train() 1380 | elif len(args.gen_from_fi) > 0: 1381 | gen_from_src() 1382 | elif args.epochs == 0: 1383 | net.eval() 1384 | test(0) 1385 | else: 1386 | prev_valloss, best_valloss = float("inf"), float("inf") 1387 | decayed = False 1388 | if args.prev_loss is not None: 1389 | prev_valloss = args.prev_loss 1390 | if args.best_loss is None: 1391 | best_valloss = prev_valloss 1392 | else: 1393 | decayed = True 1394 | best_valloss = args.best_loss 1395 | print "starting with", prev_valloss, best_valloss 1396 | 1397 | for epoch in range(1, args.epochs + 1): 1398 | if epoch > args.no_ar_epochs and not net.ar and decayed: 1399 | net.ar = True 1400 | # hack 1401 | if args.word_ar and not net.word_ar: 1402 | print "turning on word ar..." 1403 | net.word_ar = True 1404 | 1405 | print "ar:", net.ar 1406 | 1407 | train(epoch) 1408 | net.eval() 1409 | valloss = test(epoch) 1410 | 1411 | if valloss < best_valloss: 1412 | best_valloss = valloss 1413 | if len(args.save) > 0: 1414 | print "saving to", args.save 1415 | state = {"opt": args, "state_dict": net.state_dict(), 1416 | "lr": args.lr, "dict": corpus.dictionary} 1417 | torch.save(state, args.save + "." + str(int(decayed))) 1418 | 1419 | if (args.optim == "sgd" and valloss >= prev_valloss) or (args.onmt_decay and decayed): 1420 | decayed = True 1421 | args.lr *= args.lr_decay 1422 | if args.ar_after_decay and not net.ar: 1423 | net.ar = True 1424 | # hack 1425 | if args.word_ar and not net.word_ar: 1426 | print "turning on word ar..." 1427 | net.word_ar = True 1428 | print "decaying lr to:", args.lr 1429 | if args.lr < 1e-5: 1430 | break 1431 | prev_valloss = valloss 1432 | if args.cuda: 1433 | print "ugh...." 1434 | shmocals = locals() 1435 | for shk in shmocals.keys(): 1436 | shv = shmocals[shk] 1437 | if hasattr(shv, "is_cuda") and shv.is_cuda: 1438 | shv = shv.cpu() 1439 | print "done!" 1440 | print 1441 | else: 1442 | print 1443 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harvardnlp/neural-template-gen/dff72ca246e80fd0eef2ca0b8b1fbf21fe927df4/data/__init__.py -------------------------------------------------------------------------------- /data/e2e_aligned.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harvardnlp/neural-template-gen/dff72ca246e80fd0eef2ca0b8b1fbf21fe927df4/data/e2e_aligned.tar.gz -------------------------------------------------------------------------------- /data/make_e2e_labedata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from utils import get_e2e_fields, e2e_key2idx 5 | 6 | e2e_train_src = "src_train.txt" 7 | e2e_train_tgt = "train_tgt_lines.txt" # gold generations corresponding to src_train.txt 8 | e2e_val_src = "src_valid.txt" 9 | e2e_val_tgt = "valid_tgt_lines.txt" # gold generations corresponding to src_valid.txt 10 | 11 | punctuation = set(['.', '!', ',', ';', ':', '?']) 12 | 13 | def get_first_sent_tokes(tokes): 14 | try: 15 | first_per = tokes.index('.') 16 | return tokes[:first_per+1] 17 | except ValueError: 18 | return tokes 19 | 20 | def stupid_search(tokes, fields): 21 | """ 22 | greedily assigns longest labels to spans from left to right 23 | """ 24 | labels = [] 25 | i = 0 26 | while i < len(tokes): 27 | matched = False 28 | for j in xrange(len(tokes), i, -1): 29 | # first check if it's punctuation 30 | if all(toke in punctuation for toke in tokes[i:j]): 31 | labels.append((i, j, len(e2e_key2idx))) # first label after rul labels 32 | i = j 33 | matched = True 34 | break 35 | # then check if it matches stuff in the table 36 | for k, v in fields.iteritems(): 37 | # take an uncased match 38 | if " ".join(tokes[i:j]).lower() == " ".join(v).lower(): 39 | labels.append((i, j, e2e_key2idx[k])) 40 | i = j 41 | matched = True 42 | break 43 | if matched: 44 | break 45 | if not matched: 46 | i += 1 47 | return labels 48 | 49 | def print_data(srcfi, tgtfi): 50 | with open(srcfi) as f1: 51 | with open(tgtfi) as f2: 52 | for srcline in f1: 53 | tgttokes = f2.readline().strip().split() 54 | senttokes = tgttokes 55 | 56 | fields = get_e2e_fields(srcline.strip().split()) # fieldname -> tokens 57 | labels = stupid_search(senttokes, fields) 58 | labels = [(str(tup[0]), str(tup[1]), str(tup[2])) for tup in labels] 59 | 60 | # add eos stuff 61 | senttokes.append("") 62 | labels.append((str(len(senttokes)-1), str(len(senttokes)), '8')) # label doesn't matter 63 | 64 | labelstr = " ".join([','.join(label) for label in labels]) 65 | sentstr = " ".join(senttokes) 66 | 67 | outline = "%s|||%s" % (sentstr, labelstr) 68 | print outline 69 | 70 | 71 | if sys.argv[1] == "train": 72 | print_data(e2e_train_src, e2e_train_tgt) 73 | elif sys.argv[1] == "valid": 74 | print_data(e2e_val_src, e2e_val_tgt) 75 | else: 76 | assert False 77 | -------------------------------------------------------------------------------- /data/make_wikibio_labedata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | 5 | from utils import get_wikibio_fields 6 | 7 | train_dir = "wikipedia-biography-dataset/train" 8 | val_dir = "wikipedia-biography-dataset/valid" 9 | 10 | punctuation = set(['.', '!', ',', ';', ':', '?', '--', '-rrb-', '-lrb-']) 11 | 12 | # from wikipedia 13 | prepositions = set(['aboard', 'about', 'above', 'absent', 'across', 'after', 'against', 'along', 'alongside', 'amid', 'among', 14 | 'apropos', 'apud', 'around', 'as', 'astride', 'at', 'atop', 'bar', 'before', 'behind', 'below', 'beneath', 15 | 'beside', 'besides', 'between', 'beyond', 'but', 'by', 'chez', 'circa', 'come', 'despite', 'down', 'during', 16 | 'except', 'for', 'from', 'in', 'inside', 'into', 'less', 'like', 'minus', 'near', 'notwithstanding', 'of', 17 | 'off', 'on', 'onto', 'opposite', 'out', 'outside', 'over', 'pace', 'past', 'per', 'plus', 'post', 'pre', 18 | 'pro', 'qua', 're', 'sans', 'save', 'short', 'since', 'than', 'through', 'throughout', 'till', 'to', 'toward', 19 | 'under', 'underneath', 'unlike', 'until', 'unto', 'up', 'upon', 'upside', 'versus', 'via', 'vice', 'aboard', 20 | 'about', 'above', 'absent', 'across', 'after', 'against', 'along', 'alongside', 'amid', 'among', 'apropos', 21 | 'apud', 'around', 'as', 'astride', 'at', 'atop', 'bar', 'before', 'behind', 'below', 'beneath', 'beside', 'besides', 22 | 'between', 'beyond', 'but', 'by', 'chez', 'circa', 'come', 'despite', 'down', 'during', 'except', 'for', 'from', 'in', 23 | 'inside', 'into', 'less', 'like', 'minus', 'near', 'notwithstanding', 'of', 'off', 'on', 'onto', 'opposite', 'out', 24 | 'outside', 'over', 'pace', 'past', 'per', 'plus', 'post', 'pre', 'pro', 'qua', 're', 'sans', 'save', 'short', 'since', 25 | 'than', 'through', 'throughout', 'till', 'to', 'toward', 'under', 'underneath', 'unlike', 'until', 'unto', 'up', 'upon', 26 | 'upside', 'versus', 'via', 'vice', 'with', 'within', 'without', 'worth']) 27 | 28 | 29 | splitters = set(['and', ',', 'or', 'of', 'for', '--', 'also']) 30 | 31 | goodsplitters = set([',', 'of', 'for', '--', 'also']) # leaves out and and or 32 | 33 | def splitphrs(tokes, l, r, max_phrs_len, labelist): 34 | if r-l <= max_phrs_len: 35 | labelist.append((l, r, 0)) 36 | else: 37 | i = r-1 38 | found_a_split = False 39 | while i > l: 40 | if tokes[i] in goodsplitters or tokes[i] in prepositions: 41 | splitphrs(tokes, l, i, max_phrs_len, labelist) 42 | if i < r-1: 43 | splitphrs(tokes, i+1, r, max_phrs_len, labelist) 44 | found_a_split = True 45 | break 46 | i -= 1 47 | if not found_a_split: # add back in and and or 48 | i = r-1 49 | while i > l: 50 | if tokes[i] in splitters or tokes[i] in prepositions: 51 | splitphrs(tokes, l, i, max_phrs_len, labelist) 52 | if i < r-1: 53 | splitphrs(tokes, i+1, r, max_phrs_len, labelist) 54 | found_a_split = True 55 | break 56 | i -= 1 57 | if not found_a_split: # just do something 58 | i = r-1 59 | while i >= l: 60 | max_len = min(max_phrs_len, i-l+1) 61 | labelist.append((i-max_len+1, i+1, 0)) 62 | i = i-max_len 63 | 64 | 65 | def stupid_search(tokes, fields): 66 | """ 67 | greedily assigns longest labels to spans from right to left 68 | """ 69 | PFL = 4 70 | labels = [] 71 | i = len(tokes) 72 | wordsets = [set(toke for toke in v if toke not in punctuation) for k, v in fields.iteritems()] 73 | pfxsets = [set(toke[:PFL] for toke in v if toke not in punctuation) for k, v in fields.iteritems()] 74 | while i > 0: 75 | matched = False 76 | if tokes[i-1] in punctuation: 77 | labels.append((i-1, i, 0)) # all punctuation 78 | i -= 1 79 | continue 80 | if tokes[i-1] in punctuation or tokes[i-1] in prepositions or tokes[i-1] in splitters: 81 | i -= 1 82 | continue 83 | for j in xrange(i): 84 | if tokes[j] in punctuation or tokes[j] in prepositions or tokes[j] in splitters: 85 | continue 86 | # then check if it matches stuff in the table 87 | tokeset = set(toke for toke in tokes[j:i] if toke not in punctuation) 88 | for vset in wordsets: 89 | if tokeset == vset or (tokeset.issubset(vset) and len(tokeset) > 1): 90 | if i - j > max_phrs_len: 91 | nugz = [] 92 | splitphrs(tokes, j, i, max_phrs_len, nugz) 93 | labels.extend(nugz) 94 | else: 95 | labels.append((j, i, 0)) 96 | i = j 97 | matched = True 98 | break 99 | if matched: 100 | break 101 | pset = set(toke[:PFL] for toke in tokes[j:i] if toke not in punctuation) 102 | for pfxset in pfxsets: 103 | if pset == pfxset or (pset.issubset(pfxset)and len(pset) > 1): 104 | if i - j > max_phrs_len: 105 | nugz = [] 106 | splitphrs(tokes, j, i, max_phrs_len, nugz) 107 | labels.extend(nugz) 108 | else: 109 | labels.append((j, i, 0)) 110 | i = j 111 | matched = True 112 | break 113 | if matched: 114 | break 115 | if not matched: 116 | i -= 1 117 | labels.sort(key=lambda x: x[0]) 118 | return labels 119 | 120 | def print_data(direc): 121 | fis = os.listdir(direc) 122 | srcfi = [fi for fi in fis if fi.endswith('.box')][0] 123 | tgtfi = [fi for fi in fis if fi.endswith('.sent')][0] 124 | nbfi = [fi for fi in fis if fi.endswith('.nb')][0] 125 | 126 | with open(os.path.join(direc, srcfi)) as f: 127 | srclines = f.readlines() 128 | with open(os.path.join(direc, nbfi)) as f: 129 | nbs = [0] 130 | [nbs.append(int(line.strip())) for line in f.readlines()] 131 | nbs = set(torch.Tensor(nbs).cumsum(0)) 132 | 133 | tgtlines = [] 134 | with open(os.path.join(direc, tgtfi)) as f: 135 | for i, tgtline in enumerate(f): 136 | if i in nbs: 137 | tgtlines.append(tgtline) 138 | 139 | assert len(srclines) == len(tgtlines) 140 | for i in xrange(len(srclines)): 141 | fields = get_wikibio_fields(srclines[i].strip().split()) 142 | tgttokes = tgtlines[i].strip().split() 143 | labels = stupid_search(tgttokes, fields) 144 | labels = [(str(tup[0]), str(tup[1]), str(tup[2])) for tup in labels] 145 | # add eos stuff 146 | tgttokes.append("") 147 | labels.append((str(len(tgttokes)-1), str(len(tgttokes)), '0')) # label doesn't matter 148 | 149 | labelstr = " ".join([','.join(label) for label in labels]) 150 | sentstr = " ".join(tgttokes) 151 | 152 | outline = "%s|||%s" % (sentstr, labelstr) 153 | print outline 154 | 155 | if __name__ == "__main__": 156 | max_phrs_len = int(sys.argv[2]) 157 | if sys.argv[1] == "train": 158 | print_data(train_dir) 159 | elif sys.argv[1] == "valid": 160 | print_data(val_dir) 161 | else: 162 | assert False 163 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | # leaves out familyFriendly, which is a binary thing... 4 | e2e_keys = ["name", "eatType", "food", "priceRange", "customerrating", "area", "near"] 5 | e2e_key2idx = dict((key, i) for i, key in enumerate(e2e_keys)) 6 | 7 | def get_e2e_fields(tokes, keys=None): 8 | """ 9 | assumes a key only appears once per line... 10 | returns keyname -> list of words dict 11 | """ 12 | if keys is None: 13 | keys = e2e_keys 14 | fields = defaultdict(list) 15 | state = None 16 | for toke in tokes: 17 | if "__start" in toke: 18 | for key in keys: 19 | if toke == "__start_%s__" % key: 20 | assert state is None 21 | state = key 22 | elif "__end" in toke: 23 | for key in keys: 24 | if toke == "__end_%s__" % key: 25 | assert state == key 26 | state = None 27 | elif state is not None: 28 | fields[state].append(toke) 29 | 30 | return fields 31 | 32 | def get_e2e_poswrds(tokes): 33 | """ 34 | assumes a key only appears once per line... 35 | returns (key, num) -> word 36 | """ 37 | fields = {} 38 | state, num = None, 1 # 1-idx the numbering 39 | for toke in tokes: 40 | if "__start" in toke: 41 | assert state is None 42 | state = toke[7:-2] 43 | elif "__end" in toke: 44 | state, num = None, 1 45 | elif state is not None: 46 | fields[state, num] = toke 47 | num += 1 48 | return fields 49 | 50 | 51 | def get_wikibio_fields(tokes, keep_splits=None): 52 | """ 53 | key -> list of words 54 | """ 55 | fields = defaultdict(list) 56 | for toke in tokes: 57 | try: 58 | fullkey, val = toke.split(':') 59 | except ValueError: 60 | ugh = toke.split(':') # must be colons in the val 61 | fullkey = ugh[0] 62 | val = ''.join(ugh[1:]) 63 | if val == "": 64 | continue 65 | #try: 66 | keypieces = fullkey.split('_') 67 | if len(keypieces) == 1: 68 | key = fullkey 69 | else: 70 | keynum = keypieces[-1] 71 | key = '_'.join(keypieces[:-1]) 72 | #key, keynum = fullkey.split('_') 73 | #except ValueError: 74 | # key = fullkey 75 | if keep_splits is None or key not in keep_splits: 76 | fields[key].append(val) # assuming keys are ordered... 77 | else: 78 | fields[fullkey].append(val) 79 | return fields 80 | 81 | 82 | def get_wikibio_poswrds(tokes): 83 | """ 84 | (key, num) -> word 85 | """ 86 | fields = {} 87 | for toke in tokes: 88 | try: 89 | fullkey, val = toke.split(':') 90 | except ValueError: 91 | ugh = toke.split(':') # must be colons in the val 92 | fullkey = ugh[0] 93 | val = ''.join(ugh[1:]) 94 | if val == "": 95 | continue 96 | #try: 97 | keypieces = fullkey.split('_') 98 | if len(keypieces) == 1: 99 | key = fullkey 100 | #keynum = '0' 101 | keynum = 1 102 | else: 103 | keynum = int(keypieces[-1]) 104 | key = '_'.join(keypieces[:-1]) 105 | fields[key, keynum] = val 106 | return fields 107 | -------------------------------------------------------------------------------- /data/wb_aligned.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harvardnlp/neural-template-gen/dff72ca246e80fd0eef2ca0b8b1fbf21fe927df4/data/wb_aligned.tar.gz -------------------------------------------------------------------------------- /infc.py: -------------------------------------------------------------------------------- 1 | """ 2 | all the inference stuff 3 | """ 4 | 5 | import math 6 | import torch 7 | from torch.autograd import Variable 8 | from utils import logsumexp0, logsumexp2 9 | 10 | 11 | def recover_bps(delt, bps, bps_star): 12 | """ 13 | delt, bps, bps_star - seqlen+1 x bsz x K 14 | returns: 15 | bsz-length list of lists with (start_idx, end_idx, label) entries 16 | """ 17 | seqlenp1, bsz, K = delt.size() 18 | seqlen = seqlenp1 - 1 19 | seqs = [] 20 | for b in xrange(bsz): 21 | seq = [] 22 | _, last_lab = delt[seqlen][b].max(0) 23 | last_lab = last_lab[0] 24 | curr_idx = seqlen # 1-indexed 25 | while True: 26 | last_len = bps[curr_idx][b][last_lab] 27 | seq.append((curr_idx-last_len, curr_idx, last_lab)) # start_idx, end_idx, label, 0-idxd 28 | curr_idx -= last_len 29 | if curr_idx == 0: 30 | break 31 | last_lab = bps_star[curr_idx][b][last_lab] 32 | seqs.append(seq[::-1]) 33 | return seqs 34 | 35 | 36 | def viterbi(pi, trans_logprobs, bwd_obs_logprobs, len_logprobs, constraints=None, ret_delt=False): 37 | """ 38 | pi - 1 x K 39 | bwd_obs_logprobs - L x T x bsz x K, obs probs ending at t 40 | trans_logprobs - T-1 x bsz x K x K, trans_logprobs[t] = p(q_{t+1} | q_t). 41 | see https://hal.inria.fr/hal-01064672v2/document 42 | """ 43 | neginf = -1e38 44 | L, seqlen, bsz, K = bwd_obs_logprobs.size() 45 | delt = trans_logprobs.new(seqlen+1, bsz, K).fill_(-float("inf")) 46 | delt_star = trans_logprobs.new(seqlen+1, bsz, K).fill_(-float("inf")) 47 | delt_star[0].copy_(pi.expand(bsz, K)) 48 | 49 | # currently len_logprobs contains tensors that are [1 step back; 2 steps back; ... L steps_back] 50 | # but we need to flip on the 0'th axis 51 | flipped_len_logprobs = [] 52 | for l in xrange(len(len_logprobs)): 53 | llps = len_logprobs[l] 54 | flipped_len_logprobs.append(torch.stack([llps[-i-1] for i in xrange(llps.size(0))])) 55 | 56 | bps = delt.long().fill_(L) 57 | bps_star = delt_star.long() 58 | bps_star[0].copy_(torch.arange(0, K).view(1, K).expand(bsz, K)) 59 | 60 | mask = trans_logprobs.new(L, bsz, K) 61 | 62 | for t in xrange(1, seqlen+1): 63 | steps_back = min(L, t) 64 | steps_fwd = min(L, seqlen-t+1) 65 | 66 | if steps_back <= steps_fwd: 67 | # steps_fwd x K -> steps_back x K 68 | len_terms = flipped_len_logprobs[min(L-1, steps_fwd-1)][-steps_back:] 69 | else: # we need to pick probs from different distributions... 70 | len_terms = torch.stack([len_logprobs[min(L, seqlen+1-t+jj)-1][jj] 71 | for jj in xrange(L-1, -1, -1)]) 72 | 73 | if constraints is not None and constraints[t] is not None: 74 | tmask = mask.narrow(0, 0, steps_back).zero_() 75 | # steps_back x bsz x K -> steps_back*bsz x K 76 | tmask.view(-1, K).index_fill_(0, constraints[t], neginf) 77 | 78 | # delt_t(j) = log \sum_l p(x_{t-l+1:t}) delt*_{t-l} p(l_t) 79 | delt_terms = (delt_star[t-steps_back:t] # steps_back x bsz x K 80 | + bwd_obs_logprobs[-steps_back:, t-1]) # steps_back x bsz x K (0-idx) 81 | #delt_terms.sub_(bwd_maxlens[t-steps_back:t].expand_as(delt_terms)) # steps_back x bsz x K 82 | delt_terms.add_(len_terms.unsqueeze(1).expand(steps_back, bsz, K)) 83 | 84 | if constraints is not None and constraints[t] is not None: 85 | delt_terms.add_(tmask) 86 | 87 | maxes, argmaxes = torch.max(delt_terms, 0) # 1 x bsz x K, 1 x bsz x K 88 | delt[t] = maxes.squeeze(0) # bsz x K 89 | #bps[t] = argmaxes.squeeze(0) # bsz x K 90 | bps[t].sub_(argmaxes.squeeze(0)) # keep track of steps back taken: L - argmax 91 | if steps_back < L: 92 | bps[t].sub_(L - steps_back) 93 | if t < seqlen: 94 | # delt*_t(k) = log \sum_j delt_t(j) p(q_{t+1}=k | q_t = j) 95 | # get bsz x K x K trans logprobs, viz., p(q_{t+1}=j|i) w/ 0th dim i, 2nd dim j 96 | tps = trans_logprobs[t-1] # N.B. trans_logprobs[t] is p(q_{t+1}) and 0-indexed 97 | delt_t = delt[t] # bsz x K, viz, p(x, j) 98 | delt_star_terms = (tps.transpose(0, 1) # K x bsz x K 99 | + delt_t.unsqueeze(2).expand(bsz, K, K).transpose(0, 1)) 100 | maxes, argmaxes = torch.max(delt_star_terms, 0) # 1 x bsz x K, 1 x bsz x K 101 | delt_star[t] = maxes.squeeze(0) 102 | bps_star[t] = argmaxes.squeeze(0) 103 | 104 | #return delt, delt_star, bps, bps_star, recover_bps(delt, bps, bps_star) 105 | if ret_delt: 106 | return recover_bps(delt, bps, bps_star), delt[-1] # bsz x K total scores 107 | else: 108 | return recover_bps(delt, bps, bps_star) 109 | 110 | 111 | def just_fwd(pi, trans_logprobs, bwd_obs_logprobs, constraints=None): 112 | """ 113 | pi - bsz x K 114 | bwd_obs_logprobs - L x T x bsz x K, obs probs ending at t 115 | trans_logprobs - T-1 x bsz x K x K, trans_logprobs[t] = p(q_{t+1} | q_t) 116 | """ 117 | neginf = -1e38 # -float("inf") 118 | L, seqlen, bsz, K = bwd_obs_logprobs.size() 119 | # we'll be 1-indexed for alphas and betas 120 | alph = [None]*(seqlen+1) 121 | alph_star = [None]*(seqlen+1) 122 | alph_star[0] = pi 123 | mask = trans_logprobs.new(L, bsz, K) 124 | 125 | bwd_maxlens = trans_logprobs.new(seqlen).fill_(L) # store max possible length generated from t 126 | bwd_maxlens[-L:].copy_(torch.arange(L, 0, -1)) 127 | bwd_maxlens = bwd_maxlens.log_().view(seqlen, 1, 1) 128 | 129 | for t in xrange(1, seqlen+1): 130 | steps_back = min(L, t) 131 | 132 | if constraints is not None and constraints[t] is not None: 133 | tmask = mask.narrow(0, 0, steps_back).zero_() 134 | # steps_back x bsz x K -> steps_back*bsz x K 135 | tmask.view(-1, K).index_fill_(0, constraints[t], neginf) 136 | 137 | # alph_t(j) = log \sum_l p(x_{t-l+1:t}) alph*_{t-l} p(l_t) 138 | alph_terms = (torch.stack(alph_star[t-steps_back:t]) # steps_back x bsz x K 139 | + bwd_obs_logprobs[-steps_back:, t-1] # steps_back x bsz x K (0-idx) 140 | - bwd_maxlens[t-steps_back:t].expand(steps_back, bsz, K)) 141 | 142 | if constraints is not None and constraints[t] is not None: 143 | alph_terms = alph_terms + tmask #Variable(tmask) 144 | 145 | alph[t] = logsumexp0(alph_terms) # bsz x K 146 | 147 | if t < seqlen: 148 | # alph*_t(k) = log \sum_j alph_t(j) p(q_{t+1}=k | q_t = j) 149 | # get bsz x K x K trans logprobs, viz., p(q_{t+1}=j|i) w/ 0th dim i, 2nd dim j 150 | tps = trans_logprobs[t-1] # N.B. trans_logprobs[t] is p(q_{t+1}) and 0-indexed 151 | alph_t = alph[t] # bsz x K, viz, p(x, j) 152 | alph_star_terms = (tps.transpose(0, 1) # K x bsz x K 153 | + alph_t.unsqueeze(2).expand(bsz, K, K).transpose(0, 1)) 154 | alph_star[t] = logsumexp0(alph_star_terms) 155 | 156 | return alph, alph_star 157 | 158 | 159 | def just_bwd(trans_logprobs, fwd_obs_logprobs, len_logprobs, constraints=None): 160 | """ 161 | fwd_obs_logprobs - L x T x bsz x K, obs probs starting at t 162 | trans_logprobs - T-1 x bsz x K x K, trans_logprobs[t] = p(q_{t+1} | q_t) 163 | """ 164 | neginf = -1e38 # -float("inf") 165 | L, seqlen, bsz, K = fwd_obs_logprobs.size() 166 | 167 | # we'll be 1-indexed for alphas and betas 168 | beta = [None]*(seqlen+1) 169 | beta_star = [None]*(seqlen+1) 170 | beta[seqlen] = Variable(trans_logprobs.data.new(bsz, K).zero_()) 171 | mask = trans_logprobs.data.new(L, bsz, K) 172 | 173 | for t in xrange(1, seqlen+1): 174 | steps_fwd = min(L, t) 175 | 176 | len_terms = len_logprobs[min(L-1, steps_fwd-1)] # steps_fwd x K 177 | 178 | if constraints is not None and constraints[seqlen-t+1] is not None: 179 | tmask = mask.narrow(0, 0, steps_fwd).zero_() 180 | # steps_fwd x bsz x K -> steps_fwd*bsz x K 181 | tmask.view(-1, K).index_fill_(0, constraints[seqlen-t+1], neginf) 182 | 183 | # beta*_t(k) = log \sum_l beta_{t+l}(k) p(x_{t+1:t+l}) p(l_t) 184 | beta_star_terms = (torch.stack(beta[seqlen-t+1:seqlen-t+1+steps_fwd]) # steps_fwd x bsz x K 185 | + fwd_obs_logprobs[:steps_fwd, seqlen-t] # steps_fwd x bsz x K 186 | #- math.log(steps_fwd)) # steps_fwd x bsz x K 187 | + len_terms.unsqueeze(1).expand(steps_fwd, bsz, K)) 188 | 189 | if constraints is not None and constraints[seqlen-t+1] is not None: 190 | beta_star_terms = beta_star_terms + Variable(tmask) 191 | 192 | beta_star[seqlen-t] = logsumexp0(beta_star_terms) 193 | if seqlen-t > 0: 194 | # beta_t(j) = log \sum_k beta*_t(k) p(q_{t+1} = k | q_t=j) 195 | betastar_nt = beta_star[seqlen-t] # bsz x K 196 | # get bsz x K x K trans logprobs, viz., p(q_{t+1}=j|i) w/ 0th dim i, 2nd dim j 197 | tps = trans_logprobs[seqlen-t-1] # N.B. trans_logprobs[t] is p(q_{t+1}) and 0-idxed 198 | beta_terms = betastar_nt.unsqueeze(1).expand(bsz, K, K) + tps # bsz x K x K 199 | beta[seqlen-t] = logsumexp2(beta_terms) # bsz x K 200 | 201 | 202 | return beta, beta_star 203 | 204 | 205 | # [p0 p1 p2 p3 p4 206 | # p0:1 p1:2 p2:3 3:4 4:5 207 | # p0:2 p1:3 2:4 3:5 4:6 ] 208 | 209 | 210 | 211 | # so bwd log probs look like 212 | # -inf -inf p1:3 p2:4 213 | # -inf p1:2 p2:3 p3:4 214 | # p1 p2 p3 p4 215 | def bwd_from_fwd_obs_logprobs(fwd_obs_logprobs): 216 | """ 217 | fwd_obs_logprobs - L x T x bsz x K, 218 | where fwd_obs_logprobs[:,t,:,:] gives p(x_t), p(x_{t:t+1}), ..., p(x_{t:t+l}) 219 | returns: 220 | bwd_obs_logprobs - L x T x bsz x K, 221 | where bwd_obs_logprobs[:,t,:,:] gives p(x_{t-L+1:t}), ..., p(x_{t}) 222 | iow, fwd_obs_logprobs gives probs of segments starting at t, and bwd_obs_logprobs 223 | gives probs of segments ending at t 224 | """ 225 | L = fwd_obs_logprobs.size(0) 226 | bwd_obs_logprobs = fwd_obs_logprobs.new().resize_as_(fwd_obs_logprobs).fill_(-float("inf")) 227 | bwd_obs_logprobs[L-1].copy_(fwd_obs_logprobs[0]) 228 | for l in xrange(1, L): 229 | bwd_obs_logprobs[L-l-1, l:].copy_(fwd_obs_logprobs[l, :-l]) 230 | return bwd_obs_logprobs 231 | -------------------------------------------------------------------------------- /labeled_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | this file modified from the word_language_model example 3 | """ 4 | import os 5 | import torch 6 | 7 | from collections import Counter, defaultdict 8 | 9 | from data.utils import get_wikibio_poswrds, get_e2e_poswrds 10 | 11 | import random 12 | random.seed(1111) 13 | 14 | #punctuation = set(['.', '!', ',', ';', ':', '?', '--', '-rrb-', '-lrb-']) 15 | punctuation = set() # i don't know why i was so worried about punctuation 16 | 17 | class Dictionary(object): 18 | def __init__(self, unk_word=""): 19 | self.unk_word = unk_word 20 | self.idx2word = [unk_word, "", "", ""] # OpenNMT constants 21 | self.word2idx = {word: i for i, word in enumerate(self.idx2word)} 22 | 23 | def add_word(self, word, train=False): 24 | """ 25 | returns idx of word 26 | """ 27 | if train and word not in self.word2idx: 28 | self.idx2word.append(word) 29 | self.word2idx[word] = len(self.idx2word) - 1 30 | return self.word2idx[word] if word in self.word2idx else self.word2idx[self.unk_word] 31 | 32 | def bulk_add(self, words): 33 | """ 34 | assumes train=True 35 | """ 36 | self.idx2word.extend(words) 37 | self.word2idx = {word: i for i, word in enumerate(self.idx2word)} 38 | 39 | def __len__(self): 40 | return len(self.idx2word) 41 | 42 | 43 | class SentenceCorpus(object): 44 | def __init__(self, path, bsz, thresh=0, add_bos=False, add_eos=False, 45 | test=False): 46 | self.dictionary = Dictionary() 47 | self.bsz = bsz 48 | self.wiki = "wiki" in path 49 | 50 | train_src = os.path.join(path, "src_train.txt") 51 | 52 | if thresh > 0: 53 | self.get_vocabs(os.path.join(path, 'train.txt'), train_src, thresh=thresh) 54 | self.ngen_types = len(self.genset) + 4 # assuming didn't encounter any special tokens 55 | add_to_dict = False 56 | else: 57 | add_to_dict = True 58 | trsents, trlabels, trfeats, trlocs, inps = self.tokenize( 59 | os.path.join(path, 'train.txt'), train_src, add_to_dict=add_to_dict, 60 | add_bos=add_bos, add_eos=add_eos) 61 | print "using vocabulary of size:", len(self.dictionary) 62 | 63 | print self.ngen_types, "gen word types" 64 | self.train, self.train_mb2linenos = self.minibatchify( 65 | trsents, trlabels, trfeats, trlocs, inps, bsz) # list of minibatches 66 | 67 | if (os.path.isfile(os.path.join(path, 'valid.txt')) 68 | or os.path.isfile(os.path.join(path, 'test.txt'))): 69 | if not test: 70 | val_src = os.path.join(path, "src_valid.txt") 71 | vsents, vlabels, vfeats, vlocs, vinps = self.tokenize( 72 | os.path.join(path, 'valid.txt'), val_src, add_to_dict=False, 73 | add_bos=add_bos, add_eos=add_eos) 74 | else: 75 | print "using test data and whatnot...." 76 | test_src = os.path.join(path, "src_test.txt") 77 | vsents, vlabels, vfeats, vlocs, vinps = self.tokenize( 78 | os.path.join(path, 'test.txt'), test_src, add_to_dict=False, 79 | add_bos=add_bos, add_eos=add_eos) 80 | self.valid, self.val_mb2linenos = self.minibatchify( 81 | vsents, vlabels, vfeats, vlocs, vinps, bsz) 82 | 83 | 84 | def get_vocabs(self, path, src_path, thresh=2): 85 | """unks words occurring <= thresh times""" 86 | tgt_voc = Counter() 87 | assert os.path.exists(path) 88 | 89 | linewords = [] 90 | with open(src_path, 'r') as f: 91 | for line in f: 92 | tokes = line.strip().split() 93 | if self.wiki: 94 | fields = get_wikibio_poswrds(tokes) #key, pos -> wrd 95 | else: 96 | fields = get_e2e_poswrds(tokes) # key, pos -> wrd 97 | fieldvals = fields.values() 98 | tgt_voc.update(fieldvals) 99 | linewords.append(set(wrd for wrd in fieldvals 100 | if wrd not in punctuation)) 101 | tgt_voc.update([k for k, idx in fields]) 102 | tgt_voc.update([idx for k, idx in fields]) 103 | 104 | genwords = Counter() 105 | # Add words to the dictionary 106 | with open(path, 'r') as f: 107 | #tokens = 0 108 | for l, line in enumerate(f): 109 | words, spanlabels = line.strip().split('|||') 110 | words = words.split() 111 | genwords.update([wrd for wrd in words if wrd not in linewords[l]]) 112 | tgt_voc.update(words) 113 | 114 | # prune 115 | # N.B. it's possible a word appears enough times in total but not in genwords 116 | # so we need separate unking for generation 117 | #print "comeon", "aerobatic" in genwords 118 | for cntr in [tgt_voc, genwords]: 119 | for k in cntr.keys(): 120 | if cntr[k] <= thresh: 121 | del cntr[k] 122 | 123 | self.genset = set(genwords.keys()) 124 | tgtkeys = tgt_voc.keys() 125 | # make sure gen stuff is first 126 | tgtkeys.sort(key=lambda x: -(x in self.genset)) 127 | self.dictionary.bulk_add(tgtkeys) 128 | # make sure we did everything right (assuming didn't encounter any special tokens) 129 | assert self.dictionary.idx2word[4 + len(self.genset) - 1] in self.genset 130 | assert self.dictionary.idx2word[4 + len(self.genset)] not in self.genset 131 | self.dictionary.add_word("", train=True) 132 | self.dictionary.add_word("", train=True) 133 | self.dictionary.add_word("", train=True) 134 | self.dictionary.add_word("", train=True) 135 | self.dictionary.add_word("", train=True) 136 | 137 | 138 | def tokenize(self, path, src_path, add_to_dict=False, add_bos=False, add_eos=False): 139 | """Assumes fmt is sentence|||s1,e1,k1 s2,e2,k2 ....""" 140 | assert os.path.exists(path) 141 | 142 | src_feats, src_wrd2idxs, src_wrd2fields = [], [], [] 143 | w2i = self.dictionary.word2idx 144 | with open(src_path, 'r') as f: 145 | for line in f: 146 | tokes = line.strip().split() 147 | #fields = get_e2e_fields(tokes, keys=self.e2e_keys) #keyname -> list of words 148 | if self.wiki: 149 | fields = get_wikibio_poswrds(tokes) #key, pos -> wrd 150 | else: 151 | fields = get_e2e_poswrds(tokes) # key, pos -> wrd 152 | # wrd2things will be unordered 153 | feats, wrd2idxs, wrd2fields = [], defaultdict(list), defaultdict(list) 154 | # get total number of words per field 155 | fld_cntr = Counter([key for key, _ in fields]) 156 | for (k, idx), wrd in fields.iteritems(): 157 | if k in w2i: 158 | featrow = [self.dictionary.add_word(k, add_to_dict), 159 | self.dictionary.add_word(idx, add_to_dict), 160 | self.dictionary.add_word(wrd, add_to_dict)] 161 | wrd2idxs[wrd].append(len(feats)) 162 | #nflds = self.dictionary.add_word(fld_cntr[k], add_to_dict) 163 | cheatfeat = w2i[""] if fld_cntr[k] == idx else w2i[""] 164 | wrd2fields[wrd].append((featrow[2], featrow[0], featrow[1], cheatfeat)) 165 | feats.append(featrow) 166 | src_wrd2idxs.append(wrd2idxs) 167 | src_wrd2fields.append(wrd2fields) 168 | src_feats.append(feats) 169 | 170 | sents, labels, copylocs, inps = [], [], [], [] 171 | 172 | # Add words to the dictionary 173 | tgtline = 0 174 | with open(path, 'r') as f: 175 | #tokens = 0 176 | for line in f: 177 | words, spanlabels = line.strip().split('|||') 178 | words = words.split() 179 | sent, copied, insent = [], [], [] 180 | if add_bos: 181 | sent.append(self.dictionary.add_word('', True)) 182 | for word in words: 183 | # sent is just used for targets; we have separate inputs 184 | if word in self.genset: 185 | sent.append(w2i[word]) 186 | else: 187 | sent.append(w2i[""]) 188 | if word not in punctuation and word in src_wrd2idxs[tgtline]: 189 | copied.append(src_wrd2idxs[tgtline][word]) 190 | winps = [[widx, kidx, idxidx, nidx] 191 | for widx, kidx, idxidx, nidx in src_wrd2fields[tgtline][word]] 192 | insent.append(winps) 193 | else: 194 | #assert sent[-1] < self.ngen_types 195 | copied.append([-1]) 196 | # 1 x wrd, tokennum, totalnum 197 | #insent.append([[sent[-1], w2i[""], w2i[""]]]) 198 | insent.append([[sent[-1], w2i[""], w2i[""], w2i[""]]]) 199 | #sent.extend([self.dictionary.add_word(word, add_to_dict) for word in words]) 200 | if add_eos: 201 | sent.append(self.dictionary.add_word('', True)) 202 | labetups = [tupstr.split(',') for tupstr in spanlabels.split()] 203 | labelist = [(int(tup[0]), int(tup[1]), int(tup[2])) for tup in labetups] 204 | sents.append(sent) 205 | labels.append(labelist) 206 | copylocs.append(copied) 207 | inps.append(insent) 208 | tgtline += 1 209 | assert len(sents) == len(labels) 210 | assert len(src_feats) == len(sents) 211 | assert len(copylocs) == len(sents) 212 | return sents, labels, src_feats, copylocs, inps 213 | 214 | def featurize_tbl(self, fields): 215 | """ 216 | fields are key, pos -> wrd maps 217 | returns: nrows x nfeats tensor 218 | """ 219 | feats = [] 220 | for (k, idx), wrd in fields.iteritems(): 221 | if k in self.dictionary.word2idx: 222 | featrow = [self.dictionary.add_word(k, False), 223 | self.dictionary.add_word(idx, False), 224 | self.dictionary.add_word(wrd, False)] 225 | feats.append(featrow) 226 | return torch.LongTensor(feats) 227 | 228 | def padded_loc_mb(self, curr_locs): 229 | """ 230 | curr_locs is a bsz-len list of tgt-len list of locations 231 | returns: 232 | a seqlen x bsz x max_locs tensor 233 | """ 234 | max_locs = max(len(locs) for blocs in curr_locs for locs in blocs) 235 | for blocs in curr_locs: 236 | for locs in blocs: 237 | if len(locs) < max_locs: 238 | locs.extend([-1]*(max_locs - len(locs))) 239 | return torch.LongTensor(curr_locs).transpose(0, 1).contiguous() 240 | 241 | def padded_feat_mb(self, curr_feats): 242 | """ 243 | curr_feats is a bsz-len list of nrows-len list of features 244 | returns: 245 | a bsz x max_nrows x nfeats tensor 246 | """ 247 | max_rows = max(len(feats) for feats in curr_feats) 248 | nfeats = len(curr_feats[0][0]) 249 | for feats in curr_feats: 250 | if len(feats) < max_rows: 251 | [feats.append([self.dictionary.word2idx[""] for _ in xrange(nfeats)]) 252 | for _ in xrange(max_rows - len(feats))] 253 | return torch.LongTensor(curr_feats) 254 | 255 | 256 | def padded_inp_mb(self, curr_inps): 257 | """ 258 | curr_inps is a bsz-len list of seqlen-len list of nlocs-len list of features 259 | returns: 260 | a bsz x seqlen x max_nlocs x nfeats tensor 261 | """ 262 | max_rows = max(len(feats) for seq in curr_inps for feats in seq) 263 | nfeats = len(curr_inps[0][0][0]) 264 | for seq in curr_inps: 265 | for feats in seq: 266 | if len(feats) < max_rows: 267 | # pick random rows 268 | randidxs = [random.randint(0, len(feats)-1) 269 | for _ in xrange(max_rows - len(feats))] 270 | [feats.append(feats[ridx]) for ridx in randidxs] 271 | return torch.LongTensor(curr_inps) 272 | 273 | 274 | def minibatchify(self, sents, labels, feats, locs, inps, bsz): 275 | """ 276 | this should result in there never being any padding. 277 | each minibatch is: 278 | (seqlen x bsz, bsz-length list of lists of (start, end, label) constraints, 279 | bsz x nfields x nfeats, seqlen x bsz x max_locs, seqlen x bsz x max_locs x nfeats) 280 | """ 281 | # sort in ascending order 282 | sents, sorted_idxs = zip(*sorted(zip(sents, range(len(sents))), key=lambda x: len(x[0]))) 283 | minibatches, mb2linenos = [], [] 284 | curr_batch, curr_labels, curr_feats, curr_locs, curr_linenos = [], [], [], [], [] 285 | curr_inps = [] 286 | curr_len = len(sents[0]) 287 | for i in xrange(len(sents)): 288 | if len(sents[i]) != curr_len or len(curr_batch) == bsz: # we're done 289 | minibatches.append((torch.LongTensor(curr_batch).t().contiguous(), 290 | curr_labels, self.padded_feat_mb(curr_feats), 291 | self.padded_loc_mb(curr_locs), 292 | self.padded_inp_mb(curr_inps).transpose(0, 1).contiguous())) 293 | mb2linenos.append(curr_linenos) 294 | curr_batch = [sents[i]] 295 | curr_len = len(sents[i]) 296 | curr_labels = [labels[sorted_idxs[i]]] 297 | curr_feats = [feats[sorted_idxs[i]]] 298 | curr_locs = [locs[sorted_idxs[i]]] 299 | curr_inps = [inps[sorted_idxs[i]]] 300 | curr_linenos = [sorted_idxs[i]] 301 | else: 302 | curr_batch.append(sents[i]) 303 | curr_labels.append(labels[sorted_idxs[i]]) 304 | curr_feats.append(feats[sorted_idxs[i]]) 305 | curr_locs.append(locs[sorted_idxs[i]]) 306 | curr_inps.append(inps[sorted_idxs[i]]) 307 | curr_linenos.append(sorted_idxs[i]) 308 | # catch last 309 | if len(curr_batch) > 0: 310 | minibatches.append((torch.LongTensor(curr_batch).t().contiguous(), 311 | curr_labels, self.padded_feat_mb(curr_feats), 312 | self.padded_loc_mb(curr_locs), 313 | self.padded_inp_mb(curr_inps).transpose(0, 1).contiguous())) 314 | mb2linenos.append(curr_linenos) 315 | return minibatches, mb2linenos 316 | -------------------------------------------------------------------------------- /slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/harvardnlp/neural-template-gen/dff72ca246e80fd0eef2ca0b8b1fbf21fe927df4/slides.pdf -------------------------------------------------------------------------------- /template_extraction.py: -------------------------------------------------------------------------------- 1 | import re 2 | from collections import defaultdict 3 | 4 | import torch 5 | 6 | import labeled_data 7 | 8 | seg_patt = re.compile('([^\|]+)\|(\d+)') # detects segments 9 | 10 | def group_by_template(fi, startlineno): 11 | """ 12 | returns a label-tup -> [(phrase-list, lineno), ...] map 13 | """ 14 | labes2sents = defaultdict(list) 15 | lineno = startlineno 16 | with open(fi) as f: 17 | for line in f: 18 | if '|' not in line: 19 | continue 20 | seq = seg_patt.findall(line.strip()) # list of 2-tuples 21 | wordseq, labeseq = zip(*seq) # 2 tuples 22 | wordseq = [phrs.strip() for phrs in wordseq] 23 | labeseq = tuple(int(labe) for labe in labeseq) 24 | labes2sents[labeseq].append((wordseq, lineno)) 25 | lineno += 1 26 | return labes2sents 27 | 28 | def remap_eos_states(top_temps, temps2sents): 29 | """ 30 | allocates a new state for any state that is also used for an 31 | """ 32 | used_states = set() 33 | [used_states.update(temp) for temp in top_temps] 34 | final_states = set() 35 | for temp in top_temps: 36 | final_state = temp[-1] 37 | assert any(sent[-1] == "" for sent, lineno in temps2sents[temp]) 38 | final_states.add(final_state) 39 | 40 | # make new states 41 | remap = {} 42 | for i, temp in enumerate(top_temps): 43 | nutemp = [] 44 | changed = False 45 | for j, t in enumerate(temp): 46 | if j < len(temp)-1 and t in final_states: 47 | changed = True 48 | if t not in remap: 49 | remap[t] = max(used_states) + len(remap) + 1 50 | nutemp.append(remap[t] if t in remap else t) 51 | if changed: 52 | nutuple = tuple(nutemp) 53 | top_temps[i] = nutuple 54 | temps2sents[nutuple] = temps2sents[temp] 55 | del temps2sents[temp] 56 | 57 | def just_state2phrases(temps, temps2sents): 58 | state2phrases = defaultdict(lambda: defaultdict(int)) # defaultdict of defaultdict 59 | for temp in temps: 60 | for sent, lineno in temps2sents[temp]: 61 | for i, state in enumerate(temp): 62 | #state2phrases[state].add(sent[i]) 63 | state2phrases[state][sent[i]] += 1 64 | 65 | nustate2phrases = {} 66 | for k, v in state2phrases.iteritems(): 67 | phrases = list(v) 68 | counts = torch.Tensor([state2phrases[k][phrs] for phrs in phrases]) 69 | counts.div_(counts.sum()) 70 | nustate2phrases[k] = (phrases, counts) 71 | state2phrases = nustate2phrases 72 | return state2phrases 73 | 74 | 75 | def extract_from_tagged_data(datadir, bsz, thresh, tagged_fi, ntemplates): 76 | corpus = labeled_data.SentenceCorpus(datadir, bsz, thresh=thresh, add_bos=False, 77 | add_eos=False, test=False) 78 | nskips = 0 79 | for i in xrange(len(corpus.train)): 80 | if corpus.train[i][0].size(0) <= 4: 81 | nskips += corpus.train[i][0].size(1) 82 | print "assuming we start on line", nskips, "of train" 83 | temps2sents = group_by_template(tagged_fi, nskips) 84 | top_temps = sorted(temps2sents.keys(), key=lambda x: -len(temps2sents[x]))[:ntemplates] 85 | #remap_eos_states(top_temps, temps2sents) 86 | state2phrases = just_state2phrases(top_temps, temps2sents) 87 | 88 | 89 | return top_temps, temps2sents, state2phrases 90 | 91 | 92 | def topk_phrases(pobj, k): 93 | phrases, probs = pobj 94 | thing = sorted(zip(phrases, list(probs)), key=lambda x: -x[1]) 95 | sphrases, sprobs = zip(*thing) 96 | return sphrases[:k] 97 | 98 | 99 | def align_cntr(cntr, thresh=0.4): 100 | tote = float(sum(cntr.values())) 101 | nug = {k : v/tote for k, v in cntr.iteritems()} 102 | best, bestp = None, 0 103 | for k, v in nug.iteritems(): 104 | if v > bestp: 105 | best, bestp = k, v 106 | if bestp >= thresh: 107 | return best 108 | else: 109 | return None 110 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | utils and whatnot 3 | """ 4 | import math 5 | from collections import defaultdict, Counter 6 | import torch 7 | from torch.autograd import Variable 8 | 9 | def logsumexp0(X): 10 | """ 11 | X - L x B x K 12 | returns: 13 | B x K 14 | """ 15 | if X.dim() == 2: 16 | X = X.unsqueeze(2) 17 | axis = 0 18 | X2d = X.view(X.size(0), -1) 19 | maxes, _ = torch.max(X2d, axis, True) 20 | lse = maxes + torch.log(torch.sum(torch.exp(X2d - maxes.expand_as(X2d)), axis, True)) 21 | lse = lse.view(X.size(1), -1) 22 | return lse 23 | 24 | def logsumexp2(X): 25 | """ 26 | X - L x B x K 27 | returns: 28 | L x B 29 | """ 30 | if X.dim() == 2: 31 | X = X.unsqueeze(0) 32 | X2d = X.view(-1, X.size(2)) 33 | maxes, _ = torch.max(X2d, 1, True) 34 | lse = maxes + torch.log(torch.sum(torch.exp(X2d - maxes.expand_as(X2d)), 1, True)) 35 | lse = lse.view(X.size(0), -1) 36 | return lse 37 | 38 | def logsumexp1(X): 39 | """ 40 | X - B x K 41 | returns: 42 | B x 1 43 | """ 44 | maxes, _ = torch.max(X, 1, True) 45 | lse = maxes + torch.log(torch.sum(torch.exp(X - maxes.expand_as(X)), 1, True)) 46 | return lse 47 | 48 | 49 | def vlogsumexp(v): 50 | """ 51 | for vectors 52 | """ 53 | maxv = v.max() 54 | return maxv + math.log(torch.sum(torch.exp(v-maxv))) 55 | 56 | 57 | def make_fwd_constr_idxs(L, T, constrs): 58 | """ 59 | for use w/ fwd alg. 60 | constrs are 0-indexed 61 | """ 62 | cidxs = [set() for t in xrange(T)] 63 | bsz = len(constrs) 64 | for b in xrange(bsz): 65 | for tup in constrs[b]: 66 | if len(tup) == 2: 67 | start, end = tup 68 | else: 69 | start, end = tup[0], tup[1] 70 | clen = end - start 71 | # for last thing in segment only allow segment length 72 | end_steps_back = min(L, end) 73 | cidxs[end-1].update([(end_steps_back-l-1)*bsz + b 74 | for l in xrange(end_steps_back) if l+1 != clen]) 75 | # now disallow everything for everything else in the segment 76 | for i in xrange(start, end-1): 77 | steps_back = min(L, i+1) 78 | cidxs[i].update([(steps_back-l-1)*bsz + b for l in xrange(steps_back)]) 79 | # now disallow things w/in L of the end 80 | for i in xrange(end, min(T, end+L-1)): 81 | steps_back = min(L, i+1) 82 | cidxs[i].update([(steps_back-l+end-1)*bsz + b for l in xrange(i+1, end+steps_back)]) 83 | oi_cidxs = [None] # make 1-indexed 84 | oi_cidxs.extend([torch.LongTensor(list(idxs)) if len(idxs) > 0 else None for idxs in cidxs]) 85 | return oi_cidxs 86 | 87 | 88 | def make_bwd_constr_idxs(L, T, constrs): 89 | """ 90 | for use w/ bwd alg. 91 | constrs are a bsz-length list of lists of (start, end, label) 0-indexed tups 92 | """ 93 | cidxs = [set() for t in xrange(T)] 94 | bsz = len(constrs) 95 | for b in xrange(bsz): 96 | for tup in constrs[b]: 97 | if len(tup) == 2: 98 | start, end = tup 99 | else: 100 | start, end = tup[0], tup[1] 101 | clen = end - start 102 | steps_fwd = min(L, T-start) 103 | # for first thing only allow segment length 104 | cidxs[start].update([l*bsz + b for l in xrange(steps_fwd) if l+1 != clen]) 105 | 106 | # now disallow everything for everything else in the segment 107 | for i in xrange(start+1, end): 108 | steps_fwd = min(L, T-i) 109 | cidxs[i].update([l*bsz + b for l in xrange(steps_fwd)]) 110 | 111 | # now disallow things w/in L of the start 112 | for i in xrange(max(start-L+1, 0), start): 113 | steps_fwd = min(L, T-i) 114 | cidxs[i].update([l*bsz + b for l in xrange(steps_fwd) if i+l >= start]) 115 | 116 | oi_cidxs = [None] 117 | oi_cidxs.extend([torch.LongTensor(list(idxs)) if len(idxs) > 0 else None for idxs in cidxs]) 118 | return oi_cidxs 119 | 120 | 121 | def backtrace(node): 122 | """ 123 | assumes a node is (word, node) and that every history starts with (None, None) 124 | """ 125 | hyp = [node[0]] 126 | while node[1] is not None: 127 | node = node[1] 128 | hyp.append(node[0]) 129 | return hyp[-2::-1] # returns all but last element, reversed 130 | 131 | def backtrace3(node): 132 | """ 133 | assumes a node is (word, seg-label, node) etc 134 | """ 135 | hyp = [(node[0], node[1])] 136 | while node[2] is not None: 137 | node = node[2] 138 | hyp.append((node[0], node[1])) 139 | return hyp[-2::-1] 140 | 141 | def beam_search2(net, corpus, ss, start_inp, exh0, exc0, srcfieldenc, 142 | len_lps, row2tblent, row2feats, K, final_state=False): 143 | """ 144 | ss - discrete state index 145 | exh0 - layers x 1 x rnn_size 146 | exc0 - layers x 1 x rnn_size 147 | start_inp - 1 x 1 x emb_size 148 | len_lps - K x L, log normalized 149 | """ 150 | rul_ss = ss % net.K 151 | i2w = corpus.dictionary.idx2word 152 | w2i = corpus.dictionary.word2idx 153 | genset = corpus.genset 154 | unk_idx, eos_idx, pad_idx = w2i[""], w2i[""], w2i[""] 155 | state_emb_sz = net.state_embs.size(3) if net.one_rnn else 0 156 | if net.one_rnn: 157 | cond_start_inp = torch.cat([start_inp, net.state_embs[rul_ss]], 2) # 1 x 1 x cat_size 158 | hid, (hc, cc) = net.seg_rnns[0](cond_start_inp, (exh0, exc0)) 159 | else: 160 | hid, (hc, cc) = net.seg_rnns[rul_ss](start_inp, (exh0, exc0)) 161 | curr_hyps = [(None, None)] 162 | best_wscore, best_lscore = None, None # so we can truly average over words etc later 163 | best_hyp, best_hyp_score = None, -float("inf") 164 | curr_scores = torch.zeros(K, 1) 165 | # N.B. we assume we have a single feature row for each timestep rather than avg 166 | # over them as at training time. probably better, but could conceivably average like 167 | # at training time. 168 | inps = Variable(torch.LongTensor(K, 4), volatile=True) 169 | for ell in xrange(net.L): 170 | wrd_dist = net.get_next_word_dist(hid, rul_ss, srcfieldenc).cpu() # K x nwords 171 | # disallow unks 172 | wrd_dist[:, unk_idx].zero_() 173 | if not final_state: 174 | wrd_dist[:, eos_idx].zero_() 175 | #if not ss == 25 or not ell == 3: 176 | net.collapse_word_probs(row2tblent, wrd_dist) 177 | wrd_dist.log_() 178 | if ell > 0: # add previous scores 179 | wrd_dist.add_(curr_scores.expand_as(wrd_dist)) 180 | maxprobs, top2k = torch.topk(wrd_dist.view(-1), 2*K) 181 | cols = wrd_dist.size(1) 182 | # we'll break as soon as is at the top of the beam. 183 | # this ignores but whatever 184 | if top2k[0] == eos_idx: 185 | final_hyp = backtrace(curr_hyps[0]) 186 | final_hyp.append(eos_idx) 187 | return final_hyp, maxprobs[0], len_lps[ss][ell] 188 | 189 | new_hyps, anc_hs, anc_cs = [], [], [] 190 | #inps.data.fill_(pad_idx) 191 | inps.data[:, 1].fill_(w2i[""]) 192 | inps.data[:, 2].fill_(w2i[""]) 193 | inps.data[:, 3].fill_(w2i[""]) 194 | for k in xrange(2*K): 195 | anc, wrd = top2k[k] / cols, top2k[k] % cols 196 | # check if any of the maxes are eop 197 | if wrd == net.eop_idx and ell > 0: 198 | # add len score (and avg over num words incl eop i guess) 199 | wlenscore = maxprobs[k]/(ell+1) + len_lps[ss][ell-1] 200 | if wlenscore > best_hyp_score: 201 | best_hyp_score = wlenscore 202 | best_hyp = backtrace(curr_hyps[anc]) 203 | best_wscore, best_lscore = maxprobs[k], len_lps[ss][ell-1] 204 | else: 205 | curr_scores[len(new_hyps)][0] = maxprobs[k] 206 | if wrd >= net.decoder.out_features: # a copy 207 | tblidx = wrd - net.decoder.out_features 208 | inps.data[len(new_hyps)].copy_(row2feats[tblidx]) 209 | else: 210 | inps.data[len(new_hyps)][0] = wrd if i2w[wrd] in genset else unk_idx 211 | new_hyps.append((wrd, curr_hyps[anc])) 212 | anc_hs.append(hc.narrow(1, anc, 1)) # layers x 1 x rnn_size 213 | anc_cs.append(cc.narrow(1, anc, 1)) # layers x 1 x rnn_size 214 | if len(new_hyps) == K: 215 | break 216 | assert len(new_hyps) == K 217 | curr_hyps = new_hyps 218 | if net.lut.weight.data.is_cuda: 219 | inps = inps.cuda() 220 | embs = net.lut(inps).view(1, K, -1) # 1 x K x nfeats*emb_size 221 | if net.mlpinp: 222 | embs = net.inpmlp(embs) # 1 x K x rnninsz 223 | if net.one_rnn: 224 | cond_embs = torch.cat([embs, net.state_embs[rul_ss].expand(1, K, state_emb_sz)], 2) 225 | hid, (hc, cc) = net.seg_rnns[0](cond_embs, (torch.cat(anc_hs, 1), torch.cat(anc_cs, 1))) 226 | else: 227 | hid, (hc, cc) = net.seg_rnns[rul_ss](embs, (torch.cat(anc_hs, 1), torch.cat(anc_cs, 1))) 228 | # hypotheses of length L still need their end probs added 229 | # N.B. if the falls off the beam we could end up with situations 230 | # where we take an L-length phrase w/ a lower score than 1-word followed by eos. 231 | wrd_dist = net.get_next_word_dist(hid, rul_ss, srcfieldenc).cpu() # K x nwords 232 | #wrd_dist = net.get_next_word_dist(hid, ss, srcfieldenc).cpu() # K x nwords 233 | wrd_dist.log_() 234 | wrd_dist.add_(curr_scores.expand_as(wrd_dist)) 235 | for k in xrange(K): 236 | wlenscore = wrd_dist[k][net.eop_idx]/(net.L+1) + len_lps[ss][net.L-1] 237 | if wlenscore > best_hyp_score: 238 | best_hyp_score = wlenscore 239 | best_hyp = backtrace(curr_hyps[k]) 240 | best_wscore, best_lscore = wrd_dist[k][net.eop_idx], len_lps[ss][net.L-1] 241 | #if ss == 80: 242 | # print "going with", best_hyp 243 | return best_hyp, best_wscore, best_lscore 244 | 245 | 246 | def calc_pur(counters): 247 | purs, purs2 = [], [] 248 | for counter in counters: 249 | if len(counter) > 0: 250 | vals = counter.values() 251 | if len(vals) > 0: 252 | nonothers = [val for k, val in counter.items() if k != "other"] 253 | oval = counter["other"] if "other" in counter else 0 254 | if len(nonothers) > 0: 255 | total = float(sum(nonothers)) 256 | maxval = max(nonothers) 257 | if oval < total: 258 | purs.append(maxval/total) 259 | purs2.append(maxval/(total+oval)) 260 | purs, purs2 = torch.Tensor(purs), torch.Tensor(purs2) 261 | print purs.mean(), purs.std() 262 | print purs2.mean(), purs2.std() 263 | --------------------------------------------------------------------------------