├── models ├── __init__.py ├── .DS_Store ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── attention.cpython-37.pyc │ └── feedback.cpython-37.pyc ├── attention.py ├── feedback.py └── variational_template_machine.py ├── data_utils ├── __init__.py ├── label_spnlg.py └── label_wiki.py ├── .gitignore ├── requirements.txt ├── README.md ├── data └── README.md ├── generate.py └── train.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/SPNLG 2 | data/Wiki -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | numpy 4 | nltk -------------------------------------------------------------------------------- /models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReneeYe/VariationalTemplateMachine/HEAD/models/.DS_Store -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReneeYe/VariationalTemplateMachine/HEAD/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/attention.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReneeYe/VariationalTemplateMachine/HEAD/models/__pycache__/attention.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/feedback.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ReneeYe/VariationalTemplateMachine/HEAD/models/__pycache__/feedback.cpython-37.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Variational Template Machine 2 | 3 | Code for [Variational Template Machine for Data-to-text generation](https://openreview.net/forum?id=HkejNgBtPB) (Ye et al., ICLR 2020). 4 | 5 | ## Requirements 6 | All dependencies (Python 3) can be installed via: 7 | ``` 8 | pip install -r requirements.txt 9 | ``` 10 | 11 | ## Running 12 | - **Datasets**: 13 | 14 | Two datasets (SPNLG and Wiki) can be downloaded from: https://drive.google.com/drive/folders/1FsNlFh2aUbuBl45zEjgvAXDkp_e4hQmV?usp=sharing 15 | 16 | Details info: [HERE](./data/README.md) 17 | 18 | - **Training**: 19 | 20 | ``` 21 | DATASET_PATH= 22 | MODEL_PATH= 23 | 24 | python train.py -data ${DATASET_PATH} -max_vocab_cnt 50000 -emb_size 786 -hid_size 512 -table_hid_size 256 -pool_type max -sent_represent last_hid -z_latent_size 128 -c_latent_size 256 -dec_attention -drop_emb -add_preserving_content_loss -pc_weight 1.0 -add_preserving_template_loss -pt_weight 1.0 -anneal_function_z const -anneal_k_z 0.8 -anneal_function_c const -anneal_k_c 0.8 -add_mi_z -mi_z_weight 0.5 -add_mi_c -mi_c_weight 0.5 -lr 0.001 -clip 5.0 -cuda -log_interval 500 -bsz 16 -paired_epochs 5 -raw_epochs 2 -epochs 20 -cuda -save ${MODEL_PATH} 25 | ``` 26 | Notice that the arguments may be different. 27 | - **Generation**: 28 | ``` 29 | OUTPUT_PATH= 30 | python generate.py -data ${DATASET_PATH} -max_vocab_cnt 50000 -load ${MODEL_PATH} -various_gen 5 -mask_prob 0.0 -cuda -decode_method temp_sample -sample_temperature 0.2 -gen_to_fi ${OUTPUT_PATH} 31 | ``` -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | ## Dataset 2 | Two datasets (SPNLG and Wiki) can be downloaded from https://drive.google.com/drive/folders/1FsNlFh2aUbuBl45zEjgvAXDkp_e4hQmV?usp=sharing 3 | 4 | 5 | ## Statistics 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 |
TrainValidTest
PairedRawPairedRawPaired
SPNLG14k150k21k/21k
Wiki84k842k73k43k73k
38 | 39 | 40 | ## How we get the datasets? 41 | - **SPNLG** 42 | - The dataset is from [sentence-planning-NLG dataset](https://nlds.soe.ucsc.edu/sentence-planning-NLG), a dataset describing the restaurant informations, containing 3 CSV files. 43 | - We aggregate all the 3 CSV files, and leave `train:valid:test=8:1:1`, `paired:raw=1:10` for the train set. 44 | 45 | - **Wiki** 46 | - The dataset is constructed from both [*Wiki-Bio* Dataset](https://github.com/DavidGrangier/wikipedia-biography-dataset) and [*Wikipedia Person and Animal* Dataset](https://drive.google.com/file/d/1TzcNdjZ0EsLh_rC1pBC7dU70jINcsVJd/view). 47 | - We used same valid and test set as *Wiki-Bio*. 48 | - For training set, we only randomly use 84k samples in *Wiki-Bio*-train for paired data. We use the remain sentences in *Wiki-Bio*-train and person descriptions from *Wikipedia Person and Animal* as raw data (totally up to 842k). 49 | 50 | ## Related links: 51 | - Sentence planning NLG dataset: https://nlds.soe.ucsc.edu/sentence-planning-NLG 52 | - Wikipedia biography dataset (Wiki-Bio): https://github.com/DavidGrangier/wikipedia-biography-dataset 53 | - Wikipedia Person and Animal Dataset: https://drive.google.com/file/d/1TzcNdjZ0EsLh_rC1pBC7dU70jINcsVJd/view 54 | -------------------------------------------------------------------------------- /models/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class AttentionBase(nn.Module): 7 | def __init__(self, attn_type, query_dim, key_dim): 8 | super(AttentionBase, self).__init__() 9 | self.attn_type = attn_type 10 | self.query_dim = query_dim 11 | self.key_dim = key_dim 12 | 13 | def _score(self, query, keys): 14 | raise NotImplementedError 15 | 16 | def forward(self, query, keys, values, mask=None, return_logits=False): 17 | # query: target_len x b x query_dim 18 | # keys: b x src_len x key_dim 19 | # values: b x src_len x value_dim 20 | # mask: b x src_len, mask = 0/-inf 21 | 22 | weight = self._score(query, keys) # b x tgt_len x src_len 23 | if mask is not None: 24 | weight = weight + mask.unsqueeze(1).expand_as(weight) 25 | score = F.softmax(weight, dim=-1) # b x tgt_len x src_len 26 | ctx = torch.bmm(score, values) # b x tgt_len x value_dim 27 | if return_logits: 28 | return score, ctx, weight 29 | else: 30 | return score, ctx 31 | 32 | class DotAttention(AttentionBase): 33 | def __init__(self, dim): 34 | super(DotAttention, self).__init__("dot", dim, dim) 35 | # not other parameters for DotAttention 36 | 37 | def _score(self, query, keys): 38 | # query: tgt_len x b x query_dim 39 | # keys: b x src_len x key_dim 40 | qdim, kdim = query.size(-1), keys.size(-1) 41 | # check size and dimension 42 | assert qdim == kdim 43 | assert query.dim() == 3 44 | assert keys.dim() == 3 45 | 46 | return torch.bmm(query.transpose(0, 1), keys.transpose(1, 2)) 47 | 48 | class GeneralAttention(AttentionBase): 49 | def __init__(self, query_dim, key_dim): 50 | super(GeneralAttention, self).__init__("general", query_dim, key_dim) 51 | self.W = nn.Linear(query_dim, key_dim, bias=False) 52 | 53 | def _score(self, query, keys): 54 | qdim, kdim = query.size(-1), keys.size(-1) 55 | # check size and dimension 56 | assert qdim == self.query_dim 57 | assert kdim == self.key_dim 58 | assert query.dim() == 3 59 | assert keys.dim() == 3 60 | 61 | return torch.bmm(self.W(query).transpose(0, 1), keys.transpose(1, 2)) 62 | 63 | 64 | class ConcatAttention(AttentionBase): 65 | def __init__(self, query_dim, key_dim): 66 | super(ConcatAttention, self).__init__("concat", query_dim, key_dim) 67 | self.Wa = nn.Linear(query_dim + key_dim, key_dim, bias=False) 68 | self.va = nn.Linear(key_dim, 1, bias=False) 69 | 70 | def _score(self, query, keys): 71 | qdim, kdim = query.size(-1), keys.size(-1) 72 | # check size and dimension 73 | assert qdim == self.query_dim 74 | assert kdim == self.key_dim 75 | assert query.dim() == 3 76 | assert keys.dim() == 3 77 | 78 | # query: target_len x batch_size x query_dim 79 | # keys: batch_size x src_len x key_dim 80 | 81 | tgt_len, src_len = query.size(0), keys.size(1) 82 | 83 | query = query.transpose(0, 1).unsqueeze(2).repeat(1, 1, src_len, 1) 84 | keys = keys.unsqueeze(1).repeat(1, tgt_len, 1, 1) 85 | 86 | return self.va(F.tanh(self.Wa(torch.cat([query, keys], dim=-1)))).squeeze(-1) # batch_size x tar_len x src_len x 1 87 | 88 | 89 | class CopyAttention(nn.Module): 90 | def __init__(self, attn_type, query_dim, key_dim): 91 | super(CopyAttention, self).__init__() 92 | if attn_type == "dot": 93 | assert query_dim == key_dim 94 | self.attention = DotAttention(query_dim) 95 | elif attn_type == "general": 96 | self.attention = GeneralAttention(query_dim, key_dim) 97 | elif attn_type == "concat": 98 | self.attention = ConcatAttention(query_dim, key_dim) 99 | else: 100 | raise NotImplementedError 101 | 102 | def forward(self, query, keys, values, mask=None): 103 | # query: target_len x batch_size x query_dim 104 | # keys: batch_size x src_len x key_dim 105 | # values: batch_size x src_len x value_dim 106 | # mask: batch_size x src_len, mask = 0/-inf 107 | 108 | weight = self.attention._score(query, keys) # batch_size x target_len x src_len 109 | if mask is not None: 110 | weight = weight + mask.unsqueeze(1).expand_as(weight) 111 | return weight 112 | 113 | 114 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import numpy as np 4 | import torch 5 | from data_utils import label_wiki, label_spnlg 6 | from models import variational_template_machine 7 | import os 8 | 9 | def make_masks(src, pad_idx, max_pool=False): 10 | """ 11 | src - bsz x nfields x nfeats(3) 12 | """ 13 | neginf = -1e38 14 | bsz, nfields, nfeats = src.size() 15 | fieldmask = (src.eq(pad_idx).sum(2) == nfeats) # binary bsz x nfields tensor 16 | avgmask = (1 - fieldmask).float() # 1s where not padding 17 | if not max_pool: 18 | avgmask.div_(avgmask.sum(1, True).expand(bsz, nfields)) 19 | fieldmask = fieldmask.float() * neginf # 0 where not all pad and -1e38 elsewhere 20 | return fieldmask, avgmask 21 | 22 | def random_mask(src, pad_idx, prob=0.4, seed=1): 23 | # src: b x nfield x 3(key, pos, wrd) 24 | neginf = -1e38 25 | bsz, nfield, _ = src.size() 26 | fieldmask = (src.eq(pad_idx).sum(2) == 3) # b x nfield, 0 for has, 1 for pad 27 | mask_matrix = torch.rand(bsz, nfield, generator=torch.manual_seed(seed)) 28 | mask_matrix = torch.max((mask_matrix < prob), fieldmask) # 1 for pad, and 0 for not pad 29 | mask_matrix = mask_matrix.float() * neginf # 0 for not pad, -inf for pad 30 | return mask_matrix 31 | 32 | 33 | parser = argparse.ArgumentParser(description='') 34 | # basic data setups 35 | parser.add_argument('-data', type=str, default='', help='path to data dir') 36 | parser.add_argument('-bsz', type=int, default=16, help='batch size') 37 | parser.add_argument('-seed', type=int, default=1111, help='set random seed, ' 38 | 'when training, it is to shuffle training batch, ' 39 | 'when testing, it is to define the latent samples') 40 | parser.add_argument('-cuda', action='store_true', help='use CUDA') 41 | parser.add_argument('-max_vocab_cnt', type=int, default=50000) 42 | parser.add_argument('-max_seqlen', type=int, default=70, help='') 43 | 44 | # model saves 45 | parser.add_argument('-load', type=str, default='', help='path to saved model') 46 | 47 | # for generation and test 48 | parser.add_argument('-gen_to_fi', type=str, default=None, help='generate to which file') 49 | parser.add_argument('-various_gen', type=int, default=1, help='define generation how many sentence, and the result is saved in gen_to_fi') 50 | parser.add_argument('-mask_prob', type=float, default=0.0, help='mask item at prob') 51 | 52 | # decode method 53 | parser.add_argument('-decode_method', type=str, default='beam_search', help="beam_seach/temp_sample/topk_sample/nucleus_sample") 54 | parser.add_argument('-beamsz', type=int, default=1, help='beam size') 55 | parser.add_argument('-sample_temperature', type=float, default=1.0, help='set sample_temperature for decode_method=temp_sample') 56 | parser.add_argument('-topk', type=int, default=5, help='for topk_sample, if topk=1, it is greedy') 57 | parser.add_argument('-topp', type=float, default=1.0, help='for nucleus(top-p) sampleing, if topp=1, then its fwd_sample') 58 | 59 | 60 | if __name__ == "__main__": 61 | args = parser.parse_args() 62 | print(args) 63 | torch.manual_seed(args.seed) 64 | 65 | if torch.cuda.is_available(): 66 | if not args.cuda: 67 | print("WARNING: You have a CUDA device, so you should probably run with -cuda") 68 | sys.stdout.flush() 69 | else: 70 | torch.cuda.manual_seed_all(args.seed) 71 | else: 72 | if args.cuda: 73 | print("No CUDA device.") 74 | args.cuda = False 75 | 76 | # data loader 77 | if 'wiki' in args.data.lower(): 78 | corpus = label_wiki.Corpus(args.data, args.bsz, max_count=args.max_vocab_cnt, 79 | add_bos=False, add_eos=False) 80 | elif 'spnlg' in args.data.lower(): 81 | corpus = label_spnlg.Corpus(args.data, args.bsz, max_count=args.max_vocab_cnt, 82 | add_bos=False, add_eos=True) 83 | else: 84 | raise NotImplementedError 85 | print("data loaded!") 86 | print("Vocabulary size:", len(corpus.dictionary)) 87 | args.pad_idx = corpus.dictionary.word2idx[''] 88 | 89 | # load model 90 | if len(args.load) > 0: 91 | print("load model ...") 92 | saved_stuff = torch.load(args.load) 93 | saved_args, saved_state = saved_stuff["opt"], saved_stuff["state_dict"] 94 | for k, v in args.__dict__.items(): 95 | if k not in saved_args.__dict__: 96 | saved_args.__dict__[k] = v 97 | if k in ["decode_method", "beamsz", "sample_temperature", "topk", "topp"]: 98 | saved_args.__dict__[k] = v 99 | net = variational_template_machine.VariationalTemplateMachine(corpus, saved_args) 100 | net.load_state_dict(saved_state, strict=False) 101 | del saved_args, saved_state, saved_stuff 102 | else: 103 | print("WARNING: No model load! Random init.") 104 | net = variational_template_machine.VariationalTemplateMachine(corpus, args) 105 | if args.cuda: 106 | net = net.cuda() 107 | 108 | def generation(test_out, num=3): 109 | output_fn = open(test_out, 'w') 110 | # read source table 111 | table_path = os.path.join(args.data, "src_test.txt") 112 | paired_src_feat_tst, origin_src_tst, lineno_tst = corpus.get_test_data(table_path) 113 | for i in range(len(paired_src_feat_tst)): 114 | paired_src_feat = paired_src_feat_tst[i] 115 | for j in range(num): 116 | if j == 0: 117 | paired_mask, _ = make_masks(paired_src_feat, args.pad_idx) 118 | else: # you may set args.mask_prob=0 119 | paired_mask = random_mask(paired_src_feat.cpu(), args.pad_idx, prob=args.mask_prob, 120 | seed=np.random.randint(5000)) 121 | if args.cuda: 122 | paired_src_feat, paired_mask = paired_src_feat.cuda(), paired_mask.cuda() 123 | if args.decode_method != "beam_search": 124 | sentence_ids = net.predict(paired_src_feat, paired_mask) 125 | else: 126 | sentence_ids = net.predict(paired_src_feat, paired_mask, beam_size=j+1) 127 | 128 | sentence_ids = sentence_ids.data.cpu() 129 | sent_words = [] 130 | for t, wid in enumerate(sentence_ids[:, 0]): 131 | word = corpus.dictionary.idx2word[wid] 132 | if word != '': 133 | sent_words.append(word) 134 | else: 135 | break 136 | output_fn.write(" ".join(str(w) for w in sent_words) + '\n') 137 | output_fn.write("\n") 138 | output_fn.close() 139 | 140 | net.eval() 141 | generation(args.gen_to_fi, num=args.various_gen) 142 | -------------------------------------------------------------------------------- /models/feedback.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class FeedBackBase(nn.Module): 6 | def __init__(self, type, lookup): 7 | super(FeedBackBase, self).__init__() 8 | self.type = type 9 | self.lookup = lookup 10 | 11 | def prepare(self, decoder_out): 12 | raise NotImplementedError 13 | 14 | def collect(self, from_torch=None): 15 | raise NotImplementedError 16 | 17 | 18 | class GreedyFeedBack(FeedBackBase): 19 | def __init__(self, lookup, unk_idx=-1): 20 | super(GreedyFeedBack, self).__init__(type="greedy", lookup=lookup) 21 | self.ids = [] 22 | self.unk_idx = unk_idx 23 | 24 | def prepare(self, decoder_out): 25 | # decoder_out : 1 x b x vocab 26 | word_prob = F.softmax(decoder_out, dim=-1) 27 | word_prob[:, :, self.unk_idx].fill_(0) # disvow unk 28 | return word_prob 29 | 30 | def forward(self, decoder_out, keep_ids=True): 31 | # decoder_out: 1 x b x vocab 32 | log_p = self.prepare(decoder_out) 33 | # log_p: [1, batch_size, vocab_size] 34 | max_id = torch.argmax(log_p[0], dim=-1).unsqueeze(0) # [1 x batch_size] 35 | if keep_ids: 36 | self.ids.append(max_id) 37 | return max_id 38 | 39 | def clear_ids(self): 40 | self.ids = [] 41 | 42 | def collect(self, from_torch=None): 43 | # from_torch: 44 | if from_torch is None: 45 | ret = torch.cat(self.ids, dim=0) # [seq_len, batch_size, kmul] 46 | self.ids = [] 47 | return ret 48 | else: 49 | return torch.argmax(from_torch, dim=-1) 50 | 51 | 52 | class SampleFeedBack(FeedBackBase): 53 | def __init__(self, lookup, unk_idx=-1): 54 | super(SampleFeedBack, self).__init__(type="sample", lookup=lookup) 55 | self.ids = [] 56 | self.unk_idx = unk_idx 57 | 58 | def prepare(self, decoder_out): 59 | # decoder_out : 1 x b x vocab 60 | word_prob = F.softmax(decoder_out, dim=-1) 61 | word_prob[:, :, self.unk_idx].fill_(0) # disenable unk 62 | return word_prob 63 | 64 | def forward(self, decoder_out): 65 | log_p = self.prepare(decoder_out) 66 | # log_p: [1, batch_size*beam_size, vocab_size] 67 | sample_idx = torch.multinomial(log_p[0], 1) # b x 1 68 | self.ids.append(sample_idx.transpose(0, 1)) # append 1 x b 69 | return sample_idx 70 | # return self.lookup(sample_idx) 71 | 72 | def clear_ids(self): 73 | self.ids = [] 74 | 75 | def collect(self): 76 | ret = torch.cat(self.ids, dim=0) # seq x b 77 | self.ids = [] 78 | return ret 79 | 80 | class SampleFeedBackWithTemperature(FeedBackBase): 81 | def __init__(self, lookup, unk_idx=-1, temperature=1.0): 82 | super(SampleFeedBackWithTemperature, self).__init__(type="temperature_sample", lookup=lookup) 83 | self.ids = [] 84 | self.unk_idx = unk_idx 85 | self.temperature = temperature 86 | 87 | def prepare(self, decoder_out): 88 | word_prob = F.softmax(decoder_out/self.temperature, dim=-1) 89 | word_prob[:, :, self.unk_idx].fill_(0) 90 | return word_prob 91 | 92 | def forward(self, decoder_out): 93 | log_p = self.prepare(decoder_out) 94 | # log_p: [1, batch_size, vocab_size] 95 | sample_idx = torch.multinomial(log_p[0], 1) # b x 1 96 | self.ids.append(sample_idx.transpose(0, 1)) # append 1 x b 97 | return sample_idx 98 | 99 | def clear_ids(self): 100 | self.ids = [] 101 | 102 | def collect(self): 103 | ret = torch.cat(self.ids, dim=0) 104 | self.ids = [] 105 | return ret 106 | 107 | class TopkSampleFeedBack(FeedBackBase): 108 | def __init__(self, lookup, unk_id=-1, topk=1): 109 | super(TopkSampleFeedBack, self).__init__(type="topk_sample", lookup=lookup) 110 | self.ids = [] 111 | self.unk_idx = unk_id 112 | self.topk = topk 113 | 114 | def prepare(self, decoder_out): 115 | # 1 x b x vocab 116 | idx_to_remove = decoder_out < torch.topk(decoder_out, self.topk)[0][..., -1, None] 117 | decoder_out[idx_to_remove] = -float("inf") 118 | word_prob = F.softmax(decoder_out, dim=-1) 119 | word_prob[:, :, self.unk_idx].fill_(0) 120 | return word_prob 121 | 122 | def forward(self, decoder_out): 123 | self.topk = min(self.topk, decoder_out.size(-1)) # safety check 124 | log_p = self.prepare(decoder_out) # 1 x b x vocab 125 | sample_idx = torch.multinomial(log_p[0], 1) # b x 1 126 | self.ids.append(sample_idx.transpose(0, 1)) # append 1 x b 127 | return sample_idx 128 | 129 | def clear_ids(self): 130 | self.ids = [] 131 | 132 | def collect(self): 133 | ret = torch.cat(self.ids, dim=0) 134 | self.ids = [] 135 | return ret 136 | 137 | 138 | class NucleusSampleFeedBack(FeedBackBase): 139 | def __init__(self, lookup, unk_id=-1, topp=1.0): 140 | super(NucleusSampleFeedBack, self).__init__(type="nucleus_sample", lookup=lookup) 141 | self.ids = [] 142 | self.unk_idx = unk_id 143 | self.topp = topp 144 | 145 | def prepare(self, decoder_out): 146 | sorted_value, sorted_idx = torch.sort(decoder_out, descending=True) 147 | cumulated_p = torch.cumsum(F.softmax(sorted_value, dim=-1), dim=-1) 148 | sort_idx_to_remove = cumulated_p > self.topp 149 | # assure there must be one element in each batch 150 | sort_idx_to_remove[..., 1:] = sort_idx_to_remove[..., :-1].clone() 151 | sort_idx_to_remove[..., 0] = 0 152 | id_to_remove = torch.gather(sort_idx_to_remove, -1, sorted_idx) 153 | decoder_out[id_to_remove] = -float("inf") 154 | word_prob = F.softmax(decoder_out, dim=-1) 155 | word_prob[:, :, self.unk_idx].fill_(0) 156 | return word_prob 157 | 158 | def forward(self, decoder_out): 159 | log_p = self.prepare(decoder_out) # 1 x b x vocab 160 | sample_idx = torch.multinomial(log_p[0], 1) 161 | self.ids.append(sample_idx.transpose(0, 1)) 162 | return sample_idx 163 | 164 | def clear_ids(self): 165 | self.ids = [] 166 | 167 | def collect(self): 168 | ret = torch.cat(self.ids, dim=0) 169 | self.ids = [] 170 | return ret 171 | 172 | class BeamFeedBack(FeedBackBase): 173 | """ 174 | a helper class for inferenece with beam search 175 | """ 176 | def __init__(self, lookup, beam_size, unk_idx=-1): 177 | super(BeamFeedBack, self).__init__(type="beam", lookup=lookup) 178 | self.beam_size = beam_size 179 | self.output_size = lookup.num_embeddings 180 | self.unk_idx = unk_idx 181 | self.back_pointers = [] 182 | self.symbols = [] 183 | 184 | def repeat(self, v): 185 | # v: batch_size x ? 186 | return v.repeat(1, self.beam_size, 1) # v: 1x bx voc 187 | # return v.unsqueeze(1).repeat(1, self.beam_size, 1) # 1 x b*beam x ? 188 | 189 | def forward(self, past_p, cur_p, batch_size, step, keep_ids=True): 190 | # cur_p: [batch*beam, vocab_size] 191 | # past_p: [batch*beam, 1] 192 | if step == 0: 193 | score = cur_p.view(batch_size, -1)[:, 0:self.output_size] 194 | else: 195 | score = (cur_p + past_p).view(batch_size, -1) 196 | top_v, top_id = score.topk(self.beam_size, dim=1) 197 | 198 | back_ptr = top_id.div(self.output_size) # which beam 199 | symbols = top_id.fmod(self.output_size) # which word 200 | past_p = top_v.view(-1, 1) 201 | if keep_ids: 202 | self.back_pointers.append(back_ptr.view(-1, 1)) 203 | self.symbols.append(symbols.view(-1, 1)) 204 | 205 | return past_p, symbols 206 | 207 | def clear_ids(self): 208 | self.symbols = [] 209 | self.back_pointers = [] 210 | 211 | def collect(self, past_p, batch_size): 212 | # past_p: b*beam x 1 213 | final_seq_symbols = [] 214 | cum_sum = past_p.view(-1, self.beam_size) # b x beam 215 | 216 | max_seq_ids = cum_sum.topk(self.beam_size)[1] # batch_size x beam_size # .data.cpu().view(-1).numpy() 217 | 218 | rev_seq_symbols = self.symbols[::-1] 219 | rev_back_ptrs = self.back_pointers[::-1] 220 | 221 | for symbols, back_ptrs in zip(rev_seq_symbols, rev_back_ptrs): 222 | symbol2ds = symbols.view(-1, self.beam_size) 223 | back2ds = back_ptrs.view(-1, self.beam_size) 224 | 225 | selected_symbols = [] 226 | selected_parents = [] 227 | for b_id in range(batch_size): 228 | selected_parents.append(back2ds[b_id, max_seq_ids[b_id]]) 229 | selected_symbols.append(symbol2ds[b_id, max_seq_ids[b_id]]) 230 | # print(back2ds[b_id, max_seq_ids[b_id]]) 231 | # print(symbol2ds[b_id, max_seq_ids[b_id]]) 232 | 233 | final_seq_symbols.append(torch.stack(selected_symbols).unsqueeze(1)) 234 | max_seq_ids = torch.stack(selected_parents)# .data.cpu().numpy() 235 | 236 | sequence_symbols = torch.cat(final_seq_symbols[::-1], dim=1) # batch_size x seq_len x beam_size 237 | sequence_symbols = sequence_symbols[:, :, 0].transpose(0, 1) 238 | return sequence_symbols 239 | 240 | 241 | if __name__ == "__main__": 242 | vocab = 100 243 | seq = 5 244 | emb = 30 245 | pad_idx = 1 246 | batch = 4 247 | beam_size = 3 248 | lut = nn.Embedding(vocab, emb, padding_idx=pad_idx) 249 | print(lut.num_embeddings) 250 | # print("use greedy") 251 | feedback = GreedyFeedBack(lut, unk_idx=0) 252 | beam_fb = BeamFeedBack(lut, beam_size) 253 | past_p = torch.zeros(batch * beam_size, 1) 254 | 255 | for t in range(seq): 256 | word_dis = torch.randn(1, batch, vocab) 257 | # max_ids = feedback(word_dis.squeeze(0)) 258 | max_ids = feedback(word_dis)[0][0].item() 259 | print("greedy - max_ids", max_ids) 260 | 261 | # print("greedy - max_ids", max_ids.tolist()) 262 | cur_p = beam_fb.repeat(word_dis).squeeze(0) 263 | # print("cur_p size", cur_p.size()) 264 | past_p, symbol = beam_fb(past_p, cur_p, batch, t) 265 | print(symbol[0][0].item()) 266 | print("beam - most possible symbol:", symbol[:, 0].tolist()) 267 | # print(beam_fb.back_pointers) 268 | 269 | greedy_ids = feedback.collect() # seq x b 270 | # print(sentences_ids) 271 | print(greedy_ids.size()) 272 | 273 | beam_ids = beam_fb.collect(past_p, batch) 274 | print(beam_ids) 275 | -------------------------------------------------------------------------------- /data_utils/label_spnlg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from collections import Counter 4 | import numpy as np 5 | import re 6 | from nltk.tokenize import word_tokenize 7 | 8 | 9 | def get_spnlg_field(tableStr): 10 | fdict = {} # (field, pos) -> value 11 | all_items = re.findall(r',?(.*?)[[](.*?)[]]', tableStr) # [('name', 'nameVariable'), (' food', 'Chinese food')] 12 | for item in all_items: # eatType[pub] 13 | field, field_value = item 14 | field = field.strip() 15 | values = word_tokenize(field_value) 16 | for i, v in enumerate(values): 17 | fdict[(field, i)] = v 18 | return fdict 19 | 20 | 21 | class Dictionary(object): 22 | def __init__(self, unk_word=""): 23 | self.unk_word = unk_word 24 | self.idx2word = [unk_word, "", "", "", ""] # OpenNMT constants 25 | self.word2idx = {word: i for i, word in enumerate(self.idx2word)} 26 | 27 | def add_word(self, word, train=False): 28 | """add extra word, returns idx of word 29 | :param word: a str 30 | :param train: bool, if true, then update self.idx2word and w2i; if false, just update w2i 31 | """ 32 | if train and word not in self.word2idx: 33 | self.idx2word.append(word) 34 | self.word2idx[word] = len(self.idx2word) - 1 35 | return self.word2idx[word] if word in self.word2idx else self.word2idx[self.unk_word] 36 | 37 | def bulk_add(self, words): 38 | """add lots of words, assumes train=True 39 | :param words: a list of words 40 | """ 41 | self.idx2word.extend(words) 42 | self.word2idx = {word: i for i, word in enumerate(self.idx2word)} 43 | 44 | def __len__(self): 45 | return len(self.idx2word) 46 | 47 | 48 | class Corpus(object): 49 | def __init__(self, path, bsz, max_count=50000, add_bos=False, add_eos=False): 50 | self.dictionary = Dictionary() 51 | self.value_dict = Dictionary() 52 | assert 'spnlg' in path.lower() 53 | self.path = path 54 | pair_train_src = os.path.join(path, "pair_src.train") 55 | pair_valid_src = os.path.join(path, "pair_src.valid") 56 | pair_train_tgt = os.path.join(path, "pair_tgt.train") 57 | pair_valid_tgt = os.path.join(path, "pair_tgt.valid") 58 | raw_train_text = os.path.join(path, "raw_tgt.train") 59 | raw_valid_text = os.path.join(path, "raw_tgt.valid") 60 | 61 | self.gen_vocab = Dictionary() 62 | self.make_vocab(pair_train_tgt, pair_train_src, raw_train_text, max_count=max_count) 63 | self.genset.add("") 64 | # self.value_dict.add_word("", train=True) 65 | 66 | # load training data 67 | pair_sents_train, pair_sk_sents_train, \ 68 | pair_src_feats_train = self.load_paired_data(pair_train_src, pair_train_tgt, 69 | add_to_dict=True, add_bos=add_bos, add_eos=add_eos) 70 | raw_sents_train = self.load_raw_data(raw_train_text) 71 | 72 | self.paired_train, _ = self._minibatchify_pair(pair_sents_train, pair_sk_sents_train, pair_src_feats_train, bsz) 73 | self.raw_train, _ = self._minibatchify_raw(raw_sents_train, bsz) 74 | del pair_sents_train, pair_sk_sents_train, pair_src_feats_train, raw_sents_train 75 | 76 | # load valid data 77 | pair_sents_valid, pair_sk_sents_valid, \ 78 | pair_src_feats_valid = self.load_paired_data(pair_valid_src, pair_valid_tgt, 79 | add_to_dict=False, add_bos=add_bos, add_eos=add_eos) 80 | 81 | self.paired_valid, self.paired_lineno_valid = self._minibatchify_pair(pair_sents_valid, pair_sk_sents_valid, 82 | pair_src_feats_valid, bsz) 83 | self.raw_valid = None 84 | del pair_sents_valid, pair_sk_sents_valid, pair_src_feats_valid 85 | 86 | def make_vocab(self, pair_tgt, pair_src, raw_tgt, max_count=50000): 87 | self.word_cnt = Counter() 88 | genwords, value_vocab = self.get_vocab_from_paired(pair_tgt, pair_src) 89 | raw_vocab = self.get_vocab_from_raw(raw_tgt) # just to update self.word_cnt 90 | self.genset = set(genwords.keys()) 91 | tgtkeys = list(self.word_cnt.keys()) 92 | # make sure gen stuff is first 93 | tgtkeys.sort(key=lambda x: -(x in self.genset)) # in genset first 94 | voc = tgtkeys[:max_count] 95 | self.dictionary.bulk_add(voc) 96 | self.value_dict.bulk_add(list([i for i in value_vocab.keys() if i in voc])) 97 | self.gen_vocab.bulk_add(list([i for i in genwords.keys() if i in voc])) 98 | # make sure we did everything right (assuming didn't encounter any special tokens) 99 | assert self.dictionary.idx2word[5 + len(self.genset) - 1] in self.genset 100 | assert self.dictionary.idx2word[5 + len(self.genset)] not in self.genset 101 | self.dictionary.add_word("", train=True) 102 | self.dictionary.add_word("", train=True) 103 | 104 | def get_vocab_from_paired(self, tgt_path, src_path): 105 | assert os.path.exists(tgt_path) 106 | linewords = [] 107 | with open(src_path, 'r') as f: 108 | for line in f: 109 | fields = get_spnlg_field(line.strip()) # key, pos -> word 110 | fieldvals = fields.values() 111 | self.word_cnt.update(fieldvals) 112 | linewords.append(set(wrd for wrd in fieldvals)) 113 | self.word_cnt.update([k for k, idx in fields]) 114 | self.word_cnt.update([idx for k, idx in fields]) 115 | 116 | genwords = Counter() # a Counter that records all the vocab in target 117 | value_words = Counter() 118 | with open(tgt_path, 'r') as f: 119 | for l, line in enumerate(f): 120 | words = word_tokenize(line.strip()) 121 | genwords.update([wrd for wrd in words if wrd not in linewords[l]]) 122 | value_words.update([wrd for wrd in words if wrd in linewords[l]]) 123 | self.word_cnt.update(words) 124 | return genwords, value_words 125 | 126 | def get_vocab_from_raw(self, path): 127 | assert os.path.exists(path) 128 | raw_vocab = Counter() 129 | with open(path, 'r') as f: 130 | for l, line in enumerate(f): 131 | words = word_tokenize(line.strip()) 132 | self.word_cnt.update(words) 133 | raw_vocab.update(words) 134 | return raw_vocab 135 | 136 | def get_test_data(self, table_path): 137 | w2i = self.dictionary.word2idx 138 | src_feats = [] 139 | original_feats = [] 140 | with open(table_path, 'r') as f: 141 | for line in f: 142 | feats = [] 143 | orig = [] 144 | fields = get_spnlg_field(line.strip()) # (key, pos) -> word 145 | for (key, pos), wrd in fields.items(): 146 | if key in w2i: 147 | featrow = [self.dictionary.add_word(key, False), 148 | self.dictionary.add_word(pos, False), 149 | self.dictionary.add_word(wrd, False)] 150 | feats.append(featrow) 151 | orig.append((key, pos, wrd)) 152 | src_feats.append(feats) 153 | original_feats.append(orig) 154 | 155 | src_feat_batches = [] 156 | line_no_tst = [] 157 | for i in range(len(src_feats)): 158 | # src = torch.LongTensor(src_feats[i]).unsqueeze(0) # 1 x nfield x 3 159 | # src_feat_batches.append(src) 160 | src_feat_batches.append(self._pad_srcfeat([src_feats[i]])) 161 | line_no_tst.append([i]) 162 | 163 | return src_feat_batches, original_feats, line_no_tst 164 | 165 | def get_raw_temp(self, raw_fn_in, fn_out=None, num=5, seed=1): 166 | np.random.seed(seed) # define random seed for select certain sentence 167 | with open(raw_fn_in, 'r') as f: 168 | all_contents = f.read().strip().split('\n') 169 | select_num = np.random.randint(0, len(all_contents) - 1, (num,)) 170 | 171 | if fn_out is not None: 172 | with open(fn_out, 'w') as fout: 173 | for i in select_num: 174 | fout.write(all_contents[i] + '\n') 175 | 176 | all_raw_tmps = [] 177 | w2i = self.dictionary.word2idx 178 | for i in select_num: 179 | line = all_contents[i] 180 | words = word_tokenize(line.strip(())) 181 | token = [] 182 | for word in words: 183 | if word in w2i: 184 | token.append(w2i[word]) 185 | else: 186 | token.append(w2i['']) 187 | # token list to tensor 188 | token = torch.LongTensor(token) 189 | all_raw_tmps.append(token) 190 | return all_raw_tmps 191 | 192 | def get_raw_temp_from_file(self, fn_in): 193 | all_raw_tmps = [] 194 | with open(fn_in, 'r') as f: 195 | for line in f: 196 | words = word_tokenize(line.strip()) 197 | token = [] 198 | for word in words: 199 | if word in self.dictionary.word2idx: 200 | token.append(self.dictionary.word2idx[word]) 201 | else: 202 | token.append(self.dictionary.word2idx['']) 203 | token = torch.LongTensor(token) 204 | all_raw_tmps.append(token) 205 | return all_raw_tmps 206 | 207 | def load_paired_data(self, table_path, text_path, add_to_dict=False, add_bos=False, add_eos=False): 208 | w2i = self.dictionary.word2idx 209 | sents = [] 210 | sk_sents = [] 211 | raw_sentences = [] 212 | src_feats = [] 213 | linewords = [] 214 | with open(table_path, 'r') as f: 215 | for line in f: 216 | fields = get_spnlg_field(line.strip()) 217 | feats = [] 218 | linewords.append(set(fields.values())) 219 | for (key, pos), wrd in fields.items(): 220 | if key in w2i: 221 | featrow = [self.dictionary.add_word(key, add_to_dict), 222 | self.dictionary.add_word(pos, add_to_dict), 223 | self.dictionary.add_word(wrd, False)] # value can not update, but key can 224 | feats.append(featrow) 225 | src_feats.append(feats) 226 | 227 | with open(text_path, 'r') as f: 228 | for l, line in enumerate(f): 229 | words = word_tokenize(line.strip()) 230 | raw_sentences.append(words) 231 | token = [] 232 | sk_tokens = [] 233 | if add_bos: 234 | token.append(self.dictionary.add_word('', True)) 235 | sk_tokens.append(self.dictionary.add_word('', True)) 236 | 237 | for word in words: 238 | if word in w2i: 239 | token.append(w2i[word]) 240 | else: 241 | token.append(w2i['']) 242 | 243 | if word in linewords[l]: 244 | if len(sk_tokens) == 0 or (len(sk_tokens) == 1 and sk_tokens[0] == w2i['']): # first word 245 | sk_tokens.append(self.dictionary.word2idx['']) 246 | elif sk_tokens[-1] == self.dictionary.word2idx['']: 247 | continue 248 | else: 249 | sk_tokens.append(self.dictionary.word2idx['']) 250 | else: # ordinary word 251 | if word in w2i: 252 | sk_tokens.append(w2i[word]) 253 | else: 254 | sk_tokens.append(w2i['']) 255 | if add_eos: 256 | token.append(self.dictionary.add_word('', True)) 257 | sk_tokens.append(self.dictionary.add_word('', True)) 258 | sents.append(token) 259 | sk_sents.append(sk_tokens) 260 | del token, sk_tokens 261 | assert len(raw_sentences) == len(src_feats) 262 | 263 | return sents, sk_sents, src_feats 264 | 265 | def load_raw_data(self, text_path, add_bos=False, add_eos=False): 266 | w2i = self.dictionary.word2idx 267 | sents = [] 268 | with open(text_path, 'r') as f: 269 | for line in f: 270 | words = word_tokenize(line.strip()) 271 | token = [] 272 | if add_bos: 273 | token.append(self.dictionary.add_word('', True)) 274 | for word in words: 275 | if word in w2i: 276 | token.append(w2i[word]) 277 | else: 278 | token.append(w2i['']) 279 | if add_eos: 280 | token.append(self.dictionary.add_word('', True)) 281 | sents.append(token) 282 | return sents 283 | 284 | def _pad_sequence(self, sequence): 285 | # sequence: b x seq 286 | max_row = max(len(i) for i in sequence) 287 | for item in sequence: 288 | if len(item) < max_row: 289 | item.extend([self.dictionary.word2idx[""]] * (max_row - len(item))) 290 | return torch.LongTensor(sequence) 291 | 292 | def _pad_srcfeat(self, curr_feats): 293 | # return a b x nfield(max) x 3 294 | max_rows = max(len(feats) for feats in curr_feats) 295 | nfeats = len(curr_feats[0][0]) 296 | for feats in curr_feats: 297 | if len(feats) < max_rows: 298 | [feats.append([self.dictionary.word2idx[""] for _ in range(nfeats)]) 299 | for _ in range(max_rows - len(feats))] 300 | return torch.LongTensor(curr_feats) 301 | 302 | # def _pad_loc(self, curr_locs): 303 | # """ 304 | # curr_locs is a bsz-len list of tgt-len list of locations 305 | # returns: 306 | # a seqlen x bsz x max_locs tensor 307 | # """ 308 | # max_locs = max(len(locs) for blocs in curr_locs for locs in blocs) 309 | # max_seq = max(len(blocs) for blocs in curr_locs) 310 | # for blocs in curr_locs: 311 | # for locs in blocs: 312 | # if len(locs) < max_locs: 313 | # locs.extend([-1] * (max_locs - len(locs))) 314 | # if len(blocs) < max_seq: 315 | # blocs.extend([[-1] * max_locs] * (max_seq - len(blocs))) 316 | # return torch.LongTensor(curr_locs).transpose(0, 1).contiguous() 317 | 318 | # def _pad_inp(self, curr_inps): 319 | # """ 320 | # curr_inps is a bsz-len list of seqlen-len list of nlocs-len list of features 321 | # returns: 322 | # a bsz x seqlen x max_nlocs x nfeats tensor 323 | # """ 324 | # max_locs = max(len(feats) for seq in curr_inps for feats in seq) 325 | # max_seq = max(len(seq) for seq in curr_inps) 326 | # nfeats = len(curr_inps[0][0][0]) # default: 3 327 | # for seq in curr_inps: 328 | # for feats in seq: 329 | # if len(feats) < max_locs: 330 | # randidxs = [random.randint(0, len(feats) - 1) for _ in 331 | # range(max_locs - len(feats))] # random from on of feat 332 | # [feats.append(feats[ridx]) for ridx in randidxs] 333 | # 334 | # if len(seq) < max_seq: 335 | # seq.extend([seq[-1] * (max_seq - len(seq))]) 336 | # return torch.LongTensor(curr_inps) 337 | 338 | def _minibatchify_pair(self, sents, sk_sents, src_feats, bsz): 339 | 340 | sents, sorted_idxs = zip( 341 | *sorted(zip(sents, range(len(sents))), key=lambda x: len(x[0]))) # from shortest to longest 342 | minibatches, mb2linenos = [], [] 343 | curr_sent, curr_sk, curr_len, curr_srcfeat = [], [], [], [] 344 | curr_line = [] 345 | 346 | for i in range(len(sents)): 347 | if len(curr_sent) == bsz: # one batch is done! 348 | minibatches.append((self._pad_sequence(curr_sent).t().contiguous(), 349 | self._pad_sequence(curr_sk).t().contiguous(), 350 | torch.IntTensor(curr_len), 351 | self._pad_srcfeat(curr_srcfeat))) 352 | 353 | mb2linenos.append(curr_line) 354 | # init 355 | curr_line = [sorted_idxs[i]] 356 | curr_sent, curr_len = [sents[i]], [len(sents[i])] 357 | curr_sk = [sk_sents[sorted_idxs[i]]] 358 | curr_srcfeat = [src_feats[sorted_idxs[i]]] 359 | 360 | else: 361 | curr_sent.append(sents[i]) 362 | curr_len.append(len(sents[i])) 363 | curr_line.append(sorted_idxs[i]) 364 | curr_sk.append(sk_sents[sorted_idxs[i]]) 365 | curr_srcfeat.append(src_feats[sorted_idxs[i]]) 366 | 367 | if len(curr_sent) > 0: # last 368 | minibatches.append((self._pad_sequence(curr_sent).t().contiguous(), 369 | self._pad_sequence(curr_sk).t().contiguous(), 370 | torch.IntTensor(curr_len), 371 | self._pad_srcfeat(curr_srcfeat))) 372 | mb2linenos.append(curr_line) 373 | 374 | return minibatches, mb2linenos 375 | 376 | def _minibatchify_raw(self, sents, bsz): 377 | sents, sorted_idxs = zip( 378 | *sorted(zip(sents, range(len(sents))), key=lambda x: len(x[0]))) # from shortest to longest 379 | minibatches, mb2linenos = [], [] 380 | curr_sent, curr_len, curr_line = [], [], [] 381 | for i in range(len(sents)): 382 | if len(curr_sent) == bsz: # one batch is done! 383 | minibatches.append((self._pad_sequence(curr_sent).t().contiguous(), 384 | torch.IntTensor(curr_len))) 385 | mb2linenos.append(curr_line) 386 | # init 387 | curr_sent, curr_line, curr_len = [sents[i]], [sorted_idxs[i]], [len(sents[i])] 388 | else: 389 | curr_sent.append(sents[i]) 390 | curr_len.append(len(sents[i])) 391 | # curr_bow.append(sent_bow[sorted_idxs[i]]) 392 | curr_line.append(sorted_idxs[i]) 393 | 394 | if len(curr_sent) > 0: # last 395 | minibatches.append((self._pad_sequence(curr_sent).t().contiguous(), 396 | torch.IntTensor(curr_len))) 397 | mb2linenos.append(curr_line) 398 | 399 | return minibatches, mb2linenos 400 | 401 | 402 | if __name__ == "__main__": 403 | pass 404 | -------------------------------------------------------------------------------- /data_utils/label_wiki.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from collections import Counter 4 | import numpy as np 5 | 6 | 7 | def get_wiki_poswrds(tokes): 8 | """(key, num) -> word""" 9 | fields = {} 10 | for toke in tokes: 11 | try: 12 | fullkey, val = toke.split(':') 13 | except ValueError: 14 | ugh = toke.split(':') # must be colons in the val 15 | fullkey = ugh[0] 16 | val = ''.join(ugh[1:]) 17 | if val == "": 18 | continue 19 | keypieces = fullkey.split('_') 20 | if len(keypieces) == 1: 21 | key = fullkey 22 | keynum = 1 23 | else: 24 | keynum = int(keypieces[-1]) 25 | key = '_'.join(keypieces[:-1]) 26 | fields[key, keynum] = val 27 | return fields 28 | 29 | class Dictionary(object): 30 | def __init__(self, unk_word=""): 31 | self.unk_word = unk_word 32 | self.idx2word = [unk_word, "", "", "", ""] # OpenNMT constants 33 | self.word2idx = {word: i for i, word in enumerate(self.idx2word)} 34 | 35 | def add_word(self, word, train=False): 36 | """add extra word, returns idx of word 37 | :param word: a str 38 | :param train: bool, if true, then update self.idx2word and w2i; if false, just update w2i 39 | """ 40 | if train and word not in self.word2idx: 41 | self.idx2word.append(word) 42 | self.word2idx[word] = len(self.idx2word) - 1 43 | return self.word2idx[word] if word in self.word2idx else self.word2idx[self.unk_word] 44 | 45 | def bulk_add(self, words): 46 | """add lots of words, assumes train=True 47 | :param words: a list of words 48 | """ 49 | self.idx2word.extend(words) 50 | self.word2idx = {word: i for i, word in enumerate(self.idx2word)} 51 | 52 | def __len__(self): 53 | return len(self.idx2word) 54 | 55 | 56 | class Corpus(object): 57 | def __init__(self, path, bsz, max_count=50000, add_bos=False, add_eos=False): 58 | self.dictionary = Dictionary() 59 | self.value_dict = Dictionary() 60 | self.path = path 61 | pair_train_src = os.path.join(path, "pair_src.train") 62 | pair_valid_src = os.path.join(path, "pair_src.valid") 63 | pair_train_tgt = os.path.join(path, "pair_tgt.train") 64 | pair_valid_tgt = os.path.join(path, "pair_tgt.valid") 65 | 66 | # if not ner_fake: 67 | # pair_train_src = os.path.join(path, "pair_src_train.txt") 68 | # pair_valid_src = os.path.join(path, "pair_src_valid.txt") 69 | # pair_train_text = os.path.join(path, "pair_train.txt") 70 | # pair_valid_text = os.path.join(path, "pair_valid.txt") 71 | # else: 72 | # pair_train_src = os.path.join(path, "pair_src_train_all_include.txt") 73 | # pair_valid_src = os.path.join(path, "pair_src_valid.txt") 74 | # pair_train_text = os.path.join(path, "pair_train_all_include.txt") 75 | # pair_valid_text = os.path.join(path, "pair_valid.txt") 76 | 77 | raw_train_text = os.path.join(path, "raw_tgt.train") 78 | raw_valid_text = os.path.join(path, "raw_tgt.valid") 79 | 80 | self.gen_vocab = Dictionary() 81 | self.make_vocab(pair_train_tgt, pair_train_src, raw_train_text, max_count=max_count) 82 | self.genset.add("") 83 | 84 | # load training data 85 | pair_sents_train, pair_sk_sents_train, \ 86 | pair_src_feats_train = self.load_paired_data(pair_train_src, pair_train_tgt, 87 | add_to_dict=False, add_bos=add_bos, add_eos=add_eos) 88 | raw_sents_train = self.load_raw_data(raw_train_text) 89 | 90 | self.paired_train, _ = self._minibatchify_pair(pair_sents_train, pair_sk_sents_train, pair_src_feats_train, bsz) 91 | self.raw_train, _ = self._minibatchify_raw(raw_sents_train, bsz) 92 | del pair_sents_train, pair_sk_sents_train, pair_src_feats_train, raw_sents_train 93 | 94 | # load valid data 95 | pair_sents_valid, pair_sk_sents_valid, \ 96 | pair_src_feats_valid = self.load_paired_data(pair_valid_src, pair_valid_tgt, 97 | add_to_dict=False, add_bos=add_bos, add_eos=add_eos) 98 | raw_sents_valid = self.load_raw_data(raw_valid_text) 99 | self.paired_valid, self.paired_lineno_valid = self._minibatchify_pair(pair_sents_valid, pair_sk_sents_valid, 100 | pair_src_feats_valid, bsz) 101 | if len(raw_sents_valid) == 0: 102 | self.raw_valid = None 103 | else: 104 | self.raw_valid, _ = self._minibatchify_raw(raw_sents_valid, bsz) 105 | 106 | del pair_sents_valid, pair_sk_sents_valid, pair_src_feats_valid, raw_sents_valid 107 | 108 | def make_vocab(self, pair_tgt, pair_src, raw_tgt, max_count=50000): 109 | self.word_cnt = Counter() 110 | genwords, value_vocab = self.get_vocab_from_paired(pair_tgt, pair_src) 111 | raw_vocab = self.get_vocab_from_raw(raw_tgt) # just to update self.word_cnt 112 | self.genset = set(genwords.keys()) 113 | tgtkeys = list(self.word_cnt.keys()) 114 | tgtkeys.sort(key=lambda x: -(x in self.genset)) # add genset first 115 | voc = tgtkeys[:max_count] 116 | self.dictionary.bulk_add(voc) 117 | self.value_dict.bulk_add(list([i for i in value_vocab.keys() if i in voc])) 118 | self.gen_vocab.bulk_add(list([i for i in genwords.keys() if i in voc])) 119 | # make sure we did everything right (assuming didn't encounter any special tokens) 120 | assert self.dictionary.idx2word[5 + len(self.genset) - 1] in self.genset 121 | assert self.dictionary.idx2word[5 + len(self.genset)] not in self.genset 122 | self.dictionary.add_word("", train=True) 123 | self.dictionary.add_word("", train=True) 124 | 125 | def get_vocab_from_paired(self, path, src_path): 126 | assert os.path.exists(path) 127 | linewords = [] 128 | with open(src_path, 'r') as f: 129 | for line in f: 130 | tokes = line.strip().split() 131 | fields = get_wiki_poswrds(tokes) # key, pos -> wrd 132 | fieldvals = fields.values() 133 | self.word_cnt.update(fieldvals) 134 | linewords.append(set(wrd for wrd in fieldvals)) 135 | self.word_cnt.update([k for k, idx in fields]) 136 | self.word_cnt.update([idx for k, idx in fields]) 137 | 138 | genwords = Counter() # a Counter that records all the vocab in target 139 | value_words = Counter() 140 | with open(path, 'r') as f: 141 | for l, line in enumerate(f): 142 | words, spanlabels = line.strip().split('|||') 143 | words = words.split() 144 | genwords.update([wrd for wrd in words if wrd not in linewords[l]]) 145 | value_words.update([wrd for wrd in words if wrd in linewords[l]]) 146 | self.word_cnt.update(words) 147 | 148 | genwords = {k: v for k,v in genwords.items() if v > 5} 149 | 150 | return genwords, value_words 151 | 152 | def get_vocab_from_raw(self, path): 153 | assert os.path.exists(path) 154 | raw_vocab = Counter() 155 | with open(path, 'r') as f: 156 | for l, line in enumerate(f): 157 | words = line.strip().split() 158 | self.word_cnt.update(words) 159 | raw_vocab.update(words) 160 | return raw_vocab 161 | 162 | # sent_voc = {k: v for k, v in wrdcnt.items() if v > thresh} 163 | # # rare_words = {k: v for k, v in wrdcnt.items() if v > value_thresh1 and v<10} # set rare words as value words 164 | # # rare_words = {k:v for k,v in wrdcnt.items() if v==1} 165 | # self.genset.update(set(sent_voc.keys())) 166 | # self.gen_vocab.bulk_add(list(sent_voc.keys())) 167 | # self.dictionary.bulk_add(list(sent_voc.keys())) 168 | # # self.value_dict.bulk_add(list(rare_words.keys())) 169 | # del sent_voc 170 | 171 | def get_test_data(self, table_path): 172 | w2i = self.dictionary.word2idx 173 | src_feats = [] 174 | original_feats = [] 175 | with open(table_path, 'r') as f: 176 | for line in f: 177 | feats = [] 178 | orig = [] 179 | items = line.strip().split() 180 | fields = get_wiki_poswrds(items) # (key, pos) -> wrd 181 | for (key, pos), wrd in fields.items(): 182 | if key in w2i: 183 | featrow = [self.dictionary.add_word(key, False), 184 | self.dictionary.add_word(pos, False), 185 | self.dictionary.add_word(wrd, False)] 186 | feats.append(featrow) 187 | orig.append((key, pos, wrd)) 188 | src_feats.append(feats) 189 | original_feats.append(orig) 190 | 191 | src_feat_batches = [] 192 | line_no_tst = [] 193 | for i in range(len(src_feats)): 194 | # src = torch.LongTensor(src_feats[i]).unsqueeze(0) # 1 x nfield x 3 195 | # src_feat_batches.append(src) 196 | src_feat_batches.append(self._pad_srcfeat([src_feats[i]])) 197 | line_no_tst.append([i]) 198 | 199 | return src_feat_batches, original_feats, line_no_tst 200 | 201 | def get_raw_temp(self, raw_fn_in, fn_out=None, num=5, seed=1): 202 | np.random.seed(seed) # define random seed for select certain sentence 203 | with open(raw_fn_in,'r') as f: 204 | all_contents = f.read().strip().split('\n') 205 | select_num = np.random.randint(0, len(all_contents)-1, (num,)) 206 | 207 | if fn_out is not None: 208 | with open(fn_out, 'w') as fout: 209 | for i in select_num: 210 | if '|||' in all_contents[i]: 211 | words, labels = all_contents[i].strip().split('|||') 212 | fout.write(words + '\n') 213 | else: 214 | fout.write(all_contents[i] + '\n') 215 | 216 | all_raw_tmps = [] 217 | w2i = self.dictionary.word2idx 218 | for i in select_num: 219 | line = all_contents[i] 220 | if '|||' in line: # for paired data 221 | words, labels = line.strip().split('|||') 222 | words = words.split() 223 | else: # for raw data 224 | words = line.strip().split() 225 | token = [] 226 | for word in words: 227 | if word in w2i: 228 | token.append(w2i[word]) 229 | else: 230 | token.append(w2i['']) 231 | # token list to tensor 232 | token = torch.LongTensor(token) 233 | all_raw_tmps.append(token) 234 | return all_raw_tmps 235 | 236 | def get_raw_temp_from_file(self, fn_in): 237 | all_raw_tmps = [] 238 | with open(fn_in, 'r') as f: 239 | for line in f: 240 | words = line.strip().split() 241 | token = [] 242 | for word in words: 243 | if word in self.dictionary.word2idx: 244 | token.append(self.dictionary.word2idx[word]) 245 | else: 246 | token.append(self.dictionary.word2idx['']) 247 | token = torch.LongTensor(token) 248 | all_raw_tmps.append(token) 249 | return all_raw_tmps 250 | 251 | def load_paired_data(self, table_path, text_path, add_to_dict=False, add_bos=False, add_eos=False): 252 | w2i = self.dictionary.word2idx 253 | sents = [] 254 | sk_sents = [] 255 | raw_sentences = [] 256 | src_feats = [] 257 | linewords = [] 258 | with open(table_path, 'r') as f: 259 | for line in f: 260 | items = line.strip().split() 261 | fields = get_wiki_poswrds(items) # dict: (key, pos) -> word 262 | feats = [] 263 | linewords.append(set(fields.values())) 264 | for (key, pos), wrd in fields.items(): 265 | if key in w2i: 266 | featrow = [self.dictionary.add_word(key, add_to_dict), 267 | self.dictionary.add_word(pos, add_to_dict), 268 | self.dictionary.add_word(wrd, False)] # word can not update, but key can 269 | feats.append(featrow) 270 | src_feats.append(feats) 271 | 272 | with open(text_path, 'r') as f: 273 | for l, line in enumerate(f): 274 | words, labels = line.strip().split('|||') 275 | words = words.split() 276 | raw_sentences.append(words) 277 | token = [] 278 | sk_tokens = [] 279 | if add_bos: 280 | token.append(self.dictionary.add_word('', True)) 281 | sk_tokens.append(self.dictionary.add_word('', True)) 282 | 283 | for word in words: 284 | if word in w2i: 285 | token.append(w2i[word]) 286 | else: 287 | token.append(w2i['']) 288 | 289 | if word in linewords[l]: 290 | if len(sk_tokens) == 0 or (len(sk_tokens) == 1 and sk_tokens[0] == w2i['']): # first word 291 | sk_tokens.append(self.dictionary.word2idx['']) 292 | elif sk_tokens[-1] == self.dictionary.word2idx['']: 293 | continue 294 | else: 295 | sk_tokens.append(self.dictionary.word2idx['']) 296 | else: # ordinary word 297 | if word in w2i: 298 | sk_tokens.append(w2i[word]) 299 | else: 300 | sk_tokens.append(w2i['']) 301 | if add_eos: 302 | token.append(self.dictionary.add_word('', True)) 303 | sk_tokens.append(self.dictionary.add_word('', True)) 304 | sents.append(token) 305 | sk_sents.append(sk_tokens) 306 | del token, sk_tokens 307 | assert len(raw_sentences) == len(src_feats) 308 | return sents, sk_sents, src_feats 309 | 310 | def load_raw_data(self, text_path, add_bos=False, add_eos=False): 311 | w2i = self.dictionary.word2idx 312 | sents = [] 313 | with open(text_path, 'r') as f: 314 | for line in f: 315 | words = line.strip().split() 316 | token = [] 317 | if add_bos: 318 | token.append(self.dictionary.add_word('', True)) 319 | for word in words: 320 | if word in w2i: 321 | token.append(w2i[word]) 322 | else: 323 | token.append(w2i['']) 324 | 325 | if add_eos: 326 | token.append(self.dictionary.add_word('', True)) 327 | sents.append(token) 328 | return sents 329 | 330 | def _pad_sequence(self, sequence, ner=False): 331 | # sequence: b x seq 332 | max_row = max(len(i) for i in sequence) 333 | for item in sequence: 334 | if len(item) < max_row: 335 | if ner: 336 | item.extend([self.tag_dict.word2idx['']]*(max_row-len(item))) 337 | else: 338 | item.extend([self.dictionary.word2idx[""]]*(max_row-len(item))) 339 | return torch.LongTensor(sequence) 340 | 341 | def _pad_srcfeat(self, curr_feats): 342 | # return a b x nfield(max) x 3 343 | max_rows = max(len(feats) for feats in curr_feats) 344 | nfeats = len(curr_feats[0][0]) 345 | for feats in curr_feats: 346 | if len(feats) < max_rows: 347 | [feats.append([self.dictionary.word2idx[""] for _ in range(nfeats)]) 348 | for _ in range(max_rows - len(feats))] 349 | return torch.LongTensor(curr_feats) 350 | 351 | def _minibatchify_pair(self, sents, sk_sents, src_feats, bsz): 352 | sents, sorted_idxs = zip( 353 | *sorted(zip(sents, range(len(sents))), key=lambda x: len(x[0]))) # from shortest to longest 354 | minibatches, mb2linenos = [], [] 355 | curr_sent, curr_sk, curr_len, curr_srcfeat = [], [], [], [] 356 | curr_line = [] 357 | 358 | for i in range(len(sents)): 359 | if len(curr_sent) == bsz: # one batch is done! 360 | minibatches.append((self._pad_sequence(curr_sent).t().contiguous(), 361 | self._pad_sequence(curr_sk).t().contiguous(), 362 | torch.IntTensor(curr_len), 363 | self._pad_srcfeat(curr_srcfeat))) 364 | mb2linenos.append(curr_line) 365 | # init 366 | curr_line = [sorted_idxs[i]] 367 | curr_sent, curr_len = [sents[i]], [len(sents[i])] 368 | curr_sk = [sk_sents[sorted_idxs[i]]] 369 | curr_srcfeat = [src_feats[sorted_idxs[i]]] 370 | else: 371 | curr_sent.append(sents[i]) 372 | curr_len.append(len(sents[i])) 373 | curr_line.append(sorted_idxs[i]) 374 | curr_sk.append(sk_sents[sorted_idxs[i]]) 375 | curr_srcfeat.append(src_feats[sorted_idxs[i]]) 376 | 377 | if len(curr_sent) > 0: # last 378 | minibatches.append((self._pad_sequence(curr_sent).t().contiguous(), 379 | self._pad_sequence(curr_sk).t().contiguous(), 380 | torch.IntTensor(curr_len), 381 | self._pad_srcfeat(curr_srcfeat))) 382 | mb2linenos.append(curr_line) 383 | 384 | return minibatches, mb2linenos 385 | 386 | def _minibatchify_raw(self, sents, bsz): 387 | sents, sorted_idxs = zip( 388 | *sorted(zip(sents, range(len(sents))), key=lambda x: len(x[0]))) # from shortest to longest 389 | minibatches, mb2linenos = [], [] 390 | curr_sent, curr_len, curr_line = [], [], [] 391 | for i in range(len(sents)): 392 | if len(curr_sent) == bsz: # one batch is done! 393 | minibatches.append((self._pad_sequence(curr_sent).t().contiguous(), 394 | torch.IntTensor(curr_len))) 395 | mb2linenos.append(curr_line) 396 | # init 397 | curr_sent, curr_line, curr_len = [sents[i]], [sorted_idxs[i]], [len(sents[i])] 398 | else: 399 | curr_sent.append(sents[i]) 400 | curr_len.append(len(sents[i])) 401 | # curr_bow.append(sent_bow[sorted_idxs[i]]) 402 | curr_line.append(sorted_idxs[i]) 403 | 404 | if len(curr_sent) > 0: # last 405 | minibatches.append((self._pad_sequence(curr_sent).t().contiguous(), 406 | torch.IntTensor(curr_len))) 407 | mb2linenos.append(curr_line) 408 | 409 | return minibatches, mb2linenos 410 | 411 | 412 | if __name__ == "__main__": 413 | # table = "name_1:walter name_2:extra image: image_size: caption: birth_name: birth_date_1:1954 birth_place: death_date: death_place: death_cause: resting_place: resting_place_coordinates: residence: nationality_1:german ethnicity: citizenship: other_names: known_for: education: alma_mater: employer: occupation_1:aircraft occupation_2:designer occupation_3:and occupation_4:manufacturer home_town: title: salary: networth: height: weight: term: predecessor: successor: party: boards: religion: spouse: partner: children: parents: relations: signature: website: footnotes: article_title_1:walter article_title_2:extra" 414 | # sentence = "walter extra is a german award-winning aerobatic pilot , chief aircraft designer and founder of extra flugzeugbau -lrb- extra aircraft construction -rrb- , a manufacturer of aerobatic aircraft . " 415 | 416 | # table = "image: name_1:carlene name_2:m. name_3:walker caption: state_senate_1:utah term_start_1:january term_start_2:15 term_start_3:, term_start_4:2001 term_end_1:present predecessor_1:scott predecessor_2:n. predecessor_3:howell successor_1:karen successor_2:morgan district_1:8th birth_date_1:2 birth_date_2:september birth_date_3:1947 birth_place_1:san birth_place_2:francisco birth_place_3:, birth_place_4:ca party_1:republican party_2:party spouse_1:gordon residence_1:salt residence_2:lake residence_3:city residence_4:, residence_5:ut occupation_1:businesswoman religion_1:latter-day religion_2:saint website_1:-lsb- website_2:http://www.utahsenate.org/perl/spage/distbio2007.pl?dist8 website_3:legislative website_4:website website_5:-rsb- article_title_1:carlene article_title_2:m. article_title_3:walker" 417 | # sentence = "carlene m. walker is an american politician and businesswoman from utah . " 418 | 419 | # table = "name_1:bundit name_2:ungrangsee image_1:bundit image_2:ungrangsee image_3:19072007 image_4:bkkiff.jpg image_size_1:200px caption_1:bundit caption_2:attends caption_3:the caption_4:opening caption_5:party caption_6:for caption_7:the caption_8:2007 caption_9:bangkok caption_10:international caption_11:film caption_12:festival caption_13:. birth_date_1:7 birth_date_2:december birth_date_3:1970 birth_place_1:thailand death_date: death_place: education_1:university education_2:of education_3:michigan occupation_1:conductor title: spouse_1:mary spouse_2:ungrangsee parents: children: nationality_1:thai website_1:-lsb- website_2:http://www.bundit.org website_3:www.bundit.org website_4:-rsb- article_title_1:bundit article_title_2:ungrangsee" 420 | # sentence = "bundit ungrangsee -lrb- ; , born december 7 , 1970 -rrb- is an international symphonic conductor . " 421 | 422 | # table = "name_1:jole name_2:fierro image_1:jole image_2:fierro.jpg image_size: caption: birth_name: birth_date_1:22 birth_date_2:november birth_date_3:1926 birth_place_1:salerno birth_place_2:, birth_place_3:italy death_date_1:27 death_date_2:march death_date_3:1988 death_place_1:rome death_place_2:, death_place_3:italy occupation_1:actress spouse: article_title_1:jole article_title_2:fierro" 423 | # sentence = "jole fierro -lrb- 22 november 1926 - 27 march 1988 -rrb- was an italian actress . " 424 | # 425 | # table = "name_1:roberto name_2:la name_3:rocca image: imagesize: caption: nationality_1:ven nationality_2:venezuelan birth_date_1:09 birth_date_2:july birth_date_3:1992 birth_place_1:caracas birth_place_2:, birth_place_3:venezuela starts: wins: poles: year: titles_1:f2000 titles_2:championship titles_3:series article_title_1:roberto article_title_2:la article_title_3:rocca" 426 | # sentence = "roberto la rocca -lrb- born 9 july 1991 -rrb- is a venezuelan racing driver . " 427 | # table = "name_1:george name_2:ratcliffe image: country_1:england fullname_1:george fullname_2:ratcliff fullname_3:-lrb- fullname_4:e fullname_5:-rrb- height: nickname: birth_date_1:1 birth_date_2:april birth_date_3:1856 birth_place_1:ilkeston birth_place_2:, birth_place_3:derbyshire birth_place_4:, birth_place_5:england death_date_1:7 death_date_2:march death_date_3:1928 death_place_1:nottingham death_place_2:, death_place_3:england batting_1:left-handed batting_2:batsman bowling: role: family: international: testdebutdate: testdebutyear: testdebutagainst: testcap: lasttestdate: lasttestyear: lasttestagainst: club_1:derbyshire year_1:1887 year_2:& year_3:ndash year_4:; year_5:1889 type_1:first-class debutdate_1:27 debutdate_2:june debutyear_1:1887 debutfor_1:derbyshire debutagainst_1:yorkshire lastdate_1:15 lastdate_2:august lastyear_1:1887 lastfor_1:derbyshire lastagainst_1:surrey deliveries_1:balls deliveries_2:12 columns_1:1 column_1:first-class matches_1:5 runs_1:145 s/s_1:0/1 wickets_1:0 fivefor: tenfor: catches/stumpings_1:0 catches/stumpings_2:/ catches/stumpings_3:- date_1:1856 date_2:4 date_3:1 date_4:yes source_1:http://cricketarchive.com/players/32/32252/32252.html article_title_1:george article_title_2:ratcliffe article_title_3:-lrb- article_title_4:cricketer article_title_5:, article_title_6:born article_title_7:1856 article_title_8:-rrb-" 428 | # sentence = "`` another cricketer who played for derbyshire during the 1919 season was named george ratcliffe . '' " 429 | 430 | table = "name_1:paul name_2:of name_3:thebes birth_date_1:c. birth_date_2:227 birth_date_3:ad death_date_1:c. death_date_2:342 death_date_3:ad feast_day_1:february feast_day_2:9 feast_day_3:-lrb- feast_day_4:oriental feast_day_5:orthodox feast_day_6:churches feast_day_7:-rrb- group_1:note group_2:-rcb- group_3:-rcb- venerated_in_1:oriental venerated_in_2:orthodox venerated_in_3:churches venerated_in_4:catholic venerated_in_5:church venerated_in_6:eastern venerated_in_7:orthodox venerated_in_8:church image_1:anba_bola_1 image_2:. image_3:gif caption_1:'' caption_2:saint caption_3:paul caption_4:, caption_5:` caption_6:the caption_7:first caption_8:hermit caption_9:' caption_10:'' birth_place_1:egypt death_place_1:monastery death_place_2:of death_place_3:saint death_place_4:paul death_place_5:the death_place_6:anchorite death_place_7:, death_place_8:egypt titles_1:the titles_2:first titles_3:hermit beatified_date: beatified_place: beatified_by: canonized_date: canonized_place: canonized_by: attributes_1:two attributes_2:lions attributes_3:, attributes_4:palm attributes_5:tree attributes_6:, attributes_7:raven patronage: major_shrine_1:monastery major_shrine_2:of major_shrine_3:saint major_shrine_4:paul major_shrine_5:the major_shrine_6:anchorite major_shrine_7:, major_shrine_8:egypt suppressed_date: issues: article_title_1:paul article_title_2:of article_title_3:thebes" 431 | sentence = "paul of thebes , commonly known as paul , the first hermit or paul the anchorite -lrb- d. c. 341 -rrb- is regarded as the first christian hermit . " 432 | 433 | table = table.strip().split() 434 | fields = get_wiki_poswrds(table) # key, pos -> wrd 435 | line_words = set(fields.values()) 436 | words = sentence.strip().split() 437 | sk_words = [] 438 | for word in words: 439 | if word in line_words: 440 | if len(sk_words) == 0: 441 | sk_words.append('') 442 | elif sk_words[-1] == '': 443 | continue 444 | else: 445 | sk_words.append('') 446 | else: 447 | sk_words.append(word) 448 | print("skeleton: {}".format(" ".join(sk_words))) 449 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | from collections import defaultdict 4 | import numpy as np 5 | import torch 6 | import torch.optim as optim 7 | from torch.autograd import Variable 8 | from data_utils import label_wiki, label_spnlg 9 | from models.variational_template_machine import VariationalTemplateMachine 10 | import random 11 | 12 | def set_optimizer(net): 13 | if args.optim == "adagrad": 14 | optalg = optim.Adagrad(filter(lambda p: p.requires_grad, net.parameters()), lr=args.lr) 15 | for group in optalg.param_groups: 16 | for p in group['params']: 17 | optalg.state[p]['sum'].fill_(0.1) 18 | elif args.optim == "rmsprop": 19 | optalg = optim.RMSprop(filter(lambda p: p.requires_grad, net.parameters()), lr=args.lr) 20 | elif args.optim == "adam": 21 | optalg = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=args.lr) 22 | else: 23 | optalg = None 24 | return optalg 25 | 26 | def cuda2cpu(): 27 | if args.cuda: 28 | shmocals = locals() 29 | for shk in list(shmocals): 30 | shv = shmocals[shk] 31 | if hasattr(shv, "is_cuda") and shv.is_cuda: 32 | shv = shv.cpu() 33 | 34 | def make_skeleton(sentence, sent_len, vocab, gen_dict): 35 | # sentence: seq x b LongTensor 36 | # sent_len: dim = b 37 | totseqlen, bsz = sentence.size() 38 | skeleton_sents = [] 39 | for b in range(bsz): 40 | sk = [] 41 | # sk = [vocab.word2idx[""] for i in range(sent_len[b])] 42 | for t in range(sent_len[b]): 43 | if sentence[t, b].item() in gen_dict.idx2word: 44 | sk.append(sentence[t, b].item()) 45 | else: 46 | if len(sk) == 0: # first item 47 | sk.append(vocab.word2idx['']) 48 | elif sk[-1] == vocab.word2idx['']: 49 | continue 50 | else: 51 | sk.append(vocab.word2idx['']) 52 | if len(sk) < totseqlen: 53 | sk.extend([vocab.word2idx['']]*(totseqlen-len(sk))) 54 | 55 | skeleton_sents.append(sk) 56 | return torch.LongTensor(skeleton_sents).transpose(0, 1).contiguous() # seq x b 57 | 58 | def make_masks(src, pad_idx, max_pool=False): 59 | """ 60 | src - bsz x nfields x nfeats(3) 61 | """ 62 | neginf = -1e38 63 | bsz, nfields, nfeats = src.size() 64 | fieldmask = (src.eq(pad_idx).sum(2) == nfeats) # binary bsz x nfields tensor 65 | avgmask = (1 - fieldmask).float() # 1s where not padding 66 | if not max_pool: 67 | avgmask.div_(avgmask.sum(1, True).expand(bsz, nfields)) 68 | fieldmask = fieldmask.float() * neginf # 0 where not all pad and -1e38 elsewhere 69 | return fieldmask, avgmask 70 | 71 | def make_sent_msk(sentence, pad_idx): 72 | return torch.ByteTensor(sentence != pad_idx).transpose(0, 1) 73 | 74 | 75 | parser = argparse.ArgumentParser(description='') 76 | # basic data setups 77 | parser.add_argument('-data', type=str, default='', help='path to data dir') 78 | parser.add_argument('-bsz', type=int, default=16, help='batch size') 79 | parser.add_argument('-seed', type=int, default=1111, help='set random seed, ' 80 | 'when training, it is to shuffle training batch, ' 81 | 'when testing, it is to define the latent samples') 82 | parser.add_argument('-cuda', action='store_true', help='use CUDA') 83 | parser.add_argument('-log_interval', type=int, default=200, help='minibatches to wait before logging training status') 84 | parser.add_argument('-max_vocab_cnt', type=int, default=50000) 85 | parser.add_argument('-max_seqlen', type=int, default=70, help='') 86 | 87 | # epochs 88 | parser.add_argument('-epochs', type=int, default=40, help='epochs that train together') 89 | parser.add_argument('-paired_epochs', type=int, default=10, help='epochs that train paired data') 90 | parser.add_argument('-raw_epochs', type=int, default=10, help='epochs that train raw data') 91 | parser.add_argument('-warm_up_epoch', type=int, default=0) 92 | 93 | # model saves 94 | parser.add_argument('-load', type=str, default='', help='path to saved model') 95 | parser.add_argument('-save', type=str, default='', help='path to save the model') 96 | 97 | # global setups 98 | parser.add_argument('-emb_size', type=int, default=100, help='size of word embeddings') 99 | parser.add_argument('-dropout', type=float, default=0.3, help='dropout') 100 | parser.add_argument('-drop_emb', action='store_true', help='dropout in embedding') 101 | parser.add_argument('-initrange', type=float, default=0.1, help='uniform init interval') 102 | # table encoder setups 103 | parser.add_argument('-table_hid_size', type=int, default=128, help='size of table hidden size') 104 | parser.add_argument('-pool_type', type=str, default="max", help='max/mean pooling') 105 | 106 | # sentence embedding setups 107 | parser.add_argument('-hid_size', type=int, default=128, help='size of rnn hidden state') 108 | parser.add_argument('-layers', type=int, default=1, help='num rnn layers') 109 | parser.add_argument('-sent_represent', type=str, default='last_hid', help='last_hid/seq_avg') 110 | 111 | # generator setups use attention or not 112 | parser.add_argument('-dec_attention', action='store_true', help='store attention to h when generating') 113 | 114 | # latent setups 115 | parser.add_argument('-z_latent_size', type=int, default=200, help="size of latent variable z") 116 | parser.add_argument('-c_latent_size', type=int, default=200, help="size of latent variable c") 117 | 118 | 119 | # mse loss to train q(c|x) 120 | parser.add_argument('-add_preserving_content_loss', action='store_true', help='add preserving-content loss') 121 | parser.add_argument('-pc_weight', type=float, default=1.0, help='pc loss weight') 122 | 123 | # add skeleton loss E_z~q(z|x)logp(x_sk|z) 124 | parser.add_argument('-add_preserving_template_loss', action='store_true', help='add skeleton preserving-template loss') 125 | parser.add_argument('-pt_weight', type=float, default=1.0, help='weight for preserving-template loss') 126 | 127 | # annealing tricks for latent variables z and c 128 | parser.add_argument('-anneal_function_z', type=str, default='linear', help='logistic/const/linear') 129 | parser.add_argument('-anneal_k_z', type=float, default=0.1) 130 | parser.add_argument('-anneal_x0_z', type=int, default=6000) 131 | parser.add_argument('-anneal_function_c', type=str, default='linear', help='logistic/const/linear') 132 | parser.add_argument('-anneal_k_c', type=float, default=0.1) 133 | parser.add_argument('-anneal_x0_c', type=int, default=6000) 134 | # additional bow losses (including adverserial loss and multitask loss) 135 | parser.add_argument('-add_mi_z', action='store_true',help='add mi loss for latent z') 136 | parser.add_argument('-add_mi_c', action='store_true', help='add mi loss for latent c') 137 | parser.add_argument('-mi_z_weight', type=float, default=1.0, help='mi weight for z') 138 | parser.add_argument('-mi_c_weight', type=float, default=1.0, help='mi weight for c') 139 | 140 | # loss for unlabeled raw data 141 | parser.add_argument('-rawloss_weight', type=float, default=1.0, help='weight for raw data training') 142 | 143 | # learning tricks: lr, optimizer 144 | parser.add_argument('-lr', type=float, default=0.001, help='initial learning rate') 145 | parser.add_argument('-lr_decay', type=float, default=0.5, help='learning rate decay') 146 | parser.add_argument('-optim', type=str, default="adam", help='optimization algorithm') 147 | parser.add_argument('-clip', type=float, default=5, help='gradient clipping') 148 | 149 | # decode method 150 | parser.add_argument('-decode_method', type=str, default='beam_search', help="beam_seach / temp_sample / topk_sample / nucleus_sample") 151 | parser.add_argument('-beamsz', type=int, default=1, help='beam size') 152 | parser.add_argument('-sample_temperature', type=float, default=1.0, help='set sample_temperature for decode_method=temp_sample') 153 | parser.add_argument('-topk', type=int, default=5, help='for topk_sample, if topk=1, it is greedy') 154 | parser.add_argument('-topp', type=float, default=1.0, help='for nucleus(top-p) sampleing, if topp=1, then its fwd_sample') 155 | 156 | 157 | if __name__ == "__main__": 158 | args = parser.parse_args() 159 | print(args) 160 | sys.stdout.flush() 161 | torch.manual_seed(args.seed) 162 | 163 | if torch.cuda.is_available(): 164 | if not args.cuda: 165 | print("WARNING: You have a CUDA device, so you should probably run with -cuda") 166 | sys.stdout.flush() 167 | else: 168 | torch.cuda.manual_seed_all(args.seed) 169 | else: 170 | if args.cuda: 171 | print("No CUDA device.") 172 | args.cuda = False 173 | 174 | if 'wiki' in args.data.lower(): 175 | corpus = label_wiki.Corpus(args.data, args.bsz, max_count=args.max_vocab_cnt, 176 | add_bos=False, add_eos=False) 177 | elif 'spnlg' in args.data.lower(): 178 | corpus = label_spnlg.Corpus(args.data, args.bsz, max_count=args.max_vocab_cnt, 179 | add_bos=False, add_eos=True) 180 | else: 181 | raise NotImplementedError 182 | 183 | print("data loaded!") 184 | print("total vocabulary size:", len(corpus.dictionary)) 185 | args.pad_idx = corpus.dictionary.word2idx[''] 186 | 187 | if len(args.load) > 0: 188 | print("load model ...") 189 | saved_stuff = torch.load(args.load) 190 | saved_args, saved_state = saved_stuff["opt"], saved_stuff["state_dict"] 191 | for k, v in args.__dict__.items(): 192 | if k not in saved_args.__dict__: 193 | saved_args.__dict__[k] = v 194 | if k in ["decode_method", "beamsz", "sample_temperature", "topk", "topp"]: 195 | saved_args.__dict__[k] = v 196 | net = VariationalTemplateMachine(corpus, saved_args) 197 | net.load_state_dict(saved_state, strict=False) 198 | del saved_args, saved_state, saved_stuff 199 | else: 200 | net = VariationalTemplateMachine(corpus, args) 201 | if args.cuda: 202 | net = net.cuda() 203 | 204 | trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad) 205 | print("Num of parameters in vae model: {:.2f} M. ".format(trainable_num/1000/1000)) 206 | optalg = set_optimizer(net) 207 | 208 | 209 | def train_pair(epoch): 210 | net.train() 211 | global tot_steps 212 | loss_record = defaultdict(float) 213 | nsents = 0 214 | trainperm = torch.randperm(len(corpus.paired_train)) 215 | 216 | for batch_idx in range(len(corpus.paired_train)): 217 | net.zero_grad() 218 | sentence, paired_skeleton_sent, paired_sentlen, paired_src_feat = corpus.paired_train[trainperm[batch_idx]] 219 | paired_mask, _ = make_masks(paired_src_feat, args.pad_idx) 220 | sentence_mask = make_sent_msk(sentence, args.pad_idx) 221 | 222 | # if not args.add_genbow and not args.add_vbow: 223 | # paired_genbow, paired_vbow = None, None 224 | # else: 225 | # paired_genbow, paired_vbow = make_bow(sentence, paired_sentlen, corpus.value_dict, corpus.gen_vocab) 226 | # if args.cuda: 227 | # paired_genbow, paired_vbow = paired_genbow.cuda(), paired_vbow.cuda() 228 | if args.cuda: 229 | paired_src_feat, paired_mask = paired_src_feat.cuda(), paired_mask.cuda() 230 | sentence, paired_skeleton_sent, sentence_mask = sentence.cuda(), paired_skeleton_sent.cuda(), sentence_mask.cuda() 231 | 232 | paired_table_enc = net.encode_table(Variable(paired_src_feat), Variable(paired_mask)) 233 | 234 | tot_steps += 1 235 | if hasattr(net, 'set_kl_weight'): 236 | net.set_kl_weight(tot_steps) 237 | 238 | paired_loss_dict = net.decode_pair(paired_table_enc, Variable(sentence), 239 | Variable(paired_skeleton_sent), Variable(paired_mask), 240 | Variable(sentence_mask), valid=False) 241 | loss = paired_loss_dict['pair_loss'] 242 | loss.backward() 243 | torch.nn.utils.clip_grad_norm_(net.parameters(), args.clip) 244 | if optalg is not None: 245 | optalg.step() 246 | else: 247 | for p in net.parameters(): 248 | if p.grad is not None: 249 | p.data.add_(-args.lr, p.grad.data) 250 | 251 | for l, v in paired_loss_dict.items(): 252 | loss_record[l] += v.item() 253 | nsents += 1 254 | 255 | if (batch_idx+1) % args.log_interval == 0: 256 | log_str = "batch %d/%d " % (batch_idx+1, len(corpus.paired_train)) 257 | for k, v in paired_loss_dict.items(): 258 | log_str += ("| train %s %g ") % (k, v.item()) 259 | # writer.add_scalar('Train/{}'.format(k), v.item(), tot_steps) 260 | print(log_str) 261 | sys.stdout.flush() 262 | 263 | 264 | del sentence, paired_src_feat, paired_mask, paired_sentlen, paired_loss_dict 265 | 266 | log_str = "paired epoch %d " % (epoch) 267 | for k, v in loss_record.items(): 268 | log_str += ("| train %s %g ") % (k, v / nsents) 269 | print(log_str) 270 | 271 | def train_raw(epoch): 272 | net.train() 273 | global tot_steps 274 | tot_loss = 0.0 275 | loss_record = defaultdict(float) 276 | nsents = 0 277 | trainperm = torch.randperm(len(corpus.raw_train)) 278 | 279 | for batch_idx in range(len(corpus.raw_train)): 280 | net.zero_grad() 281 | raw_sentence, raw_sentlen = corpus.raw_train[trainperm[batch_idx]] 282 | sentence_mask = make_sent_msk(raw_sentence, args.pad_idx) 283 | # if not args.add_genbow and not args.add_vbow: 284 | # raw_genbow, raw_vbow = None, None 285 | # else: 286 | # raw_genbow, raw_vbow = make_bow(raw_sentence, raw_sentlen, corpus.value_dict, corpus.gen_vocab) 287 | # if args.cuda: 288 | # raw_genbow, raw_vbow = raw_genbow.cuda(), raw_vbow.cuda() 289 | 290 | if args.cuda: 291 | raw_sentence, raw_sentlen = raw_sentence.cuda(), raw_sentlen.cuda() 292 | sentence_mask = sentence_mask.cuda() 293 | 294 | raw_loss_dict = net.decode_raw(Variable(raw_sentence), Variable(sentence_mask), valid=False) 295 | # raw_loss_dict = net.decode_raw(Variable(raw_sentence), Variable(raw_genbow), Variable(raw_vbow), Variable(sentence_mask), valid=False) 296 | 297 | loss = raw_loss_dict['raw_loss'] 298 | loss.backward() 299 | torch.nn.utils.clip_grad_norm_(net.parameters(), args.clip) 300 | if optalg is not None: 301 | optalg.step() 302 | else: 303 | for p in net.parameters(): 304 | if p.grad is not None: 305 | p.data.add_(-args.lr, p.grad.data) 306 | for l, v in raw_loss_dict.items(): 307 | loss_record[l] += v.item() 308 | nsents += 1 309 | 310 | if (batch_idx+1) % args.log_interval == 0: 311 | log_str = "batch %d/%d " % (batch_idx+1, len(corpus.raw_train)) 312 | for k, v in raw_loss_dict.items(): 313 | log_str += ("| train %s %g ") % (k, v.item()) 314 | # writer.add_scalar('Train/{}'.format(k), v.item(), tot_steps) 315 | print(log_str) 316 | 317 | del raw_sentence, raw_sentlen, raw_loss_dict 318 | 319 | log_str = "raw epoch %d " % (epoch) 320 | for k, v in loss_record.items(): 321 | log_str += ("| train %s %g ") % (k, v / nsents) 322 | print(log_str) 323 | sys.stdout.flush() 324 | 325 | def train_together(epoch): 326 | net.train() 327 | global tot_steps 328 | nsents = 0 329 | loss_record = defaultdict(float) 330 | train_size = min(len(corpus.paired_train), len(corpus.raw_train)) 331 | 332 | paired_perm = np.random.choice(len(corpus.paired_train), size=train_size, replace=False) 333 | raw_perm = np.random.choice(len(corpus.raw_train), size=train_size, replace=False) 334 | 335 | for batch_idx in range(train_size): 336 | net.zero_grad() 337 | # load pair data 338 | pair_sentence, paired_skeleton_sent, paired_sentlen, paired_src_feat = corpus.paired_train[paired_perm[batch_idx]] 339 | paired_mask, _ = make_masks(paired_src_feat, args.pad_idx) 340 | paired_sentence_mask = make_sent_msk(pair_sentence, args.pad_idx) 341 | # if not args.add_genbow and not args.add_vbow: 342 | # paired_genbow, paired_vbow = None, None 343 | # else: 344 | # paired_genbow, paired_vbow = make_bow(pair_sentence, paired_sentlen, corpus.value_dict, corpus.gen_vocab) 345 | # if args.cuda: 346 | # paired_genbow, paired_vbow = paired_genbow.cuda(), paired_vbow.cuda() 347 | 348 | # load raw data 349 | raw_sentence, raw_sentlen = corpus.raw_train[raw_perm[batch_idx]] 350 | raw_sentence_mask = make_sent_msk(raw_sentence, args.pad_idx) 351 | # if not args.add_genbow and not args.add_vbow: 352 | # raw_genbow, raw_vbow = None, None 353 | # else: 354 | # raw_genbow, raw_vbow = make_bow(raw_sentence, raw_sentlen, corpus.value_dict, corpus.gen_vocab) 355 | # if args.cuda: 356 | # raw_genbow, raw_vbow = raw_genbow.cuda(), raw_vbow.cuda() 357 | 358 | if args.cuda: 359 | paired_src_feat, paired_mask = paired_src_feat.cuda(), paired_mask.cuda() 360 | pair_sentence, paired_skeleton_sent = pair_sentence.cuda(), paired_skeleton_sent.cuda() 361 | paired_sentence_mask = paired_sentence_mask.cuda() 362 | raw_sentence, raw_sentlen = raw_sentence.cuda(), raw_sentlen.cuda() 363 | raw_sentence_mask = raw_sentence_mask.cuda() 364 | 365 | tot_steps += 1 366 | if hasattr(net, 'set_kl_weight'): 367 | net.set_kl_weight(tot_steps) 368 | 369 | total_loss, all_loss_dict = net.forward(Variable(paired_src_feat), Variable(paired_mask), Variable(pair_sentence), 370 | Variable(paired_skeleton_sent), Variable(paired_sentence_mask), 371 | Variable(raw_sentence), Variable(raw_sentence_mask), valid=False) 372 | 373 | # total_loss, all_loss_dict = net.forward(Variable(paired_src_feat), Variable(paired_mask), Variable(pair_sentence), 374 | # Variable(paired_skeleton_sent), Variable(paired_genbow), Variable(paired_vbow), Variable(paired_sentence_mask), 375 | # Variable(raw_sentence), Variable(raw_genbow), Variable(raw_vbow), Variable(raw_sentence_mask), 376 | # valid=False) 377 | 378 | total_loss.backward() 379 | torch.nn.utils.clip_grad_norm_(net.parameters(), args.clip) 380 | if optalg is not None: 381 | optalg.step() 382 | else: 383 | for p in net.parameters(): 384 | if p.grad is not None: 385 | p.data.add_(-args.lr, p.grad.data) 386 | 387 | for l, v in all_loss_dict.items(): 388 | loss_record[l] += v.item() 389 | nsents += 1 390 | 391 | if (batch_idx+1) % args.log_interval == 0: 392 | log_str = "batch %d/%d " % (batch_idx+1, train_size) 393 | for k, v in all_loss_dict.items(): 394 | log_str += ("| train %s %g ") % (k, v.item()) 395 | # writer.add_scalar('Train/{}'.format(k), v.item(), tot_steps) 396 | print(log_str) 397 | 398 | del pair_sentence, paired_src_feat, paired_sentlen, paired_mask, paired_skeleton_sent 399 | del raw_sentlen, raw_sentence, all_loss_dict 400 | 401 | log_str = "together epoch %d " % (epoch) 402 | for k, v in loss_record.items(): 403 | log_str += ("| train %s %g ") % (k, v / nsents) 404 | print(log_str) 405 | 406 | def valid(epoch): 407 | net.eval() 408 | pairnsents = 0 409 | rawnsents = 0 410 | loss_record = defaultdict(float) 411 | for i in range(len(corpus.paired_valid)): 412 | paired_sentence, paired_skeleton_sent, paired_sentlen, paired_src_feat = corpus.paired_valid[i] 413 | paired_mask, _ = make_masks(paired_src_feat, args.pad_idx) 414 | sentence_mask = make_sent_msk(paired_sentence, args.pad_idx) 415 | # if not args.add_genbow and not args.add_vbow: 416 | # paired_genbow, paired_vbow = None, None 417 | # else: 418 | # paired_genbow, paired_vbow = make_bow(paired_sentence, paired_sentlen, corpus.value_dict, corpus.gen_vocab) 419 | # if args.cuda: 420 | # paired_genbow, paired_vbow = paired_genbow.cuda(), paired_vbow.cuda() 421 | 422 | if args.cuda: 423 | paired_src_feat, paired_mask = paired_src_feat.cuda(), paired_mask.cuda() 424 | paired_sentence, paired_skeleton_sent = paired_sentence.cuda(), paired_skeleton_sent.cuda() 425 | sentence_mask = sentence_mask.cuda() 426 | 427 | paired_table_enc = net.encode_table(paired_src_feat, paired_mask) 428 | paired_loss_dict = net.decode_pair(paired_table_enc, paired_sentence, paired_skeleton_sent, paired_mask, 429 | sentence_mask, valid=True) 430 | 431 | for l, v in paired_loss_dict.items(): 432 | loss_record[l] += v.item() 433 | pairnsents += 1 434 | 435 | del paired_sentence, paired_src_feat, paired_mask, paired_skeleton_sent, paired_loss_dict 436 | 437 | if corpus.raw_valid is None: 438 | log_str = "No raw data for valid; epoch %d " % (epoch) 439 | for k, v in loss_record.items(): 440 | log_str += ("| valid %s %g ") % (k, v / pairnsents) 441 | print(log_str) 442 | sys.stdout.flush() 443 | return loss_record['pair_loss'] / pairnsents 444 | 445 | else: 446 | for i in range(len(corpus.raw_valid)): 447 | raw_sentence, raw_sentlen = corpus.raw_valid[i] 448 | sentence_mask = make_sent_msk(raw_sentence, args.pad_idx) 449 | # if not args.add_genbow and not args.add_vbow: 450 | # raw_genbow, raw_vbow = None, None 451 | # else: 452 | # raw_genbow, raw_vbow = make_bow(raw_sentence, raw_sentlen, corpus.value_dict, corpus.gen_vocab) 453 | # if args.cuda: 454 | # raw_genbow, raw_vbow = raw_genbow.cuda(), raw_vbow.cuda() 455 | if args.cuda: 456 | raw_sentence, raw_sentlen = raw_sentence.cuda(), raw_sentlen.cuda() 457 | sentence_mask = sentence_mask.cuda() 458 | 459 | raw_loss_dict = net.decode_raw(raw_sentence, sentence_mask, valid=True) 460 | for l, v in raw_loss_dict.items(): 461 | loss_record[l] += v.item() 462 | rawnsents += 1 463 | 464 | log_str = "epoch %d " % (epoch) 465 | for k, v in loss_record.items(): 466 | if "pair" in k: 467 | log_str += ("| valid %s %g ") % (k, v / pairnsents) 468 | # writer.add_scalar('Valid/{}'.format(k), v / pairnsents, epoch) 469 | if "raw" in k: 470 | log_str += ("| valid %s %g ") % (k, v / rawnsents) 471 | # writer.add_scalar('Valid/{}'.format(k), v / rawnsents, epoch) 472 | print(log_str) 473 | valid_loss = loss_record["pair_loss"] / pairnsents + loss_record["raw_loss"] / rawnsents 474 | return valid_loss 475 | 476 | def generate_samples(num=3): 477 | net.eval() 478 | valid_candidates = random.sample(range(len(corpus.paired_valid)), num) 479 | for i in valid_candidates: 480 | paired_sentence, paired_skeleton_sent, _, paired_src_feat = corpus.paired_valid[i] 481 | src_str = "" 482 | for keyid, widx in zip(paired_src_feat[0, :, 0], paired_src_feat[0, :, 2]): 483 | src_str += corpus.dictionary.idx2word[keyid] + "_" + corpus.dictionary.idx2word[widx] + "|" 484 | print("Source: ", src_str) 485 | 486 | ref = [] 487 | for t in range(paired_sentence.size(0)): 488 | word = corpus.dictionary.idx2word[paired_sentence[t, 0].item()] 489 | if word != '': 490 | ref.append(word) 491 | print("Reference: {}".format(" ".join(ref))) 492 | 493 | paired_mask, _ = make_masks(paired_src_feat, args.pad_idx) 494 | if args.cuda: 495 | paired_src_feat, paired_mask = paired_src_feat.cuda(), paired_mask.cuda() 496 | sentence_ids = net.predict(paired_src_feat, paired_mask) 497 | sentence_ids = sentence_ids.data.cpu() 498 | sent_words = [] 499 | for t, wid in enumerate(sentence_ids[:, 0]): 500 | word = corpus.dictionary.idx2word[wid] 501 | sent_words.append(word) 502 | print("Predict: {}".format(" ".join(str(w) for w in sent_words))) 503 | 504 | 505 | prev_valloss, best_valloss = float("inf"), float("inf") 506 | tot_steps = 0 507 | 508 | total_epoch = 1 509 | for epoch in range(1, args.paired_epochs + 1): 510 | train_pair(epoch) 511 | valloss = valid(epoch) 512 | generate_samples() # show some generation samples 513 | if valloss < best_valloss and total_epoch > args.warm_up_epoch: 514 | print("save best valid loss = {:.4f}".format(valloss)) 515 | best_valloss = valloss 516 | if len(args.save) > 0: 517 | print("saving to {}...".format(args.save)) 518 | state = {"opt": args, "state_dict": net.state_dict(), "lr": args.lr, 519 | "all_dict": corpus.dictionary, "gen_dict": corpus.gen_vocab, 520 | "value_dict": corpus.value_dict} 521 | torch.save(state, args.save) 522 | prev_valloss = valloss 523 | cuda2cpu() 524 | total_epoch += 1 525 | 526 | for epoch in range(1, args.raw_epochs + 1): 527 | train_raw(epoch) 528 | valloss = valid(epoch) 529 | generate_samples() 530 | if valloss < best_valloss and total_epoch > args.warm_up_epoch: 531 | print("save best valid loss = {:.4f}".format(valloss)) 532 | best_valloss = valloss 533 | if len(args.save) > 0: 534 | print("saving to {}...".format(args.save)) 535 | state = {"opt": args, "state_dict": net.state_dict(), "lr": args.lr, 536 | "all_dict": corpus.dictionary, "gen_dict": corpus.gen_vocab, 537 | "value_dict": corpus.value_dict} 538 | torch.save(state, args.save) 539 | prev_valloss = valloss 540 | cuda2cpu() 541 | total_epoch += 1 542 | 543 | for epoch in range(1, args.epochs + 1): 544 | train_together(epoch) 545 | valloss = valid(epoch) 546 | generate_samples() 547 | if valloss < best_valloss and total_epoch > args.warm_up_epoch: 548 | print("save best valid loss = {:.4f}".format(valloss)) 549 | best_valloss = valloss 550 | if len(args.save) > 0: 551 | print("saving to {}...".format(args.save)) 552 | state = {"opt": args, "state_dict": net.state_dict(), "lr": args.lr, 553 | "all_dict": corpus.dictionary, "gen_dict": corpus.gen_vocab, 554 | "value_dict": corpus.value_dict} 555 | torch.save(state, args.save) 556 | prev_valloss = valloss 557 | cuda2cpu() 558 | total_epoch += 1 559 | -------------------------------------------------------------------------------- /models/variational_template_machine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from models import feedback 6 | from models import attention 7 | import numpy as np 8 | 9 | 10 | def logsumexp1(X): 11 | """ X - B x K 12 | returns: B x 1 13 | """ 14 | maxes, _ = torch.max(X, 1, True) 15 | lse = maxes + torch.log(torch.sum(torch.exp(X - maxes.expand_as(X)), 1, True)) 16 | return lse 17 | 18 | 19 | def gaussian_kld(recog_mu, recog_logvar, prior_mu, prior_logvar): 20 | kld = -0.5 * torch.sum(1 + (recog_logvar - prior_logvar) - 21 | torch.div(torch.pow(prior_mu - recog_mu, 2), torch.exp(prior_logvar)) - 22 | torch.div(torch.exp(recog_logvar), torch.exp(prior_logvar)), 1) 23 | return kld 24 | 25 | 26 | def norm_log_liklihood(x, mu, logvar): 27 | return -0.5*torch.sum(logvar + np.log(2*np.pi) + torch.div(torch.pow((x-mu), 2), torch.exp(logvar)), 1) 28 | 29 | 30 | def sample_from_gaussian(mu, logvar, seed=None): 31 | if seed is None: 32 | epsilon = logvar.new_empty(logvar.size()).normal_() # train 33 | else: # during generation set different random seed 34 | if not logvar.is_cuda: # if it is not on gpu 35 | epsilon = logvar.new_empty(logvar.size()).normal_(generator=torch.manual_seed(seed)) 36 | else: # if tensor is on gpu 37 | epsilon = logvar.new_empty(logvar.size()).normal_(generator=torch.cuda.manual_seed_all(seed)) 38 | std = torch.exp(0.5 * logvar) 39 | z = mu + std * epsilon 40 | return z 41 | 42 | 43 | class VariationalTemplateMachine(nn.Module): 44 | def __init__(self, corpus, opt): 45 | super(VariationalTemplateMachine, self).__init__() 46 | self.use_cuda = opt.cuda 47 | self.corpus = corpus 48 | self.unk_idx = corpus.dictionary.word2idx[corpus.dictionary.unk_word] 49 | self.pad_idx = opt.pad_idx 50 | self.ent_idx = corpus.dictionary.word2idx[""] 51 | self.vocab_size = len(corpus.dictionary) 52 | self.gen_vocab_size = len(corpus.gen_vocab) # skelton vocab 53 | self.drop_emb = opt.drop_emb 54 | self.value_dict_size = len(corpus.value_dict) 55 | # self.w2i = corpus.dictionary.word2idx 56 | 57 | self.max_seqlen = opt.max_seqlen 58 | self.initrange = opt.initrange 59 | self.inp_feats = 3 # key, pos, value 60 | self.emb_size = opt.emb_size 61 | self.drop = nn.Dropout(opt.dropout) 62 | 63 | # encode table 64 | self.table_hid_size = opt.table_hid_size 65 | self.word_emb = nn.Embedding(self.vocab_size, self.emb_size, padding_idx=self.pad_idx) 66 | self.table_hidden_out = nn.Linear(self.emb_size * self.inp_feats, self.table_hid_size, bias=True) 67 | self.pool_type = opt.pool_type 68 | 69 | self.zeros = torch.zeros(1, 1) 70 | if opt.cuda: 71 | self.zeros = self.zeros.cuda() 72 | 73 | # encoder sentence x 74 | # self.birnn = opt.birnn 75 | self.hid_size = opt.hid_size 76 | self.layers = opt.layers 77 | self.rnn_encode = nn.GRU(self.emb_size, self.hid_size, self.layers, dropout=opt.dropout, bidirectional=True) 78 | self.sent_represent = opt.sent_represent 79 | 80 | # prior net: p(z) follows N(0,I) 81 | self.z_latent_size = opt.z_latent_size 82 | self.c_latent_size = self.table_hid_size # latent size of c should be equal to table hidden size 83 | 84 | # posterior net(inference): q(z|x) 85 | inf_input_size = opt.layers * opt.hid_size * 2 # b x (2*emb+layer*hid) 86 | self.z_posterior = nn.Linear(inf_input_size, self.z_latent_size * 2) # mu_z & logvar_z 87 | self.c_posterior = nn.Linear(inf_input_size, self.c_latent_size * 2) # mu_c & logvat_c 88 | 89 | # generator p(X|z,h_tab) 90 | self.z2dec = nn.Linear(self.z_latent_size + self.table_hid_size, self.hid_size) 91 | 92 | self.word_rnn = nn.LSTM(self.z_latent_size + self.emb_size + self.table_hid_size, 93 | self.hid_size, self.layers, dropout=opt.dropout) 94 | self.use_dec_attention = opt.dec_attention 95 | if self.use_dec_attention: 96 | self.attn_table_hidden = attention.GeneralAttention(query_dim=self.hid_size, key_dim=self.table_hid_size) 97 | self.generator_out = nn.Linear(opt.hid_size + self.table_hid_size, self.vocab_size + 1) 98 | else: 99 | self.generator_out = nn.Linear(opt.hid_size, self.vocab_size + 1) 100 | 101 | self.set_decode_method(opt) 102 | 103 | # setup loss weights etc 104 | # self.add_vbow = opt.add_vbow 105 | # if self.add_vbow: 106 | # self.value_bow_z_proj = nn.Linear(self.z_latent_size, self.value_dict_size) # max H 107 | # self.value_bow_c_proj = nn.Linear(self.c_latent_size, self.value_dict_size) # min CE 108 | # self.vbow_weight_z = opt.vbow_weight_z 109 | # self.vbow_weight_c = opt.vbow_weight_c 110 | 111 | self.add_mi_z = opt.add_mi_z 112 | self.mi_z_weight = opt.mi_z_weight 113 | self.add_mi_c = opt.add_mi_c 114 | self.mi_c_weight = opt.mi_c_weight 115 | 116 | # self.add_genbow = opt.add_genbow 117 | # if self.add_genbow: 118 | # self.gen_bow_z_proj = nn.Linear(self.z_latent_size, self.gen_vocab_size) # min CE 119 | # self.gen_bow_c_proj = nn.Linear(self.c_latent_size, self.gen_vocab_size) # max H 120 | # self.genbow_weight_z = opt.genbow_weight_z 121 | # self.genbow_weight_c = opt.genbow_weight_c 122 | 123 | self.add_preserving_template_loss = opt.add_preserving_template_loss 124 | if opt.add_preserving_template_loss: 125 | self.pt_weight = opt.pt_weight 126 | self.template_rnn = nn.LSTM(self.z_latent_size + self.emb_size, self.hid_size, self.layers, dropout=opt.dropout) 127 | self.template_out = nn.Linear(self.hid_size, self.vocab_size+1) 128 | 129 | # self.add_skeleton = opt.add_skeleton 130 | # if self.add_skeleton: 131 | # self.sk_weight = opt.sk_weight 132 | # # decoder for x_skeleton 133 | # self.sk_rnn = nn.LSTM(self.z_latent_size + self.emb_size, self.hid_size, self.layers, dropout=opt.dropout) 134 | # self.sk_generator_out = nn.Linear(self.hid_size, self.vocab_size + 1) 135 | 136 | # setup annealing tricks for latent variable 137 | self.anneal_function_z = opt.anneal_function_z 138 | self.anneal_k_z = opt.anneal_k_z 139 | self.anneal_x0_z = opt.anneal_x0_z 140 | self.kl_weight_z = 0.5 141 | 142 | self.anneal_function_c = opt.anneal_function_c 143 | self.anneal_k_c = opt.anneal_k_c 144 | self.anneal_x0_c = opt.anneal_x0_c 145 | self.kl_weight_c = 0.5 146 | 147 | self.add_preserving_content_loss = opt.add_preserving_content_loss 148 | self.pc_weight = opt.pc_weight 149 | self.rawloss_weight = opt.rawloss_weight 150 | 151 | self.init_weight() 152 | 153 | def init_weight(self): 154 | initrange = self.initrange 155 | self.word_emb.weight.data.uniform_(-initrange, initrange) 156 | self.word_emb.weight.data[self.pad_idx].zero_() 157 | self.word_emb.weight.data[self.corpus.dictionary.word2idx[""]].zero_() 158 | self.word_emb.weight.data[self.corpus.dictionary.word2idx[""]].zero_() 159 | self.word_emb.weight.data[self.corpus.dictionary.word2idx['']].zero_() 160 | 161 | # init generator rnn 162 | for thing in self.word_rnn.parameters(): 163 | thing.data.uniform_(-initrange, initrange) 164 | for thing in self.rnn_encode.parameters(): 165 | thing.data.uniform_(-initrange, initrange) 166 | 167 | def set_kl_weight(self, step): 168 | # set kl weight for z 169 | if self.anneal_function_z == 'logistic': 170 | self.kl_weight_z = float(1 / (1 + np.exp(-self.anneal_k_z * (step - self.anneal_x0_z)))) 171 | elif self.anneal_function_z == 'linear': 172 | self.kl_weight_z = min(1, step / self.anneal_x0_z) 173 | elif self.anneal_function_z == 'const': 174 | self.kl_weight_z = self.anneal_k_z 175 | else: 176 | self.kl_weight_z = 0.5 177 | 178 | # set kl weight of c 179 | if self.anneal_function_c == 'logistic': 180 | self.kl_weight_c = float(1 / (1 + np.exp(-self.anneal_k_c * (step - self.anneal_x0_c)))) 181 | elif self.anneal_function_c == 'linear': 182 | self.kl_weight_c = min(1, step / self.anneal_x0_c) 183 | elif self.anneal_function_z == 'const': 184 | self.kl_weight_c = self.anneal_k_c 185 | else: 186 | self.kl_weight_c = 0.5 187 | 188 | def set_decode_method(self, opt): 189 | self.beamsz = 1 190 | self.decode_method = opt.decode_method.lower() 191 | if "beam_search" in self.decode_method: 192 | self.beamsz = opt.beamsz 193 | if self.beamsz == 1: 194 | self.feedback_x = feedback.GreedyFeedBack(self.word_emb, self.unk_idx) 195 | else: 196 | self.feedback_x = feedback.BeamFeedBack(self.word_emb, self.beamsz, self.unk_idx) 197 | elif self.decode_method == "temp_sample": 198 | self.sample_temperature = opt.sample_temperature 199 | if self.sample_temperature == 1: 200 | self.feedback_x = feedback.SampleFeedBack(self.word_emb, self.unk_idx) 201 | elif self.sample_temperature < 0.001: # if too small, then it is a greedy method 202 | self.feedback_x = feedback.GreedyFeedBack(self.word_emb, self.unk_idx) 203 | else: 204 | self.feedback_x = feedback.SampleFeedBackWithTemperature(self.word_emb, self.unk_idx, 205 | temperature=self.sample_temperature) 206 | elif self.decode_method == "topk_sample": 207 | self.topk = opt.topk 208 | if self.topk == 1: 209 | self.feedback_x = feedback.GreedyFeedBack(self.word_emb, self.unk_idx) 210 | else: 211 | self.feedback_x = feedback.TopkSampleFeedBack(self.word_emb, self.unk_idx, self.topk) 212 | 213 | elif self.decode_method == "nucleus_sample": 214 | self.topp = opt.topp 215 | if self.topp == 1: 216 | self.feedback_x = feedback.SampleFeedBack(self.word_emb, self.unk_idx) 217 | else: 218 | self.feedback_x = feedback.NucleusSampleFeedBack(self.word_emb, self.unk_idx, self.topp) 219 | 220 | else: 221 | self.feedback_x = feedback.GreedyFeedBack(self.word_emb, self.unk_idx) 222 | 223 | def _get_posterior_input(self, y_out, h_yt): 224 | bsz = h_yt.size(1) 225 | if self.sent_represent == "last_hid": 226 | h_yt = torch.transpose(h_yt, 0, 1).contiguous().view(bsz, -1) # b x layer*hid 227 | posterior_input = h_yt 228 | elif self.sent_represent == "seq_avg": 229 | y_out = y_out.mean(0) # b x layer*hid 230 | posterior_input = y_out 231 | else: 232 | raise NotImplementedError 233 | return posterior_input 234 | 235 | def encode_table(self, src, fieldmask): 236 | """ 237 | :param src: b x nfield x 3 238 | :param fieldmask: b x nfield 0/-inf 239 | :return: key_emb: b x nfield x 2 x emb 240 | masked_key_emb: b x nfield x 2 x emb 241 | value_emb: b x nfield x emb 242 | h_table_field: b x nfield x hidden 243 | h_table: b x hidden 244 | """ 245 | bsz, nfields, nfeats = src.size() 246 | emb_size = self.word_emb.embedding_dim 247 | # src_key, src_value = src[:, :, :2], src[:, :, 2:] # src_key: b x nfield x 2, src_value: b x nfield x 1 248 | 249 | embs = self.word_emb(src.view(-1, nfeats)) # bsz*nfields x nfeats x emb_size 250 | 251 | if self.drop_emb: 252 | embs = self.drop(embs) # b*nfield x 3 x emb 253 | 254 | key_emb = embs[:, :2, :].view(bsz, nfields, -1, emb_size) # b x nfield x 2 x emb 255 | value_emb = embs[:, 2:, :].view(bsz, nfields, -1, emb_size).squeeze(2) # b x nfield x emb 256 | 257 | h_table_field = F.tanh(self.table_hidden_out(embs.view(-1, nfeats*emb_size))).view(bsz, nfields, -1) # b x nfield x hidden 258 | 259 | fieldmask[fieldmask == 0] = 1 260 | fieldmask = fieldmask.unsqueeze(2).expand(bsz, nfields, emb_size).unsqueeze(2).expand(bsz, nfields, 2, emb_size) # b x nfield x 2 x emb 261 | masked_key_emb = key_emb * fieldmask # b x nfield x 2 x emb 262 | 263 | if self.pool_type == "max": 264 | masked_key_emb = masked_key_emb.view(bsz, nfields, -1).transpose(1,2) # b x 2*emb x nfield 265 | masked_key_emb = F.max_pool1d(masked_key_emb, nfields).squeeze(2) # b x 2*emb 266 | 267 | h_table = F.max_pool1d(h_table_field.transpose(1, 2), nfields).squeeze(2) # b x emb 268 | 269 | elif self.pool_type == "mean": 270 | masked_key_emb = masked_key_emb.mean(1).view(bsz, -1) # b x 2*dim 271 | h_table = h_table_field.mean(1).view(bsz, -1) # b x emb 272 | else: 273 | raise NotImplementedError 274 | # masked_key_emb = masked_emb[:, :2, :].view(bsz, -1) # b*nfield x 2 x emb -> b x 2*emb 275 | 276 | return key_emb, masked_key_emb, value_emb, h_table_field, h_table 277 | 278 | def decode_pair(self, table_encodes, sentence, template_sent, fieldmask, sentence_mask, valid=False): 279 | """ training process 280 | Args: 281 | table_encodes: outputs of table encoders - key_emb, masked_key_emb, value_emb, h_table 282 | key_emb: b x nfield x 2 x emb 283 | masked_key_emb: b x 2*emb 284 | value_emb: b x nfield x emb 285 | h_table_field: b x nfield x tabhid 286 | h_table: b x tabhid 287 | sentence: seq x b 288 | template_sent: seq x b (mask for values and pad to the same length) 289 | fieldmask: b x nfield 290 | sentence_mask: seq x b 291 | """ 292 | key_emb, masked_key_emb, value_emb, h_table_field, h_table = table_encodes 293 | 294 | bsz, nfield, _ = value_emb.size() 295 | seqlen = sentence.size(0) 296 | 297 | sent_emb = self.word_emb(sentence) # seq x b x emb 298 | if self.drop_emb: 299 | sent_emb = self.drop(sent_emb) # seq x b x emb 300 | emb_size = sent_emb.size(2) 301 | 302 | # posterior q(z|x) 303 | h_y0 = torch.zeros(self.layers*2, bsz, self.hid_size).contiguous() # default bi-rnn, so, 2 layers 304 | if self.use_cuda: 305 | h_y0 = h_y0.cuda() 306 | 307 | y_out, h_yt = self.rnn_encode(sent_emb, h_y0) # y_out: seq x b x layer*hid, h_yt: layer x b x hid 308 | 309 | posterior_input = self._get_posterior_input(y_out, h_yt) # b x layer*hid 310 | posterior_out_z = self.z_posterior(posterior_input) # b x latent_z*2 311 | mu_post_z, logvar_post_z = torch.chunk(posterior_out_z, 2, 1) # both has size b x latent_z 312 | # sample z from the posterior 313 | z_sample = sample_from_gaussian(mu_post_z, logvar_post_z) # b x latent_z 314 | 315 | # prior of z: p(z) = N(0,I) 316 | mu_prior_z = self.zeros.expand(z_sample.size()) 317 | logvar_prior_z = self.zeros.expand(z_sample.size()) 318 | # posterior of c q(c|x) 319 | posterior_out_c = self.c_posterior(posterior_input) 320 | mu_post_c, logvar_post_c = torch.chunk(posterior_out_c, 2, 1) 321 | c_sample = sample_from_gaussian(mu_post_c, logvar_post_c) 322 | 323 | # prior of c p(c) = N(0,I) 324 | mu_prior_c = self.zeros.expand(c_sample.size()) 325 | logvar_prior_c = self.zeros.expand(c_sample.size()) 326 | # generate p(x|z,h_tab) for paired data 327 | ar_embs = torch.cat([self.word_emb.weight[2].view(1, 1, emb_size).expand(1, bsz, emb_size), 328 | sent_emb[:-1, :, :]], 0) # seqlen x bsz x emb_size 329 | 330 | ar_embs = torch.cat([ar_embs, z_sample.expand(seqlen, bsz, -1), h_table.expand(seqlen, bsz, -1)], dim=-1) # seq x b x (emb_size + latent_z + tabhid) 331 | states, (h, c) = self.word_rnn(ar_embs) # (h0, c0) states: seq x b x hid 332 | 333 | if self.use_dec_attention: 334 | assert fieldmask is not None 335 | attn_score_dec, attn_ctx_dec, attn_logits_dec = self.attn_table_hidden.forward(states, h_table_field, h_table_field, fieldmask, return_logits=True) 336 | # attn_score_dec, attn_logits_dec: b x seq x nfield 337 | # attn_ctx_dec: b x seq x tabhid 338 | dec_outs = self.generator_out(torch.cat([states, attn_ctx_dec.transpose(0, 1)], dim=-1)) # seq x b x vocab 339 | else: 340 | dec_outs = self.generator_out(states) 341 | 342 | seq_prob = F.softmax(dec_outs, dim=-1) # seq x b x vocab 343 | seq_prob = torch.cat([seq_prob, Variable(self.zeros.expand(seq_prob.size(0), seq_prob.size(1), 1))], dim=-1) 344 | 345 | crossEntropy = -torch.log(torch.gather(seq_prob, -1, sentence.view(seqlen, bsz, 1)) + 1e-15) # seqlen x b x 1 346 | # print(crossEntropy.squeeze(2).mean()) 347 | # sentence_mask = torch.ByteTensor(sentence.cpu() != self.pad_idx).transpose(0, 1) 348 | # if self.use_cuda: 349 | # sentence_mask = sentence_mask.cuda() 350 | nll_loss = crossEntropy.masked_select(sentence_mask).mean() 351 | 352 | # KL between prior and posterior KL(q(z|x)||p(z)) and kl(q(c|x)||p(c)) 353 | KL_z = torch.mean(gaussian_kld(mu_post_z, logvar_post_z, mu_prior_z, logvar_prior_z), dim=0) # dim=b -> mean 354 | KL_c = torch.mean(gaussian_kld(mu_post_c, logvar_post_c, mu_prior_c, logvar_prior_c), dim=0) # dim=b -> mean 355 | if not valid: 356 | total_loss = nll_loss + self.kl_weight_z * KL_z 357 | else: 358 | total_loss = nll_loss + KL_z 359 | 360 | loss_dict = {"pair_nll": nll_loss, "pair_KLz": KL_z, "kl_weight_z": torch.full((1,), self.kl_weight_z)} 361 | 362 | if self.add_preserving_content_loss: 363 | mse_loss = F.mse_loss(c_sample, h_table) # L_committment 364 | if not valid: 365 | total_loss += self.pc_weight * mse_loss + self.kl_weight_c * KL_c 366 | loss_dict['pair_mse'] = mse_loss 367 | loss_dict['pair_KLc'] = KL_c 368 | 369 | if self.add_preserving_template_loss: 370 | template_emb = self.word_emb(template_sent) 371 | if self.drop_emb: 372 | template_emb = self.drop(template_emb) 373 | ar_embs = torch.cat([self.word_emb.weight[2].view(1, 1, emb_size).expand(1, bsz, emb_size), 374 | template_emb[:-1 ,: ,:]], 0) # temp_seqlen x bsz x emb_size, init 375 | temp_seqlen = template_sent.size(0) 376 | ar_embs = torch.cat([ar_embs, z_sample.expand(temp_seqlen, bsz, -1)], dim=-1) 377 | temp_states, (h, c) = self.template_rnn(ar_embs) 378 | temp_outs = self.template_out(temp_states) 379 | temp_seq_prob = F.softmax(temp_outs, dim=-1) 380 | 381 | crossEntropy = -torch.log( 382 | torch.gather(temp_seq_prob, -1, template_sent.view(temp_seqlen, bsz, 1)) + 1e-15) # seqlen x b x 1 383 | temp_mask = torch.ByteTensor(template_sent.cpu() != self.pad_idx).transpose(0, 1) 384 | if self.use_cuda: 385 | temp_mask = temp_mask.cuda() 386 | pt_loss = crossEntropy.masked_select(temp_mask).mean() 387 | 388 | if not valid: 389 | total_loss += self.pt_weight * pt_loss 390 | loss_dict["pt_loss"] = pt_loss 391 | 392 | # if self.add_skeleton: 393 | # sk_sent_emb = self.word_emb(sent_skelton) # skseq x b x emb 394 | # if self.drop_emb: 395 | # sk_sent_emb = self.drop(sk_sent_emb) 396 | # 397 | # ar_embs = torch.cat([self.word_emb.weight[2].view(1, 1, emb_size).expand(1, bsz, emb_size), sk_sent_emb[:-1 ,: ,:]], 0) # skseqlen x bsz x emb_size, init 398 | # sk_seqlen = sent_skelton.size(0) 399 | # ar_embs = torch.cat([ar_embs, z_sample.expand(sk_seqlen, bsz, -1)], dim=-1) # skseq x b x (emb_size + latent_z) 400 | # sk_states, (h, c) = self.sk_rnn(ar_embs) # (h0, c0) states: skseq x b x hid 401 | # sk_dec_outs = self.sk_generator_out(sk_states) # skseq x b x vocab 402 | # sk_seq_prob = F.softmax(sk_dec_outs, dim=-1) # skseq x b x vocab 403 | # # nll loss 404 | # crossEntropy = -torch.log(torch.gather(sk_seq_prob, -1, sent_skelton.view(sk_seqlen, bsz, 1)) + 1e-15) # seqlen x b x 1 405 | # sent_sk_mask = torch.ByteTensor(sent_skelton.cpu() != self.pad_idx).transpose(0, 1) 406 | # if self.use_cuda: 407 | # sent_sk_mask = sent_sk_mask.cuda() 408 | # skeleton_nllloss = crossEntropy.masked_select(sent_sk_mask).mean() 409 | # 410 | # if not valid: 411 | # total_loss += self.sk_weight * skeleton_nllloss 412 | # loss_dict['pair_sk_nll'] = skeleton_nllloss 413 | 414 | if self.add_mi_z: 415 | logqz = norm_log_liklihood(z_sample, self.zeros.expand(z_sample.size()), self.zeros.expand(z_sample.size())) # dim=b 416 | logqz_Cx = norm_log_liklihood(z_sample, mu_post_z, logvar_post_z) # dim=b 417 | mutual_info_z = (logqz_Cx - logqz).mean() # b -> 1x1 418 | if not valid: 419 | loss_dict['pair_mi_z'] = mutual_info_z 420 | total_loss += self.mi_z_weight * mutual_info_z 421 | 422 | if self.add_mi_c: 423 | logqc = norm_log_liklihood(c_sample, self.zeros.expand(c_sample.size()), self.zeros.expand(c_sample.size())) 424 | logqc_Cx = norm_log_liklihood(c_sample, mu_post_c, logvar_post_c) 425 | mutual_info_c = (logqc_Cx - logqc).mean() 426 | if not valid: 427 | loss_dict['pair_mi_c'] = mutual_info_c 428 | total_loss += self.mi_c_weight * mutual_info_c 429 | 430 | loss_dict['pair_loss'] = total_loss 431 | return loss_dict 432 | 433 | def decode_raw(self, sentence, sentence_mask, valid=False): 434 | # sentence: seq x b 435 | seqlen, bsz = sentence.size() 436 | sent_emb = self.word_emb(sentence) # seq x b x emb 437 | if self.drop_emb: 438 | sent_emb = self.drop(sent_emb) # seq x b x emb 439 | emb_size = sent_emb.size(2) 440 | 441 | # posterior q(z|x) 442 | h_y0 = torch.zeros(self.layers * 2, bsz, self.hid_size).contiguous() # default bi-rnn, so, 2 layers 443 | if self.use_cuda: 444 | h_y0 = h_y0.cuda() 445 | 446 | y_out, h_yt = self.rnn_encode(sent_emb, h_y0) # y_out: seq x b x layer*hid, h_yt: layer x b x hid 447 | 448 | # posterior of z q(z|x) 449 | posterior_input = self._get_posterior_input(y_out, h_yt) # b x layer*hid 450 | posterior_out_z = self.z_posterior(posterior_input) # b x latent_z*2 451 | mu_post_z, logvar_post_z = torch.chunk(posterior_out_z, 2, 1) # both has size b x latent_z 452 | # sample z from the posterior 453 | z_sample = sample_from_gaussian(mu_post_z, logvar_post_z) # b x latent_z 454 | 455 | # prior of z: p(z) = N(0,I) 456 | mu_prior_z = self.zeros.expand(z_sample.size()) 457 | logvar_prior_z = self.zeros.expand(z_sample.size()) 458 | # posterior of c q(c|x) 459 | posterior_out_c = self.c_posterior(posterior_input) 460 | mu_post_c, logvar_post_c = torch.chunk(posterior_out_c, 2, 1) 461 | c_sample = sample_from_gaussian(mu_post_c, logvar_post_c) 462 | 463 | # prior of c p(c) = N(0,I) 464 | mu_prior_c = self.zeros.expand(c_sample.size()) 465 | logvar_prior_c = self.zeros.expand(c_sample.size()) 466 | # generate p(x|z,h_tab) for paired data 467 | # wlps_k_collect = [] 468 | ar_embs = torch.cat([self.word_emb.weight[2].view(1, 1, emb_size).expand(1, bsz, emb_size), 469 | sent_emb[:-1, :, :]], 0) # seqlen x bsz x emb_size 470 | ar_embs = torch.cat([ar_embs, z_sample.expand(seqlen, bsz, -1), c_sample.expand(seqlen, bsz, -1)], 471 | dim=-1) # seq x b x (emb_size + latent_z + tabhid) 472 | states, (h, c) = self.word_rnn(ar_embs) # (h0, c0) states: seq x b x hid 473 | 474 | if self.use_dec_attention: 475 | raw_attn = torch.zeros(seqlen, bsz, self.c_latent_size).contiguous() 476 | if self.use_cuda: 477 | raw_attn = raw_attn.cuda() 478 | dec_outs = self.generator_out(torch.cat([states, raw_attn], dim=-1)) # seq x b x vocab 479 | else: 480 | dec_outs = self.generator_out(states) 481 | 482 | seq_prob = F.softmax(dec_outs) # seq x b x vocab 483 | 484 | # nll loss 485 | crossEntropy = -torch.log(torch.gather(seq_prob, -1, sentence.view(seqlen, bsz, 1)) + 1e-15) # seqlen x b x 1 486 | nll_loss = crossEntropy.masked_select(sentence_mask).mean() 487 | 488 | # KL between prior and posterior KL(q(z|x)||p(z)) and kl(q(c|x)||p(c)) 489 | KL_z = torch.mean(gaussian_kld(mu_post_z, logvar_post_z, mu_prior_z, logvar_prior_z), dim=0) # dim=b -> mean 490 | KL_c = torch.mean(gaussian_kld(mu_post_c, logvar_post_c, mu_prior_c, logvar_prior_c), dim=0) # dim=b -> mean 491 | 492 | if not valid: # train 493 | total_loss = nll_loss + self.kl_weight_z * KL_z + self.kl_weight_c * KL_c 494 | else: 495 | total_loss = nll_loss + KL_c + KL_z 496 | 497 | loss_dict = {"raw_nll": nll_loss, "raw_KLz": KL_z, "raw_KLc": KL_c, 498 | "kl_weight_z": torch.full((1,), self.kl_weight_z), "kl_weight_c": torch.full((1,), self.kl_weight_c)} 499 | 500 | # if self.add_genbow: 501 | # p_genbow_fromz = F.softmax(self.gen_bow_z_proj(z_sample), dim=1) # b x gen_voc 502 | # p_genbow_fromc = F.softmax(self.gen_bow_c_proj(c_sample), dim=1) # b x gen_voc 503 | # sent_genbow = sent_genbow.float() 504 | # sent_genbow_p_targ = torch.div((sent_genbow+1).transpose(0, 1), (sent_genbow+1).sum(1)).transpose(0, 1) # normalize b x gen_voc 505 | # # we can predict what's in skelton if noly know z 506 | # CE_genbow_z = - torch.mean(torch.sum(sent_genbow_p_targ * torch.log(p_genbow_fromz + 1e-15), dim=1), dim=0) # dim=b -> mean 507 | # # we cannot predict whats in skelton of only know c 508 | # negH_genbow_c = -torch.distributions.categorical.Categorical(probs=p_genbow_fromc).entropy() 509 | # negH_genbow_c = torch.mean(negH_genbow_c) # dim=b ->mean 510 | # if not valid: 511 | # total_loss += self.genbow_weight_z * CE_genbow_z + self.genbow_weight_c * negH_genbow_c 512 | # loss_dict['raw_CE_genbow_z'] = CE_genbow_z 513 | # loss_dict['raw_H_genbow_c'] = - negH_genbow_c 514 | 515 | # if self.add_vbow: 516 | # p_vbow_fromz = F.softmax(self.value_bow_z_proj(z_sample), dim=1) 517 | # p_vbow_fromc = F.softmax(self.value_bow_c_proj(c_sample), dim=1) 518 | # sent_vbow = sent_vbow.float() 519 | # sent_vbow_p_targ = torch.div((sent_vbow+1).transpose(0, 1), (sent_vbow+1).sum(1)).transpose(0, 1) 520 | # negH_vbow_z = - torch.distributions.categorical.Categorical(probs=p_vbow_fromz).entropy() 521 | # negH_vbow_z = torch.mean(negH_vbow_z, dim=0) 522 | # CE_vbow_c = torch.sum(sent_vbow_p_targ * torch.log(p_vbow_fromc + 1e-15), dim=1) 523 | # CE_vbow_c = - torch.mean(CE_vbow_c, dim=0) 524 | # if not valid: 525 | # total_loss += self.vbow_weight_z * negH_vbow_z + self.vbow_weight_c * CE_vbow_c 526 | # loss_dict['raw_H_vbow_z'] = - negH_vbow_z 527 | # loss_dict['raw_CE_vbow_c'] = CE_vbow_c 528 | 529 | if self.add_mi_z: 530 | logqz = norm_log_liklihood(z_sample, self.zeros.expand(z_sample.size()), self.zeros.expand(z_sample.size())) # dim=b 531 | logqz_Cx = norm_log_liklihood(z_sample, mu_post_z, logvar_post_z) # dim=b 532 | mutual_info_z = (logqz_Cx - logqz).mean() # b -> 1x1 533 | if not valid: 534 | loss_dict['pair_mi_z'] = mutual_info_z 535 | total_loss += mutual_info_z 536 | 537 | if self.add_mi_c: 538 | logqc = norm_log_liklihood(c_sample, self.zeros.expand(c_sample.size()), self.zeros.expand(c_sample.size())) 539 | logqc_Cx = norm_log_liklihood(c_sample, mu_post_c, logvar_post_c) 540 | mutual_info_c = (logqc_Cx - logqc).mean() 541 | if not valid: 542 | loss_dict['pair_mi_c'] = mutual_info_c 543 | total_loss += mutual_info_c 544 | 545 | loss_dict['raw_loss'] = total_loss 546 | return loss_dict 547 | 548 | def forward(self, pair_src, pair_mask, pair_sentence, pair_sent_skeleton, pair_sent_mask, 549 | raw_sentence, raw_sent_mask, valid=False): 550 | # for paired data 551 | paired_table_enc = self.encode_table(pair_src, pair_mask) 552 | paired_loss_dict = self.decode_pair(paired_table_enc, pair_sentence, pair_sent_skeleton, 553 | pair_mask, pair_sent_mask, valid=valid) 554 | # for raw data 555 | raw_loss_dict = self.decode_raw(raw_sentence, raw_sent_mask, valid=valid) 556 | 557 | all_loss_dict = {} 558 | for k, v in paired_loss_dict.items(): 559 | all_loss_dict[k] = v 560 | for k,v in raw_loss_dict.items(): 561 | all_loss_dict[k] = v 562 | 563 | # total loss dict 564 | if not valid: # train 565 | total_loss = paired_loss_dict['pair_loss'] + self.rawloss_weight * raw_loss_dict['raw_loss'] 566 | else: # test loss on valid 567 | total_loss = paired_loss_dict['pair_loss'] + raw_loss_dict['raw_loss'] 568 | 569 | return total_loss, all_loss_dict 570 | 571 | 572 | def predict(self, paired_src, paired_mask, beam_size=None): 573 | bsz, _, _ = paired_src.size() 574 | 575 | real_beam_flag = False # default is false 576 | if beam_size is not None: 577 | if beam_size == 1: 578 | self.feedback_x = feedback.GreedyFeedBack(self.word_emb, self.unk_idx) 579 | else: 580 | real_beam_flag = True 581 | self.feedback_x = feedback.BeamFeedBack(self.word_emb, beam_size, self.unk_idx) 582 | 583 | key_emb, masked_key_emb, value_emb, h_table_field, h_table = self.encode_table(paired_src, paired_mask) 584 | 585 | embsz = value_emb.size(-1) 586 | 587 | mu_prior, logvar_prior = torch.zeros(bsz, self.z_latent_size).contiguous(), torch.zeros(bsz, self.z_latent_size).contiguous() # b x latent 588 | if self.use_cuda: 589 | mu_prior, logvar_prior = mu_prior.cuda(), logvar_prior.cuda() 590 | 591 | z_sample = sample_from_gaussian(mu_prior, logvar_prior, seed=None) # b x latent 592 | 593 | # mu_prior_c = self.zeros.expand(bsz, self.c_latent_size) 594 | # logvar_prior_c = self.zeros.expand(bsz, self.c_latent_size) 595 | 596 | h = self.zeros.expand(1, bsz, self.hid_size).contiguous() 597 | c = self.zeros.expand(1, bsz, self.hid_size).contiguous() # to push x makes the use of information from c 598 | ar_embs = self.word_emb.weight[2].view(1, 1, embsz).expand(1, bsz, embsz) # 1 x b x emb 599 | 600 | if beam_size is not None: 601 | past_p = self.zeros.expand(bsz * beam_size, 1) # init past_p for beam search 602 | else: 603 | past_p = self.zeros.expand(bsz * self.beamsz, 1) 604 | 605 | for t in range(self.max_seqlen): 606 | ar_embs = torch.cat([ar_embs, z_sample.unsqueeze(0), h_table.unsqueeze(0)], dim=-1) # 1 x b x (emb+latent+tabhid) 607 | ar_state, (h, c) = self.word_rnn(ar_embs, (h, c)) 608 | 609 | if self.use_dec_attention: 610 | attn_score_dec, attn_ctx_dec, attn_logits_dec = self.attn_table_hidden.forward(ar_state, h_table_field, 611 | h_table_field, paired_mask, 612 | return_logits=True) 613 | # attn_score_dec, attn_logits_dec: b x seq x nfield 614 | # attn_ctx_dec: b x seq x tabhid 615 | dec_outs = self.generator_out(torch.cat([ar_state, attn_ctx_dec.transpose(0, 1)], dim=-1)) # seq x b x vocab 616 | else: 617 | dec_outs = self.generator_out(ar_state) # seq x b x vocab 618 | 619 | if not real_beam_flag: 620 | next_inp = self.feedback_x(dec_outs)[0][0].item() 621 | else: 622 | word_prob = F.softmax(dec_outs, dim=-1) 623 | word_prob[:, :, self.unk_idx].fill_(0) # disallow generating unk word 624 | 625 | cur_p = self.feedback_x.repeat(word_prob).squeeze(0) 626 | past_p, symbol = self.feedback_x(past_p, cur_p, bsz, t) 627 | next_inp = symbol[0][0].item() 628 | 629 | next_inp = torch.tensor(next_inp).cuda() if self.use_cuda else torch.tensor(next_inp) 630 | ar_embs = self.word_emb(next_inp.unsqueeze(0).unsqueeze(1)).expand(1, bsz, embsz) 631 | 632 | if not real_beam_flag: 633 | sentences_ids = self.feedback_x.collect() 634 | else: 635 | sentences_ids = self.feedback_x.collect(past_p, bsz) 636 | return sentences_ids 637 | 638 | 639 | # def predict_from_temp(self, paired_src, paired_mask, temp_sentence): 640 | # # sentence: dim = seq 641 | # w2i = self.corpus.dictionary.word2idx 642 | # bsz, _, _ = paired_src.size() 643 | # key_emb, masked_key_emb, value_emb, h_table_field, h_table = self.encode(paired_src, paired_mask) 644 | # 645 | # temp_sentence = temp_sentence.unsqueeze(1) # seq x 1 646 | # sent_emb = self.word_emb(temp_sentence) # seq x 1 x emb 647 | # if self.drop_emb: 648 | # sent_emb = self.drop(sent_emb) # seq x 1 x emb 649 | # 650 | # # posterior q(z|x,eK*c) 651 | # h_y0 = torch.zeros(self.layers * 2, bsz, self.hid_size).contiguous() # default bi-rnn, so, 2 layers 652 | # if self.use_cuda: 653 | # h_y0 = h_y0.cuda() 654 | # 655 | # y_out, h_yt = self.rnn_encode(sent_emb, h_y0) # y_out: seq x b x layer*hid, h_yt: layer x 1 x hid 656 | # 657 | # # posterior of z q(z|x) 658 | # posterior_input = self._get_posterior_input(y_out, h_yt) # b x layer*hid 659 | # posterior_out_z = self.z_posterior(posterior_input) # b x latent_z*2 660 | # mu_post_z, logvar_post_z = torch.chunk(posterior_out_z, 2, 1) # both has size b x latent_z 661 | # # sample z from the posterior 662 | # z_sample = sample_from_gaussian(mu_post_z, logvar_post_z) # b x latent_z 663 | # 664 | # mu_prior_c = self.zeros.expand(bsz, self.c_latent_size) 665 | # logvar_prior_c = self.zeros.expand(bsz, self.c_latent_size) 666 | # c_sample = sample_from_gaussian(mu_prior_c, logvar_prior_c) # b x latent_c 667 | # 668 | # embsz = sent_emb.size(-1) 669 | # 670 | # h = self.zeros.expand(1, bsz, self.hid_size).contiguous() 671 | # c = self.zeros.expand(1, bsz, self.hid_size).contiguous() # to push x makes the use of information from c 672 | # ar_embs = self.word_emb.weight[2].view(1, 1, embsz).expand(1, bsz, embsz) # 1 x b x emb 673 | # 674 | # for t in range(self.max_seqlen): 675 | # # print(ar_embs.size(), z_sample.size()) 676 | # ar_embs = torch.cat([ar_embs, z_sample.unsqueeze(0), h_table.unsqueeze(0)], dim=-1) # 1 x b x (emb+latent+tabhid) 677 | # 678 | # # ar_embs = torch.cat([ar_embs, z_sample.unsqueeze(0), c_sample.unsqueeze(0)], dim=-1) # 1 x b x (emb+latent+tabhid) 679 | # ar_state, (h, c) = self.word_rnn(ar_embs, (h, c)) 680 | # 681 | # if self.use_dec_attention: 682 | # attn_score_dec, attn_ctx_dec, attn_logits_dec = self.attn_table_hidden.forward(ar_state, h_table_field, 683 | # h_table_field, 684 | # paired_mask, 685 | # return_logits=True) 686 | # # attn_score_dec, attn_logits_dec: b x seq x nfield 687 | # # attn_ctx_dec: b x seq x tabhid 688 | # dec_outs = self.generator_out( 689 | # torch.cat([ar_state, attn_ctx_dec.transpose(0, 1)], dim=-1)) # seq x b x vocab 690 | # else: 691 | # dec_outs = self.generator_out(ar_state) # seq x b x vocab 692 | # 693 | # next_inp = self.feedback_x(dec_outs)[0][0].item() 694 | # # word_prob = F.softmax(dec_outs, dim=-1) 695 | # # word_prob[:, :, self.unk_idx].fill_(0) # disallow generating unk word 696 | # # next_inp = self.feedback_x(word_prob)[0][0].item() # 1x1 697 | # next_inp = torch.tensor(next_inp).cuda() if self.use_cuda else torch.tensor(next_inp) 698 | # ar_embs = self.word_emb(next_inp.unsqueeze(0).unsqueeze(1)).expand(1, bsz, embsz) 699 | # # ar_embs = torch.mean(self.word_emb(next_inp), dim=0).unsqueeze(0).unsqueeze(1).expand(1, bsz, embsz) 700 | # sentences_ids = self.feedback_x.collect() 701 | # return sentences_ids 702 | 703 | # def inference(self, sentence, return_skeleton=False): 704 | # # sentence: seq x 1 705 | # # sentence = sentence.unsqueeze(1) # seq x 1 706 | # template_feedback = feedback.SampleFeedBack(self.word_emb, self.unk_idx) 707 | # sent_emb = self.word_emb(sentence) # seq x 1 x emb 708 | # if self.drop_emb: 709 | # sent_emb = self.drop(sent_emb) # seq x 1 x emb 710 | # seqlen, bsz, emb_size = sent_emb.size() # bsz=1 711 | # # posterior q(z|x,eK*c) 712 | # h_y0 = torch.zeros(self.layers * 2, bsz, self.hid_size).contiguous() # default bi-rnn, so, 2 layers 713 | # if self.use_cuda: 714 | # h_y0 = h_y0.cuda() 715 | # 716 | # y_out, h_yt = self.rnn_encode(sent_emb, h_y0) # y_out: seq x 1 x layer*hid, h_yt: layer x 1 x hid 717 | # 718 | # # posterior of z q(z|x) 719 | # posterior_input = self._get_posterior_input(y_out, h_yt) # 1 x layer*hid 720 | # posterior_out_z = self.z_posterior(posterior_input) # 1 x latent_z*2 721 | # mu_post_z, logvar_post_z = torch.chunk(posterior_out_z, 2, 1) # both has size 1 x latent_z 722 | # # sample z from the posterior 723 | # z_sample = sample_from_gaussian(mu_post_z, logvar_post_z) # 1 x latent_z 724 | # z_outs = (posterior_out_z, z_sample) 725 | # 726 | # posterior_out_c = self.c_posterior(posterior_input) # 1 x latent_c*2 727 | # mu_post_c, logvar_post_c = torch.chunk(posterior_out_c, 2, 1) 728 | # c_sample = sample_from_gaussian(mu_post_c, logvar_post_c) # 1 x latent_c 729 | # c_outs = (posterior_out_c, c_sample) 730 | # 731 | # if return_skeleton and self.add_skeleton: 732 | # h = self.zeros.expand(1, bsz, self.hid_size).contiguous() 733 | # c = self.zeros.expand(1, bsz, self.hid_size).contiguous() # to push x makes the use of information from c 734 | # ar_embs = self.word_emb.weight[2].view(1, 1, self.emb_size).expand(1, bsz, self.emb_size) # 1 x 1 x emb 735 | # for t in range(self.max_seqlen): 736 | # # print(ar_embs.size(), z_sample.size()) 737 | # ar_embs = torch.cat([ar_embs, z_sample.unsqueeze(0)],dim=-1) # 1 x 1 x (emb+latent_z) 738 | # ar_state, (h, c) = self.sk_rnn(ar_embs, (h, c)) 739 | # dec_outs = self.sk_generator_out(ar_state) # 1 x 1 x vocab 740 | # 741 | # next_inp = template_feedback(dec_outs)[0][0].item() 742 | # 743 | # # sk_word_prob = F.softmax(dec_outs, dim=-1) 744 | # # sk_word_prob[:, :, self.unk_idx].fill_(0) # disallow generating unk word 745 | # # next_inp = self.feedback_x(sk_word_prob)[0][0].item() # 1x1 scalar 746 | # next_inp = torch.tensor(next_inp).cuda() if self.use_cuda else torch.tensor(next_inp) 747 | # ar_embs = self.word_emb(next_inp.unsqueeze(0).unsqueeze(1)).expand(1, bsz, self.emb_size) 748 | # skeleton_ids = template_feedback.collect() 749 | # 750 | # return z_outs, c_outs, skeleton_ids 751 | # else: 752 | # return z_outs, c_outs 753 | --------------------------------------------------------------------------------