├── train_hred.sh ├── train_vhred.sh ├── .gitignore ├── test_hred.sh ├── test_vhred.sh ├── layers ├── __init__.py ├── feedforward.py ├── loss.py ├── rnncells.py ├── beam_search.py ├── encoder.py └── decoder.py ├── utils ├── __init__.py ├── pad.py ├── probability.py ├── bow.py ├── time_track.py ├── mask.py ├── embedding_metric.py ├── convert.py ├── bleu.py └── vocab.py ├── test.py ├── eval_embed.py ├── eval.py ├── train.py ├── README.md ├── data_loader.py ├── metrics.py ├── prepare_data.py ├── configs.py ├── solver.py └── models.py /train_hred.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=5 python -u train.py --model=HRED 2 | -------------------------------------------------------------------------------- /train_vhred.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=5 python -u train.py --model=VHRED 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.log 3 | ckpt_ 4 | /pred*/* 5 | /data*/* 6 | /model*/* 7 | /ckpt*/* 8 | -------------------------------------------------------------------------------- /test_hred.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python -u test.py --model=HRED --checkpoint=./ckpt/data/HRED/2019-12-04_22\:48\:33/40.pkl 2 | -------------------------------------------------------------------------------- /test_vhred.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python -u test.py --model=VHRED --checkpoint=./ckpt/data/VHRED/2019-12-04_22\:48\:54/40.pkl 2 | -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .encoder import * 2 | from .decoder import * 3 | from .rnncells import StackedLSTMCell, StackedGRUCell 4 | from .loss import * 5 | from .feedforward import * 6 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .convert import * 2 | from .time_track import time_desc_decorator 3 | from .vocab import * 4 | from .mask import * 5 | from .probability import * 6 | from .pad import * 7 | from .bow import * 8 | -------------------------------------------------------------------------------- /utils/pad.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from .convert import to_var 4 | 5 | 6 | def pad(tensor, length): 7 | if isinstance(tensor, Variable): 8 | var = tensor 9 | if length > var.size(0): 10 | return torch.cat([var, 11 | torch.zeros(length - var.size(0), *var.size()[1:]).cuda()]) 12 | else: 13 | return var 14 | else: 15 | if length > tensor.size(0): 16 | return torch.cat([tensor, 17 | torch.zeros(length - tensor.size(0), *tensor.size()[1:]).cuda()]) 18 | else: 19 | return tensor 20 | 21 | 22 | def pad_and_pack(tensor_list): 23 | length_list = ([t.size(0) for t in tensor_list]) 24 | max_len = max(length_list) 25 | padded = [pad(t, max_len) for t in tensor_list] 26 | packed = torch.stack(padded, 0) 27 | return packed, length_list 28 | -------------------------------------------------------------------------------- /utils/probability.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .convert import to_var 4 | 5 | 6 | def normal_logpdf(x, mean, var): 7 | """ 8 | Args: 9 | x: (Variable, FloatTensor) [batch_size, dim] 10 | mean: (Variable, FloatTensor) [batch_size, dim] or [batch_size] or [1] 11 | var: (Variable, FloatTensor) [batch_size, dim]: positive value 12 | Return: 13 | log_p: (Variable, FloatTensor) [batch_size] 14 | """ 15 | 16 | pi = to_var(torch.FloatTensor([np.pi])) 17 | return 0.5 * torch.sum(-torch.log(2.0 * pi) - torch.log(var) - ((x - mean).pow(2) / var), dim=1) 18 | 19 | 20 | def normal_kl_div(mu1, var1, 21 | mu2=to_var(torch.FloatTensor([0.0])), 22 | var2=to_var(torch.FloatTensor([1.0]))): 23 | one = to_var(torch.FloatTensor([1.0])) 24 | return torch.sum(0.5 * (torch.log(var2) - torch.log(var1) 25 | + (var1 + (mu1 - mu2).pow(2)) / var2 - one), 1) 26 | -------------------------------------------------------------------------------- /layers/feedforward.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class FeedForward(nn.Module): 6 | def __init__(self, input_size, output_size, num_layers=1, hidden_size=None, 7 | activation="Tanh", bias=True): 8 | super(FeedForward, self).__init__() 9 | self.input_size = input_size 10 | self.output_size = output_size 11 | self.hidden_size = hidden_size 12 | self.num_layers = num_layers 13 | self.activation = getattr(nn, activation)() 14 | n_inputs = [input_size] + [hidden_size] * (num_layers - 1) 15 | n_outputs = [hidden_size] * (num_layers - 1) + [output_size] 16 | self.linears = nn.ModuleList([nn.Linear(n_in, n_out, bias=bias) 17 | for n_in, n_out in zip(n_inputs, n_outputs)]) 18 | 19 | def forward(self, input): 20 | x = input 21 | for linear in self.linears: 22 | x = linear(x) 23 | x = self.activation(x) 24 | 25 | return x 26 | -------------------------------------------------------------------------------- /utils/bow.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import Counter 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | import torch 6 | from math import isnan 7 | from .vocab import PAD_ID, EOS_ID 8 | 9 | 10 | def to_bow(sentence, vocab_size): 11 | ''' Convert a sentence into a bag of words representation 12 | Args 13 | - sentence: a list of token ids 14 | - vocab_size: V 15 | Returns 16 | - bow: a integer vector of size V 17 | ''' 18 | bow = Counter(sentence) 19 | # Remove EOS tokens 20 | bow[PAD_ID] = 0 21 | bow[EOS_ID] = 0 22 | 23 | x = np.zeros(vocab_size, dtype=np.int64) 24 | x[list(bow.keys())] = list(bow.values()) 25 | 26 | return x 27 | 28 | 29 | def bag_of_words_loss(bow_logits, target_bow, weight=None): 30 | ''' Calculate bag of words representation loss 31 | Args 32 | - bow_logits: [num_sentences, vocab_size] 33 | - target_bow: [num_sentences] 34 | ''' 35 | log_probs = F.log_softmax(bow_logits, dim=1) 36 | target_distribution = target_bow / (target_bow.sum(1).view(-1, 1) + 1e-23) + 1e-23 37 | entropy = -(torch.log(target_distribution) * target_bow).sum() 38 | loss = -(log_probs * target_bow).sum() - entropy 39 | 40 | return loss 41 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from solver import Solver, VariationalSolver 2 | from data_loader import get_loader 3 | from configs import get_config 4 | from utils import Vocab 5 | import os 6 | import pickle 7 | from models import VariationalModels 8 | import re 9 | 10 | 11 | def load_pickle(path): 12 | with open(path, 'rb') as f: 13 | return pickle.load(f) 14 | 15 | 16 | if __name__ == '__main__': 17 | config = get_config(mode='test') 18 | 19 | print('Loading Vocabulary...') 20 | vocab = Vocab(lang="zh") 21 | vocab.load(config.word2id_path, config.id2word_path) 22 | print(f'Vocabulary size: {vocab.vocab_size}') 23 | 24 | config.vocab_size = vocab.vocab_size 25 | data_loader = get_loader( 26 | sentences=load_pickle(config.sentences_path), 27 | conversation_length=load_pickle(config.conversation_length_path), 28 | sentence_length=load_pickle(config.sentence_length_path), 29 | vocab=vocab, 30 | batch_size=1, 31 | shuffle=False) 32 | 33 | if config.model in VariationalModels: 34 | solver = VariationalSolver(config, None, data_loader, vocab=vocab, is_train=False) 35 | else: 36 | solver = Solver(config, None, data_loader, vocab=vocab, is_train=False) 37 | 38 | solver.build() 39 | solver.generate_for_evaluation() 40 | -------------------------------------------------------------------------------- /eval_embed.py: -------------------------------------------------------------------------------- 1 | from solver import Solver, VariationalSolver 2 | from data_loader import get_loader 3 | from configs import get_config 4 | from utils import Vocab, Tokenizer 5 | import os 6 | import pickle 7 | from models import VariationalModels 8 | import re 9 | 10 | 11 | def load_pickle(path): 12 | with open(path, 'rb') as f: 13 | return pickle.load(f) 14 | 15 | 16 | if __name__ == '__main__': 17 | config = get_config(mode='test') 18 | 19 | print('Loading Vocabulary...') 20 | vocab = Vocab() 21 | vocab.load(config.word2id_path, config.id2word_path) 22 | print(f'Vocabulary size: {vocab.vocab_size}') 23 | 24 | config.vocab_size = vocab.vocab_size 25 | 26 | data_loader = get_loader( 27 | sentences=load_pickle(config.sentences_path), 28 | conversation_length=load_pickle(config.conversation_length_path), 29 | sentence_length=load_pickle(config.sentence_length_path), 30 | vocab=vocab, 31 | batch_size=config.batch_size, 32 | shuffle=False) 33 | 34 | if config.model in VariationalModels: 35 | solver = VariationalSolver(config, None, data_loader, vocab=vocab, is_train=False) 36 | else: 37 | solver = Solver(config, None, data_loader, vocab=vocab, is_train=False) 38 | 39 | solver.build() 40 | solver.embedding_metric() 41 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from solver import Solver, VariationalSolver 2 | from data_loader import get_loader 3 | from configs import get_config 4 | from utils import Vocab, Tokenizer 5 | import os 6 | import pickle 7 | from models import VariationalModels 8 | 9 | 10 | def load_pickle(path): 11 | with open(path, 'rb') as f: 12 | return pickle.load(f) 13 | 14 | 15 | if __name__ == '__main__': 16 | config = get_config(mode='test') 17 | 18 | print('Loading Vocabulary...') 19 | vocab = Vocab() 20 | vocab.load(config.word2id_path, config.id2word_path) 21 | print(f'Vocabulary size: {vocab.vocab_size}') 22 | 23 | config.vocab_size = vocab.vocab_size 24 | 25 | data_loader = get_loader( 26 | sentences=load_pickle(config.sentences_path), 27 | conversation_length=load_pickle(config.conversation_length_path), 28 | sentence_length=load_pickle(config.sentence_length_path), 29 | vocab=vocab, 30 | batch_size=config.batch_size) 31 | 32 | if config.model in VariationalModels: 33 | solver = VariationalSolver(config, None, data_loader, vocab=vocab, is_train=False) 34 | solver.build() 35 | solver.importance_sample() 36 | else: 37 | solver = Solver(config, None, data_loader, vocab=vocab, is_train=False) 38 | solver.build() 39 | solver.test() 40 | -------------------------------------------------------------------------------- /utils/time_track.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import partial 3 | 4 | 5 | def base_time_desc_decorator(method, desc='test_description'): 6 | def timed(*args, **kwargs): 7 | 8 | # Print Description 9 | # print('#' * 50) 10 | print(desc) 11 | # print('#' * 50 + '\n') 12 | 13 | # Calculation Runtime 14 | start = time.time() 15 | 16 | # Run Method 17 | try: 18 | result = method(*args, **kwargs) 19 | except TypeError: 20 | result = method(**kwargs) 21 | 22 | # Print Runtime 23 | print('Done! It took {:.2} secs\n'.format(time.time() - start)) 24 | 25 | if result is not None: 26 | return result 27 | 28 | return timed 29 | 30 | 31 | def time_desc_decorator(desc): return partial(base_time_desc_decorator, desc=desc) 32 | 33 | 34 | @time_desc_decorator('this is description') 35 | def time_test(arg, kwarg='this is kwarg'): 36 | time.sleep(3) 37 | print('Inside of time_test') 38 | print('printing arg: ', arg) 39 | print('printing kwarg: ', kwarg) 40 | 41 | 42 | @time_desc_decorator('this is second description') 43 | def no_arg_method(): 44 | print('this method has no argument') 45 | 46 | 47 | if __name__ == '__main__': 48 | time_test('hello', kwarg=3) 49 | time_test(3) 50 | no_arg_method() 51 | -------------------------------------------------------------------------------- /utils/mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .convert import to_var 3 | 4 | 5 | def sequence_mask(sequence_length, max_len=None): 6 | """ 7 | Args: 8 | sequence_length (Variable, LongTensor) [batch_size] 9 | - list of sequence length of each batch 10 | max_len (int) 11 | Return: 12 | masks (bool): [batch_size, max_len] 13 | - True if current sequence is valid (not padded), False otherwise 14 | 15 | Ex. 16 | sequence length: [3, 2, 1] 17 | 18 | seq_length_expand 19 | [[3, 3, 3], 20 | [2, 2, 2] 21 | [1, 1, 1]] 22 | 23 | seq_range_expand 24 | [[0, 1, 2] 25 | [0, 1, 2], 26 | [0, 1, 2]] 27 | 28 | masks 29 | [[True, True, True], 30 | [True, True, False], 31 | [True, False, False]] 32 | """ 33 | if max_len is None: 34 | max_len = sequence_length.max() 35 | batch_size = sequence_length.size(0) 36 | 37 | # [max_len] 38 | seq_range = torch.arange(0, max_len).long() # [0, 1, ... max_len-1] 39 | 40 | # [batch_size, max_len] 41 | seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) 42 | seq_range_expand = to_var(seq_range_expand) 43 | 44 | # [batch_size, max_len] 45 | seq_length_expand = sequence_length.unsqueeze(1).expand_as(seq_range_expand) 46 | 47 | # [batch_size, max_len] 48 | masks = seq_range_expand < seq_length_expand 49 | 50 | return masks 51 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from solver import * 2 | from data_loader import get_loader 3 | from configs import get_config 4 | from utils import Vocab 5 | import os 6 | import pickle 7 | from models import VariationalModels 8 | 9 | def load_pickle(path): 10 | with open(path, 'rb') as f: 11 | return pickle.load(f) 12 | 13 | 14 | if __name__ == '__main__': 15 | config = get_config(mode='train') 16 | val_config = get_config(mode='valid') 17 | print(config) 18 | with open(os.path.join(config.save_path, 'config.txt'), 'w') as f: 19 | print(config, file=f) 20 | 21 | print('Loading Vocabulary...') 22 | vocab = Vocab() 23 | vocab.load(config.word2id_path, config.id2word_path) 24 | print(f'Vocabulary size: {vocab.vocab_size}') 25 | 26 | config.vocab_size = vocab.vocab_size 27 | 28 | train_data_loader = get_loader( 29 | sentences=load_pickle(config.sentences_path), 30 | conversation_length=load_pickle(config.conversation_length_path), 31 | sentence_length=load_pickle(config.sentence_length_path), 32 | vocab=vocab, 33 | batch_size=config.batch_size) 34 | 35 | eval_data_loader = get_loader( 36 | sentences=load_pickle(val_config.sentences_path), 37 | conversation_length=load_pickle(val_config.conversation_length_path), 38 | sentence_length=load_pickle(val_config.sentence_length_path), 39 | vocab=vocab, 40 | batch_size=val_config.eval_batch_size, 41 | shuffle=False) 42 | 43 | # for testing 44 | # train_data_loader = eval_data_loader 45 | if config.model in VariationalModels: 46 | solver = VariationalSolver 47 | else: 48 | solver = Solver 49 | 50 | solver = solver(config, train_data_loader, eval_data_loader, vocab=vocab, is_train=True) 51 | 52 | solver.build() 53 | solver.train() 54 | -------------------------------------------------------------------------------- /layers/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | import torch.nn as nn 4 | from utils import to_var, sequence_mask 5 | 6 | 7 | # https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1 8 | def masked_cross_entropy(logits, target, length, per_example=False): 9 | """ 10 | Args: 11 | logits (Variable, FloatTensor): [batch, max_len, num_classes] 12 | - unnormalized probability for each class 13 | target (Variable, LongTensor): [batch, max_len] 14 | - index of true class for each corresponding step 15 | length (Variable, LongTensor): [batch] 16 | - length of each data in a batch 17 | Returns: 18 | loss (Variable): [] 19 | - An average loss value masked by the length 20 | """ 21 | batch_size, max_len, num_classes = logits.size() 22 | 23 | # [batch_size * max_len, num_classes] 24 | logits_flat = logits.view(-1, num_classes) 25 | 26 | # [batch_size * max_len, num_classes] 27 | log_probs_flat = F.log_softmax(logits_flat, dim=1) 28 | 29 | # [batch_size * max_len, 1] 30 | target_flat = target.view(-1, 1) 31 | 32 | # Negative Log-likelihood: -sum { 1* log P(target) + 0 log P(non-target)} = -sum( log P(target) ) 33 | # [batch_size * max_len, 1] 34 | losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat) 35 | 36 | # [batch_size, max_len] 37 | losses = losses_flat.view(batch_size, max_len) 38 | 39 | # [batch_size, max_len] 40 | mask = sequence_mask(sequence_length=length, max_len=max_len) 41 | 42 | # Apply masking on loss 43 | losses = losses * mask.float() 44 | 45 | # word-wise cross entropy 46 | # loss = losses.sum() / length.float().sum() 47 | 48 | if per_example: 49 | # loss: [batch_size] 50 | return losses.sum(1) 51 | else: 52 | loss = losses.sum() 53 | return loss, length.float().sum() 54 | -------------------------------------------------------------------------------- /utils/embedding_metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def cosine_similarity(s, g): 5 | similarity = np.sum(s * g, axis=1) / np.sqrt((np.sum(s * s, axis=1) * np.sum(g * g, axis=1))) 6 | 7 | # return np.sum(similarity) 8 | return similarity 9 | 10 | 11 | def embedding_metric(samples, ground_truth, word2vec, method='average'): 12 | 13 | if method == 'average': 14 | # s, g: [n_samples, word_dim] 15 | s = [np.mean(sample, axis=0) for sample in samples] 16 | g = [np.mean(gt, axis=0) for gt in ground_truth] 17 | return cosine_similarity(np.array(s), np.array(g)) 18 | elif method == 'extrema': 19 | s_list = [] 20 | g_list = [] 21 | for sample, gt in zip(samples, ground_truth): 22 | s_max = np.max(sample, axis=0) 23 | s_min = np.min(sample, axis=0) 24 | s_plus = np.absolute(s_min) <= s_max 25 | s_abs = np.max(np.absolute(sample), axis=0) 26 | s = s_max * s_plus + s_min * np.logical_not(s_plus) 27 | s_list.append(s) 28 | 29 | g_max = np.max(gt, axis=0) 30 | g_min = np.min(gt, axis=0) 31 | g_plus = np.absolute(g_min) <= g_max 32 | g_abs = np.max(np.absolute(gt), axis=0) 33 | g = g_max * g_plus + g_min * np.logical_not(g_plus) 34 | g_list.append(g) 35 | 36 | return cosine_similarity(np.array(s_list), np.array(g_list)) 37 | elif method == 'greedy': 38 | sim_list = [] 39 | for s, g in zip(samples, ground_truth): 40 | s = np.array(s) 41 | g = np.array(g).T 42 | sim = (np.matmul(s, g) 43 | / np.sqrt(np.matmul(np.sum(s * s, axis=1, keepdims=True), np.sum(g * g, axis=0, keepdims=True)))) 44 | sim = np.max(sim, axis=0) 45 | sim_list.append(np.mean(sim)) 46 | 47 | # return np.sum(sim_list) 48 | return np.array(sim_list) 49 | else: 50 | raise NotImplementedError 51 | -------------------------------------------------------------------------------- /utils/convert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | 4 | 5 | def to_var(x, on_cpu=False, gpu_id=None): 6 | """Tensor => Variable""" 7 | if torch.cuda.is_available() and not on_cpu: 8 | x = x.cuda(gpu_id) 9 | #x = Variable(x) 10 | return x 11 | 12 | 13 | def to_tensor(x): 14 | """Variable => Tensor""" 15 | if torch.cuda.is_available(): 16 | x = x.cpu() 17 | return x.data 18 | 19 | def reverse_order(tensor, dim=0): 20 | """Reverse Tensor or Variable""" 21 | if isinstance(tensor, torch.Tensor) or isinstance(tensor, torch.LongTensor): 22 | idx = [i for i in range(tensor.size(dim)-1, -1, -1)] 23 | idx = torch.LongTensor(idx) 24 | inverted_tensor = tensor.index_select(dim, idx) 25 | if isinstance(tensor, torch.cuda.FloatTensor) or isinstance(tensor, torch.cuda.LongTensor): 26 | idx = [i for i in range(tensor.size(dim)-1, -1, -1)] 27 | idx = torch.cuda.LongTensor(idx) 28 | inverted_tensor = tensor.index_select(dim, idx) 29 | return inverted_tensor 30 | elif isinstance(tensor, Variable): 31 | variable = tensor 32 | variable.data = reverse_order(variable.data, dim) 33 | return variable 34 | 35 | def reverse_order_valid(tensor, length_list, dim=0): 36 | """ 37 | Reverse Tensor of Variable only in given length 38 | Ex) 39 | Args: 40 | - tensor (Tensor or Variable) 41 | 1 2 3 4 5 6 42 | 6 7 8 9 0 0 43 | 11 12 13 0 0 0 44 | 16 17 0 0 0 0 45 | 21 22 23 24 25 26 46 | 47 | - length_list (list) 48 | [6, 4, 3, 2, 6] 49 | 50 | Return: 51 | tensor (Tensor or Variable; in-place) 52 | 6 5 4 3 2 1 53 | 0 0 9 8 7 6 54 | 0 0 0 13 12 11 55 | 0 0 0 0 17 16 56 | 26 25 24 23 22 21 57 | """ 58 | for row, length in zip(tensor, length_list): 59 | valid_row = row[:length] 60 | reversed_valid_row = reverse_order(valid_row, dim=dim) 61 | row[:length] = reversed_valid_row 62 | return tensor 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HRED VHRED VHCR for Multi-Turn Dialogue Systems 2 | - modified based on: https://github.com/ctr4si/A-Hierarchical-Latent-Structure-for-Variational-Conversation-Modeling 3 | 4 | 5 | 6 | ## Preprocess data 7 | ./data: train.txt, dev.txt, test.txt 8 | 9 | format: 10 | ``` 11 | u1 u2 u3 \t response 12 | ``` 13 | example: 14 | ``` 15 | w11 w12 w13 w21 w22 w31 w32 w33 w34 \t w1 w2 w3 16 | ``` 17 | 18 | then: 19 | ``` 20 | python prepare_data.py 21 | ``` 22 | 23 | 24 | 25 | ## Training 26 | Go to the model directory and set the save_dir in configs.py (this is where the model checkpoints will be saved) 27 | 28 | We provide our implementation of VHCR, as well as our reference implementations for [HRED](https://arxiv.org/abs/1507.02221) and [VHRED](https://arxiv.org/abs/1605.06069). 29 | 30 | To run training: 31 | ``` 32 | python train.py --model= --batch_size= 33 | ``` 34 | 35 | For example: 36 | 1. Train HRED: 37 | ``` 38 | python train.py --model=HRED 39 | ``` 40 | 41 | 2. Train VHRED with word drop of ratio 0.25 and kl annealing iterations 250000: 42 | ``` 43 | python train.py --model=VHRED --batch_size=40 --word_drop=0.25 --kl_annealing_iter=250000 44 | ``` 45 | 46 | 3. Train VHCR with utterance drop of ratio 0.25: 47 | ``` 48 | python train.py --model=VHCR --batch_size=40 --sentence_drop=0.25 --kl_annealing_iter=250000 49 | ``` 50 | 51 | 52 | 53 | 54 | ## Evaluation 55 | To evaluate the word perplexity: 56 | ``` 57 | python eval.py --model= --checkpoint= 58 | ``` 59 | 60 | For embedding based metrics, you need to download [Google News word vectors](https://drive.google.com/file/d/0B7XkCwpI5KDYNlNUTTlSS21pQmM/edit?usp=sharing), unzip it and put it under the datasets folder. 61 | Then run: 62 | ``` 63 | python eval_embed.py --model= --checkpoint= 64 | ``` 65 | 66 | ## Generation 67 | To generate the response for the test set: 68 | ``` 69 | python test.py --model= --checkpoint= 70 | ``` 71 | 72 | ## BLEU and DIST 73 | ``` 74 | python metrics.py 75 | ``` 76 | 77 | 78 | ## Reference 79 | - https://github.com/ctr4si/A-Hierarchical-Latent-Structure-for-Variational-Conversation-Modeling 80 | -------------------------------------------------------------------------------- /layers/rnncells.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenNMT.py, Z-forcing 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | #from torch.nn._functions.thnn.rnnFusedPointwise import LSTMFused, GRUFused 8 | 9 | 10 | class StackedLSTMCell(nn.Module): 11 | 12 | def __init__(self, num_layers, input_size, rnn_size, dropout): 13 | super(StackedLSTMCell, self).__init__() 14 | self.dropout = nn.Dropout(dropout) 15 | self.num_layers = num_layers 16 | 17 | self.layers = nn.ModuleList() 18 | for i in range(num_layers): 19 | self.layers.append(nn.LSTMCell(input_size, rnn_size)) 20 | input_size = rnn_size 21 | 22 | def forward(self, x, h_c): 23 | """ 24 | Args: 25 | x: [batch_size, input_size] 26 | h_c: [2, num_layers, batch_size, hidden_size] 27 | Return: 28 | last_h_c: [2, batch_size, hidden_size] (h from last layer) 29 | h_c_list: [2, num_layers, batch_size, hidden_size] (h and c from all layers) 30 | """ 31 | h_0, c_0 = h_c 32 | h_list, c_list = [], [] 33 | for i, layer in enumerate(self.layers): 34 | # h of i-th layer 35 | h_i, c_i = layer(x, (h_0[i], c_0[i])) 36 | 37 | # x for next layer 38 | x = h_i 39 | if i + 1 != self.num_layers: 40 | x = self.dropout(x) 41 | h_list += [h_i] 42 | c_list += [c_i] 43 | 44 | last_h_c = (h_list[-1], c_list[-1]) 45 | h_list = torch.stack(h_list) 46 | c_list = torch.stack(c_list) 47 | h_c_list = (h_list, c_list) 48 | 49 | return last_h_c, h_c_list 50 | 51 | 52 | class StackedGRUCell(nn.Module): 53 | 54 | def __init__(self, num_layers, input_size, rnn_size, dropout): 55 | super(StackedGRUCell, self).__init__() 56 | self.dropout = nn.Dropout(dropout) 57 | self.num_layers = num_layers 58 | 59 | self.layers = nn.ModuleList() 60 | for i in range(num_layers): 61 | self.layers.append(nn.GRUCell(input_size, rnn_size)) 62 | input_size = rnn_size 63 | 64 | def forward(self, x, h): 65 | """ 66 | Args: 67 | x: [batch_size, input_size] 68 | h: [num_layers, batch_size, hidden_size] 69 | Return: 70 | last_h: [batch_size, hidden_size] (h from last layer) 71 | h_list: [num_layers, batch_size, hidden_size] (h from all layers) 72 | """ 73 | # h of all layers 74 | h_list = [] 75 | for i, layer in enumerate(self.layers): 76 | # h of i-th layer 77 | h_i = layer(x, h[i]) 78 | 79 | # x for next layer 80 | x = h_i 81 | if i + 1 is not self.num_layers: 82 | x = self.dropout(x) 83 | h_list.append(h_i) 84 | 85 | last_h = h_list[-1] 86 | h_list = torch.stack(h_list) 87 | 88 | return last_h, h_list 89 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import defaultdict 3 | from torch.utils.data import Dataset, DataLoader 4 | from utils import PAD_ID, UNK_ID, SOS_ID, EOS_ID 5 | import numpy as np 6 | 7 | 8 | class DialogDataset(Dataset): 9 | def __init__(self, sentences, conversation_length, sentence_length, vocab, data=None): 10 | 11 | # [total_data_size, max_conversation_length, max_sentence_length] 12 | # tokenized raw text of sentences 13 | self.sentences = sentences 14 | self.vocab = vocab 15 | 16 | # conversation length of each batch 17 | # [total_data_size] 18 | self.conversation_length = conversation_length 19 | 20 | # list of length of sentences 21 | # [total_data_size, max_conversation_length] 22 | self.sentence_length = sentence_length 23 | self.data = data 24 | self.len = len(sentences) 25 | 26 | def __getitem__(self, index): 27 | """Return Single data sentence""" 28 | # [max_conversation_length, max_sentence_length] 29 | sentence = self.sentences[index] 30 | conversation_length = self.conversation_length[index] 31 | sentence_length = self.sentence_length[index] 32 | 33 | # word => word_ids 34 | sentence = self.sent2id(sentence) 35 | 36 | return sentence, conversation_length, sentence_length 37 | 38 | def __len__(self): 39 | return self.len 40 | 41 | def sent2id(self, sentences): 42 | """word => word id""" 43 | # [max_conversation_length, max_sentence_length] 44 | return [self.vocab.sent2id(sentence) for sentence in sentences] 45 | 46 | 47 | def get_loader(sentences, conversation_length, sentence_length, vocab, batch_size=100, data=None, shuffle=True): 48 | """Load DataLoader of given DialogDataset""" 49 | 50 | def collate_fn(data): 51 | """ 52 | Collate list of data in to batch 53 | 54 | Args: 55 | data: list of tuple(source, target, conversation_length, source_length, target_length) 56 | Return: 57 | Batch of each feature 58 | - source (LongTensor): [batch_size, max_conversation_length, max_source_length] 59 | - target (LongTensor): [batch_size, max_conversation_length, max_source_length] 60 | - conversation_length (np.array): [batch_size] 61 | - source_length (LongTensor): [batch_size, max_conversation_length] 62 | """ 63 | # Sort by conversation length (descending order) to use 'pack_padded_sequence' 64 | data.sort(key=lambda x: x[1], reverse=True) 65 | 66 | # Separate 67 | sentences, conversation_length, sentence_length = zip(*data) 68 | 69 | # return sentences, conversation_length, sentence_length.tolist() 70 | return sentences, conversation_length, sentence_length 71 | 72 | dataset = DialogDataset(sentences, conversation_length, 73 | sentence_length, vocab, data=data) 74 | 75 | data_loader = DataLoader( 76 | dataset=dataset, 77 | batch_size=batch_size, 78 | shuffle=shuffle, 79 | collate_fn=collate_fn) 80 | 81 | return data_loader 82 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from utils.bleu import * 4 | 5 | pred = [] 6 | 7 | with open("./pred/res.txt", "r", encoding='utf-8', errors='ignore') as f: 8 | for line in f: 9 | line = line.strip() 10 | if not line: 11 | continue 12 | fs = line.split("\t") 13 | if len(fs) != 3: 14 | print("error", line) 15 | q, g = fs 16 | r = "" 17 | else: 18 | q, g, r = fs 19 | g = [w for w in g] 20 | r = [w for w in r] 21 | pred.append((g, r)) 22 | 23 | 24 | def get_bleu(): 25 | ma_bleu = 0. 26 | ma_bleu1 = 0. 27 | ma_bleu2 = 0. 28 | ma_bleu3 = 0. 29 | ma_bleu4 = 0. 30 | ref_lst = [] 31 | hyp_lst = [] 32 | for g, r in pred: 33 | references = [g] 34 | hypothesis = r 35 | ref_lst.append(references) 36 | hyp_lst.append(hypothesis) 37 | bleu, precisions, _, _, _, _ = compute_bleu([references], [hypothesis], smooth=False) 38 | ma_bleu += bleu 39 | ma_bleu1 += precisions[0] 40 | ma_bleu2 += precisions[1] 41 | ma_bleu3 += precisions[2] 42 | ma_bleu4 += precisions[3] 43 | n = len(pred) 44 | ma_bleu /= n 45 | ma_bleu1 /= n 46 | ma_bleu2 /= n 47 | ma_bleu3 /= n 48 | ma_bleu4 /= n 49 | 50 | mi_bleu, precisions, _, _, _, _ = compute_bleu(ref_lst, hyp_lst, smooth=False) 51 | mi_bleu1, mi_bleu2, mi_bleu3, mi_bleu4 = precisions[0], precisions[1], precisions[2], precisions[3] 52 | return ma_bleu, ma_bleu1, ma_bleu2, ma_bleu3, ma_bleu4,\ 53 | mi_bleu, mi_bleu1, mi_bleu2, mi_bleu3, mi_bleu4 54 | 55 | 56 | def get_dist(): 57 | unigrams = [] 58 | bigrams = [] 59 | ma_dist1, ma_dist2 = 0., 0. 60 | avg_len = 0. 61 | for g, r in pred: 62 | ugs = r 63 | bgs = [] 64 | i = 0 65 | while i < len(ugs) - 1: 66 | bgs.append(ugs[i] + ugs[i+1]) 67 | i += 1 68 | unigrams += ugs 69 | bigrams += bgs 70 | ma_dist1 += len(set(ugs)) / (float)(len(ugs) + 1e-16) 71 | ma_dist2 += len(set(bgs)) / (float)(len(bgs) + 1e-16) 72 | avg_len += len(ugs) 73 | n = len(pred) 74 | ma_dist1 /= n 75 | ma_dist2 /= n 76 | mi_dist1 = len(set(unigrams)) / (float)(len(unigrams)) 77 | mi_dist2 = len(set(bigrams)) / (float)(len(bigrams)) 78 | avg_len /= n 79 | return ma_dist1, ma_dist2, mi_dist1, mi_dist2, avg_len 80 | 81 | 82 | if True: 83 | ma_bleu, ma_bleu1, ma_bleu2, ma_bleu3, ma_bleu4,\ 84 | mi_bleu, mi_bleu1, mi_bleu2, mi_bleu3, mi_bleu4 = get_bleu() 85 | 86 | ma_dist1, ma_dist2, mi_dist1, mi_dist2, avg_len = get_dist() 87 | 88 | print("ma_bleu", ma_bleu) 89 | print("ma_bleu1", ma_bleu1) 90 | print("ma_bleu2", ma_bleu2) 91 | print("ma_bleu3", ma_bleu3) 92 | print("ma_bleu4", ma_bleu4) 93 | print("mi_bleu", mi_bleu) 94 | print("mi_bleu1", mi_bleu1) 95 | print("mi_bleu2", mi_bleu2) 96 | print("mi_bleu3", mi_bleu3) 97 | print("mi_bleu4", mi_bleu4) 98 | print("ma_dist1", ma_dist1) 99 | print("ma_dist2", ma_dist2) 100 | print("mi_dist1", mi_dist1) 101 | print("mi_dist2", mi_dist2) 102 | print("avg_len", avg_len) 103 | 104 | # to tex format 105 | print("& %.2f & %.2f & %.2f & %.2f & %.2f & %.2f & %.2f & %.2f & %.2f & %.2f" \ 106 | % (ma_bleu*100, ma_bleu1*100, ma_bleu2*100, \ 107 | ma_bleu3*100, ma_bleu4*100, ma_dist1*100, \ 108 | ma_dist2*100, mi_dist1*100, mi_dist2*100, avg_len)) 109 | 110 | print() 111 | 112 | -------------------------------------------------------------------------------- /utils/bleu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Python implementation of BLEU and smooth-BLEU. 17 | 18 | This module provides a Python implementation of BLEU and smooth-BLEU. 19 | Smooth BLEU is computed following the method outlined in the paper: 20 | Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic 21 | evaluation metrics for machine translation. COLING 2004. 22 | """ 23 | 24 | import collections 25 | import math 26 | 27 | 28 | def _get_ngrams(segment, max_order): 29 | """Extracts all n-grams upto a given maximum order from an input segment. 30 | 31 | Args: 32 | segment: text segment from which n-grams will be extracted. 33 | max_order: maximum length in tokens of the n-grams returned by this 34 | methods. 35 | 36 | Returns: 37 | The Counter containing all n-grams upto max_order in segment 38 | with a count of how many times each n-gram occurred. 39 | """ 40 | ngram_counts = collections.Counter() 41 | for order in range(1, max_order + 1): 42 | for i in range(0, len(segment) - order + 1): 43 | ngram = tuple(segment[i:i+order]) 44 | ngram_counts[ngram] += 1 45 | return ngram_counts 46 | 47 | 48 | def compute_bleu(reference_corpus, translation_corpus, max_order=4, 49 | smooth=False): 50 | """Computes BLEU score of translated segments against one or more references. 51 | 52 | Args: 53 | reference_corpus: list of lists of references for each translation. Each 54 | reference should be tokenized into a list of tokens. 55 | translation_corpus: list of translations to score. Each translation 56 | should be tokenized into a list of tokens. 57 | max_order: Maximum n-gram order to use when computing BLEU score. 58 | smooth: Whether or not to apply Lin et al. 2004 smoothing. 59 | 60 | Returns: 61 | 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram 62 | precisions and brevity penalty. 63 | """ 64 | matches_by_order = [0] * max_order 65 | possible_matches_by_order = [0] * max_order 66 | reference_length = 0 67 | translation_length = 0 68 | for (references, translation) in zip(reference_corpus, 69 | translation_corpus): 70 | reference_length += min(len(r) for r in references) 71 | translation_length += len(translation) 72 | 73 | merged_ref_ngram_counts = collections.Counter() 74 | for reference in references: 75 | merged_ref_ngram_counts |= _get_ngrams(reference, max_order) 76 | translation_ngram_counts = _get_ngrams(translation, max_order) 77 | overlap = translation_ngram_counts & merged_ref_ngram_counts 78 | for ngram in overlap: 79 | matches_by_order[len(ngram)-1] += overlap[ngram] 80 | for order in range(1, max_order+1): 81 | possible_matches = len(translation) - order + 1 82 | if possible_matches > 0: 83 | possible_matches_by_order[order-1] += possible_matches 84 | 85 | precisions = [0] * max_order 86 | for i in range(0, max_order): 87 | if smooth: 88 | precisions[i] = ((matches_by_order[i] + 1.) / 89 | (possible_matches_by_order[i] + 1.)) 90 | else: 91 | if possible_matches_by_order[i] > 0: 92 | precisions[i] = (float(matches_by_order[i]) / 93 | possible_matches_by_order[i]) 94 | else: 95 | precisions[i] = 0.0 96 | 97 | if min(precisions) > 0: 98 | p_log_sum = sum((1. / max_order) * math.log(p) for p in precisions) 99 | geo_mean = math.exp(p_log_sum) 100 | else: 101 | geo_mean = 0 102 | 103 | ratio = float(translation_length) / reference_length 104 | 105 | if ratio > 1.0: 106 | bp = 1. 107 | else: 108 | bp = math.exp(1 - 1. / (ratio + 1e-16)) 109 | 110 | bleu = geo_mean * bp 111 | 112 | return (bleu, precisions, bp, ratio, translation_length, reference_length) 113 | -------------------------------------------------------------------------------- /utils/vocab.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import pickle 3 | import torch 4 | from torch import Tensor 5 | from torch.autograd import Variable 6 | from nltk import FreqDist 7 | from .convert import to_tensor, to_var 8 | 9 | PAD_TOKEN = '' 10 | UNK_TOKEN = '' 11 | SOS_TOKEN = '' 12 | EOS_TOKEN = '' 13 | 14 | PAD_ID, UNK_ID, SOS_ID, EOS_ID = [0, 1, 2, 3] 15 | 16 | 17 | class Vocab(object): 18 | def __init__(self, lang="en", max_size=None, min_freq=1): 19 | """Basic Vocabulary object""" 20 | self.lang = lang 21 | self.vocab_size = 0 22 | self.freqdist = FreqDist() 23 | 24 | def update(self, max_size=None, min_freq=1): 25 | """ 26 | Initialize id2word & word2id based on self.freqdist 27 | max_size include 4 special tokens 28 | """ 29 | 30 | # {0: '', 1: '', 2: '', 3: ''} 31 | self.id2word = { 32 | PAD_ID: PAD_TOKEN, UNK_ID: UNK_TOKEN, 33 | SOS_ID: SOS_TOKEN, EOS_ID: EOS_TOKEN 34 | } 35 | # {'': 0, '': 1, '': 2, '': 3} 36 | self.word2id = defaultdict(lambda: UNK_ID) # Not in vocab => return UNK 37 | self.word2id.update({ 38 | PAD_TOKEN: PAD_ID, UNK_TOKEN: UNK_ID, 39 | SOS_TOKEN: SOS_ID, EOS_TOKEN: EOS_ID 40 | }) 41 | # self.word2id = { 42 | # PAD_TOKEN: PAD_ID, UNK_TOKEN: UNK_ID, 43 | # SOS_TOKEN: SOS_ID, EOS_TOKEN: EOS_ID 44 | # } 45 | 46 | vocab_size = 4 47 | min_freq = max(min_freq, 1) 48 | 49 | # Reset frequencies of special tokens 50 | # [...('', 0), ('', 0), ('', 0), ('', 0)] 51 | freqdist = self.freqdist.copy() 52 | special_freqdist = {token: freqdist[token] 53 | for token in [PAD_TOKEN, UNK_TOKEN, SOS_TOKEN, EOS_TOKEN]} 54 | freqdist.subtract(special_freqdist) 55 | 56 | # Sort: by frequency, then alphabetically 57 | # Ex) freqdist = { 'a': 4, 'b': 5, 'c': 3 } 58 | # => sorted = [('b', 5), ('a', 4), ('c', 3)] 59 | sorted_frequency_counter = sorted(freqdist.items(), key=lambda k_v: k_v[0]) 60 | sorted_frequency_counter.sort(key=lambda k_v: k_v[1], reverse=True) 61 | 62 | for word, freq in sorted_frequency_counter: 63 | 64 | if freq < min_freq or vocab_size == max_size: 65 | break 66 | self.id2word[vocab_size] = word 67 | self.word2id[word] = vocab_size 68 | vocab_size += 1 69 | 70 | self.vocab_size = vocab_size 71 | 72 | def __len__(self): 73 | return len(self.id2word) 74 | 75 | def load(self, word2id_path=None, id2word_path=None): 76 | if word2id_path: 77 | with open(word2id_path, 'rb') as f: 78 | word2id = pickle.load(f) 79 | # Can't pickle lambda function 80 | self.word2id = defaultdict(lambda: UNK_ID) 81 | self.word2id.update(word2id) 82 | self.vocab_size = len(self.word2id) 83 | 84 | if id2word_path: 85 | with open(id2word_path, 'rb') as f: 86 | id2word = pickle.load(f) 87 | self.id2word = id2word 88 | 89 | def add_word(self, word): 90 | assert isinstance(word, str), 'Input should be str' 91 | self.freqdist.update([word]) 92 | 93 | def add_sentence(self, sentence): 94 | for word in sentence: 95 | self.add_word(word) 96 | 97 | def add_dataframe(self, conversation_df): 98 | for conversation in conversation_df: 99 | for sentence in conversation: 100 | self.add_sentence(sentence) 101 | 102 | def pickle(self, word2id_path, id2word_path): 103 | with open(word2id_path, 'wb') as f: 104 | pickle.dump(dict(self.word2id), f) 105 | 106 | with open(id2word_path, 'wb') as f: 107 | pickle.dump(self.id2word, f) 108 | 109 | def to_list(self, list_like): 110 | """Convert list-like containers to list""" 111 | if isinstance(list_like, list): 112 | return list_like 113 | 114 | if isinstance(list_like, Variable): 115 | return list(to_tensor(list_like).numpy()) 116 | elif isinstance(list_like, Tensor): 117 | return list(list_like.numpy()) 118 | 119 | def id2sent(self, id_list): 120 | """list of id => list of tokens (Single sentence)""" 121 | id_list = self.to_list(id_list) 122 | sentence = [] 123 | for id in id_list: 124 | word = self.id2word[id] 125 | if word not in [EOS_TOKEN, SOS_TOKEN, PAD_TOKEN]: 126 | sentence.append(word) 127 | if word == EOS_TOKEN: 128 | break 129 | return sentence 130 | 131 | def sent2id(self, sentence, var=False): 132 | """list of tokens => list of id (Single sentence)""" 133 | id_list = [self.word2id[word] for word in sentence] 134 | if var: 135 | id_list = to_var(torch.LongTensor(id_list), eval=True) 136 | return id_list 137 | 138 | def decode(self, id_list): 139 | sentence = self.id2sent(id_list) 140 | if self.lang == "zh": 141 | return ''.join(sentence) 142 | return ' '.join(sentence) 143 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | # Preprocess cornell movie dialogs dataset 2 | import os 3 | from multiprocessing import Pool 4 | import argparse 5 | import pickle 6 | import random 7 | from pathlib import Path 8 | from tqdm import tqdm 9 | from utils import Vocab, PAD_TOKEN, SOS_TOKEN, EOS_TOKEN 10 | 11 | project_dir = Path(__file__).resolve().parent 12 | datasets_dir = project_dir.joinpath('./data/') 13 | 14 | def load_conversations(fileName, spliter=""): 15 | conversations = [] 16 | with open(fileName, 'r') as f: 17 | for line in f: 18 | line = line.strip() 19 | if not line: 20 | continue 21 | fs = line.split("\t") 22 | if len(fs) != 2: 23 | print("error line", line) 24 | context, response = fs[0].strip(), fs[1].strip() 25 | utterances = context.split(spliter) 26 | conversation = [] 27 | for utterance in utterances: 28 | conversation.append(utterance.split()) 29 | conversation.append(response.split()) 30 | conversations.append(conversation) 31 | return conversations 32 | 33 | def tokenize_conversation(lines): 34 | sentence_list = [tokenizer(line['text']) for line in lines] 35 | return sentence_list 36 | 37 | def pad_sentences(conversations, max_sentence_length=40, max_conversation_length=10): 38 | def pad_tokens(tokens, max_sentence_length=max_sentence_length): 39 | n_valid_tokens = len(tokens) 40 | if n_valid_tokens > max_sentence_length - 1: 41 | tokens = tokens[:max_sentence_length - 1] 42 | n_pad = max_sentence_length - n_valid_tokens - 1 43 | tokens = tokens + [EOS_TOKEN] + [PAD_TOKEN] * n_pad 44 | return tokens 45 | 46 | def pad_conversation(conversation): 47 | conversation = [pad_tokens(sentence) for sentence in conversation] 48 | return conversation 49 | 50 | all_padded_sentences = [] 51 | all_sentence_length = [] 52 | 53 | for conversation in conversations: 54 | if len(conversation) > max_conversation_length: 55 | conversation.reverse() 56 | conversation = conversation[:max_conversation_length] 57 | conversation.reverse() # the last n utterances 58 | sentence_length = [min(len(sentence) + 1, max_sentence_length) # +1 for EOS token 59 | for sentence in conversation] 60 | all_sentence_length.append(sentence_length) 61 | 62 | sentences = pad_conversation(conversation) 63 | all_padded_sentences.append(sentences) 64 | 65 | sentences = all_padded_sentences 66 | sentence_length = all_sentence_length 67 | return sentences, sentence_length 68 | 69 | 70 | if __name__ == '__main__': 71 | 72 | parser = argparse.ArgumentParser() 73 | 74 | # Maximum valid length of sentence 75 | # => SOS/EOS will surround sentence (EOS for source / SOS for target) 76 | # => maximum length of tensor = max_sentence_length + 1 77 | parser.add_argument('-s', '--max_sentence_length', type=int, default=40) 78 | parser.add_argument('-c', '--max_conversation_length', type=int, default=10) 79 | 80 | # Vocabulary 81 | parser.add_argument('--max_vocab_size', type=int, default=50000) 82 | parser.add_argument('--min_vocab_frequency', type=int, default=1) 83 | 84 | args = parser.parse_args() 85 | 86 | max_sent_len = args.max_sentence_length 87 | max_conv_len = args.max_conversation_length 88 | max_vocab_size = args.max_vocab_size 89 | min_freq = args.min_vocab_frequency 90 | 91 | print("Loading conversations...") 92 | train = load_conversations(datasets_dir.joinpath("train.txt")) 93 | valid = load_conversations(datasets_dir.joinpath("dev.txt")) 94 | test = load_conversations(datasets_dir.joinpath("test.txt")) 95 | 96 | print("#train=%d, #val=%d, #test=%d"%(len(train), len(valid), len(test))) 97 | def to_pickle(obj, path): 98 | with open(path, 'wb') as f: 99 | pickle.dump(obj, f) 100 | 101 | vocab = Vocab(lang="zh") 102 | for split_type, conversations in [('train', train), ('valid', valid), ('test', test)]: 103 | print(f'Processing {split_type} dataset...') 104 | split_data_dir = datasets_dir.joinpath(split_type) 105 | split_data_dir.mkdir(exist_ok=True) 106 | conversation_length = [min(len(conv), max_conv_len) 107 | for conv in conversations] 108 | 109 | sentences, sentence_length = pad_sentences( 110 | conversations, 111 | max_sentence_length=max_sent_len, 112 | max_conversation_length=max_conv_len) 113 | 114 | print('Saving preprocessed data at', split_data_dir) 115 | to_pickle(conversation_length, split_data_dir.joinpath('conversation_length.pkl')) 116 | to_pickle(sentences, split_data_dir.joinpath('sentences.pkl')) 117 | to_pickle(sentence_length, split_data_dir.joinpath('sentence_length.pkl')) 118 | 119 | if split_type != 'test': 120 | print('Save Vocabulary...') 121 | vocab.add_dataframe(conversations) 122 | vocab.update(max_size=max_vocab_size, min_freq=min_freq) 123 | 124 | print('Vocabulary size: ', len(vocab)) 125 | vocab.pickle(datasets_dir.joinpath('word2id.pkl'), datasets_dir.joinpath('id2word.pkl')) 126 | 127 | print('Done!') 128 | -------------------------------------------------------------------------------- /layers/beam_search.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import EOS_ID 3 | 4 | 5 | class Beam(object): 6 | def __init__(self, batch_size, hidden_size, vocab_size, beam_size, max_unroll, batch_position): 7 | """Beam class for beam search""" 8 | self.batch_size = batch_size 9 | self.hidden_size = hidden_size 10 | self.vocab_size = vocab_size 11 | self.beam_size = beam_size 12 | self.max_unroll = max_unroll 13 | 14 | # batch_position [batch_size] 15 | # [0, beam_size, beam_size * 2, .., beam_size * (batch_size-1)] 16 | # Points where batch starts in [batch_size x beam_size] tensors 17 | # Ex. position_idx[5]: when 5-th batch starts 18 | self.batch_position = batch_position 19 | 20 | self.log_probs = list() # [(batch*k, vocab_size)] * sequence_length 21 | self.scores = list() # [(batch*k)] * sequence_length 22 | self.back_pointers = list() # [(batch*k)] * sequence_length 23 | self.token_ids = list() # [(batch*k)] * sequence_length 24 | # self.hidden = list() # [(num_layers, batch*k, hidden_size)] * sequence_length 25 | 26 | self.metadata = { 27 | 'inputs': None, 28 | 'output': None, 29 | 'scores': None, 30 | 'length': None, 31 | 'sequence': None, 32 | } 33 | 34 | def update(self, score, back_pointer, token_id): # , h): 35 | """Append intermediate top-k candidates to beam at each step""" 36 | 37 | # self.log_probs.append(log_prob) 38 | self.scores.append(score) 39 | self.back_pointers.append(back_pointer) 40 | self.token_ids.append(token_id) 41 | # self.hidden.append(h) 42 | 43 | def backtrack(self): 44 | """Backtracks over batch to generate optimal k-sequences 45 | 46 | Returns: 47 | prediction ([batch, k, max_unroll]) 48 | A list of Tensors containing predicted sequence 49 | final_score [batch, k] 50 | A list containing the final scores for all top-k sequences 51 | length [batch, k] 52 | A list specifying the length of each sequence in the top-k candidates 53 | """ 54 | prediction = list() 55 | 56 | # import ipdb 57 | # ipdb.set_trace() 58 | # Initialize for length of top-k sequences 59 | length = [[self.max_unroll] * self.beam_size for _ in range(self.batch_size)] 60 | 61 | # Last step output of the beam are not sorted => sort here! 62 | # Size not changed [batch size, beam_size] 63 | top_k_score, top_k_idx = self.scores[-1].topk(self.beam_size, dim=1) 64 | 65 | # Initialize sequence scores 66 | top_k_score = top_k_score.clone() 67 | 68 | n_eos_in_batch = [0] * self.batch_size 69 | 70 | # Initialize Back-pointer from the last step 71 | # Add self.position_idx for indexing variable with batch x beam as the first dimension 72 | # [batch x beam] 73 | back_pointer = (top_k_idx + self.batch_position.unsqueeze(1)).view(-1) 74 | 75 | for t in reversed(range(self.max_unroll)): 76 | # Reorder variables with the Back-pointer 77 | # [batch x beam] 78 | token_id = self.token_ids[t].index_select(0, back_pointer) 79 | 80 | # Reorder the Back-pointer 81 | # [batch x beam] 82 | back_pointer = self.back_pointers[t].index_select(0, back_pointer) 83 | 84 | # Indices of ended sequences 85 | # [< batch x beam] 86 | eos_indices = self.token_ids[t].data.eq(EOS_ID).nonzero() 87 | 88 | # For each batch, every time we see an EOS in the backtracking process, 89 | # If not all sequences are ended 90 | # lowest scored survived sequence <- detected ended sequence 91 | # if all sequences are ended 92 | # lowest scored ended sequence <- detected ended sequence 93 | if eos_indices.dim() > 0: 94 | # Loop over all EOS at current step 95 | for i in range(eos_indices.size(0) - 1, -1, -1): 96 | # absolute index of detected ended sequence 97 | eos_idx = eos_indices[i, 0].item() 98 | 99 | # At which batch EOS is located 100 | batch_idx = eos_idx // self.beam_size 101 | batch_start_idx = batch_idx * self.beam_size 102 | 103 | # if n_eos_in_batch[batch_idx] > self.beam_size: 104 | 105 | # Index of sequence with lowest score 106 | _n_eos_in_batch = n_eos_in_batch[batch_idx] % self.beam_size 107 | beam_idx_to_be_replaced = self.beam_size - _n_eos_in_batch - 1 108 | idx_to_be_replaced = batch_start_idx + beam_idx_to_be_replaced 109 | 110 | # Replace old information with new sequence information 111 | back_pointer[idx_to_be_replaced] = self.back_pointers[t][eos_idx].item() 112 | token_id[idx_to_be_replaced] = self.token_ids[t][eos_idx].item() 113 | top_k_score[batch_idx, 114 | beam_idx_to_be_replaced] = self.scores[t].view(-1)[eos_idx].item() 115 | length[batch_idx][beam_idx_to_be_replaced] = t + 1 116 | 117 | n_eos_in_batch[batch_idx] += 1 118 | 119 | # max_unroll * [batch x beam] 120 | prediction.append(token_id) 121 | 122 | # Sort and re-order again as the added ended sequences may change the order 123 | # [batch, beam] 124 | top_k_score, top_k_idx = top_k_score.topk(self.beam_size, dim=1) 125 | final_score = top_k_score.data 126 | 127 | for batch_idx in range(self.batch_size): 128 | length[batch_idx] = [length[batch_idx][beam_idx.item()] 129 | for beam_idx in top_k_idx[batch_idx]] 130 | 131 | # [batch x beam] 132 | top_k_idx = (top_k_idx + self.batch_position.unsqueeze(1)).view(-1) 133 | 134 | # Reverse the sequences and re-order at the same time 135 | # It is reversed because the backtracking happens in the reverse order 136 | # [batch, beam] 137 | 138 | prediction = [step.index_select(0, top_k_idx).view( 139 | self.batch_size, self.beam_size) for step in reversed(prediction)] 140 | 141 | # [batch, beam, max_unroll] 142 | prediction = torch.stack(prediction, 2) 143 | 144 | return prediction, final_score, length 145 | -------------------------------------------------------------------------------- /configs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from datetime import datetime 4 | from collections import defaultdict 5 | from pathlib import Path 6 | import pprint 7 | from torch import optim 8 | import torch.nn as nn 9 | from layers.rnncells import StackedLSTMCell, StackedGRUCell 10 | 11 | project_dir = Path(__file__).resolve().parent 12 | data_dir = project_dir.joinpath('./data/') 13 | optimizer_dict = {'RMSprop': optim.RMSprop, 'Adam': optim.Adam} 14 | rnn_dict = {'lstm': nn.LSTM, 'gru': nn.GRU} 15 | rnncell_dict = {'lstm': StackedLSTMCell, 'gru': StackedGRUCell} 16 | save_dir = project_dir.joinpath('./ckpt/') 17 | pred_dir = project_dir.joinpath('./pred/') 18 | 19 | def str2bool(v): 20 | """string to boolean""" 21 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 22 | return True 23 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 24 | return False 25 | else: 26 | raise argparse.ArgumentTypeError('Boolean value expected.') 27 | 28 | 29 | class Config(object): 30 | def __init__(self, **kwargs): 31 | """Configuration Class: set kwargs as class attributes with setattr""" 32 | if kwargs is not None: 33 | for key, value in kwargs.items(): 34 | if key == 'optimizer': 35 | value = optimizer_dict[value] 36 | if key == 'rnn': 37 | value = rnn_dict[value] 38 | if key == 'rnncell': 39 | value = rnncell_dict[value] 40 | setattr(self, key, value) 41 | 42 | # Dataset directory: ex) ./datasets/cornell/ 43 | self.dataset_dir = project_dir.joinpath(self.data.lower()) 44 | 45 | # Data Split ex) 'train', 'valid', 'test' 46 | self.data_dir = self.dataset_dir.joinpath(self.mode) 47 | # Pickled Vocabulary 48 | self.word2id_path = self.dataset_dir.joinpath('word2id.pkl') 49 | self.id2word_path = self.dataset_dir.joinpath('id2word.pkl') 50 | 51 | # Pickled Dataframes 52 | self.sentences_path = self.data_dir.joinpath('sentences.pkl') 53 | self.sentence_length_path = self.data_dir.joinpath('sentence_length.pkl') 54 | self.conversation_length_path = self.data_dir.joinpath('conversation_length.pkl') 55 | 56 | os.makedirs(pred_dir, exist_ok=True) 57 | self.pred_path = pred_dir.joinpath('res.txt') 58 | 59 | # Save path 60 | if self.mode == 'train' and self.checkpoint is None: 61 | time_now = datetime.now().strftime('%Y-%m-%d_%H:%M:%S') 62 | self.save_path = save_dir.joinpath(self.data, self.model, time_now) 63 | self.logdir = self.save_path 64 | os.makedirs(self.save_path, exist_ok=True) 65 | elif self.checkpoint is not None: 66 | assert os.path.exists(self.checkpoint) 67 | self.save_path = os.path.dirname(self.checkpoint) 68 | self.logdir = self.save_path 69 | 70 | def __str__(self): 71 | """Pretty-print configurations in alphabetical order""" 72 | config_str = 'Configurations\n' 73 | config_str += pprint.pformat(self.__dict__) 74 | return config_str 75 | 76 | 77 | def get_config(parse=True, **optional_kwargs): 78 | """ 79 | Get configurations as attributes of class 80 | 1. Parse configurations with argparse. 81 | 2. Create Config class initilized with parsed kwargs. 82 | 3. Return Config class. 83 | """ 84 | parser = argparse.ArgumentParser() 85 | 86 | # Mode 87 | parser.add_argument('--mode', type=str, default='train') 88 | 89 | # Train 90 | parser.add_argument('--batch_size', type=int, default=128) 91 | parser.add_argument('--eval_batch_size', type=int, default=50) 92 | parser.add_argument('--n_epoch', type=int, default=40) 93 | parser.add_argument('--learning_rate', type=float, default=1e-4) 94 | parser.add_argument('--optimizer', type=str, default='Adam') 95 | parser.add_argument('--clip', type=float, default=1.0) 96 | parser.add_argument('--checkpoint', type=str, default=None) 97 | 98 | # Generation 99 | parser.add_argument('--max_unroll', type=int, default=50) 100 | parser.add_argument('--sample', type=str2bool, default=False, 101 | help='if false, use beam search for decoding') 102 | parser.add_argument('--temperature', type=float, default=1.0) 103 | parser.add_argument('--beam_size', type=int, default=5) 104 | 105 | # Model 106 | parser.add_argument('--model', type=str, default='VHCR', 107 | help='one of {HRED, VHRED, VHCR}') 108 | # Currently does not support lstm 109 | parser.add_argument('--rnn', type=str, default='gru') 110 | parser.add_argument('--rnncell', type=str, default='gru') 111 | parser.add_argument('--num_layers', type=int, default=1) 112 | parser.add_argument('--embedding_size', type=int, default=300) 113 | parser.add_argument('--tie_embedding', type=str2bool, default=True) 114 | parser.add_argument('--encoder_hidden_size', type=int, default=512) 115 | parser.add_argument('--bidirectional', type=str2bool, default=True) 116 | parser.add_argument('--decoder_hidden_size', type=int, default=512) 117 | parser.add_argument('--dropout', type=float, default=0.2) 118 | parser.add_argument('--context_size', type=int, default=512) 119 | parser.add_argument('--feedforward', type=str, default='FeedForward') 120 | parser.add_argument('--activation', type=str, default='Tanh') 121 | 122 | # VAE model 123 | parser.add_argument('--z_sent_size', type=int, default=100) 124 | parser.add_argument('--z_conv_size', type=int, default=100) 125 | parser.add_argument('--word_drop', type=float, default=0.0, 126 | help='only applied to variational models') 127 | parser.add_argument('--kl_threshold', type=float, default=0.0) 128 | parser.add_argument('--kl_annealing_iter', type=int, default=25000) 129 | parser.add_argument('--importance_sample', type=int, default=100) 130 | parser.add_argument('--sentence_drop', type=float, default=0.0) 131 | 132 | # Generation 133 | parser.add_argument('--n_context', type=int, default=1) 134 | parser.add_argument('--n_sample_step', type=int, default=1) 135 | 136 | # BOW 137 | parser.add_argument('--bow', type=str2bool, default=False) 138 | 139 | # Utility 140 | parser.add_argument('--print_every', type=int, default=100) 141 | parser.add_argument('--plot_every_epoch', type=int, default=1) 142 | parser.add_argument('--save_every_epoch', type=int, default=1) 143 | 144 | # Data 145 | parser.add_argument('--data', type=str, default='./data/') 146 | 147 | # Parse arguments 148 | if parse: 149 | kwargs = parser.parse_args() 150 | else: 151 | kwargs = parser.parse_known_args()[0] 152 | 153 | # Namespace => Dictionary 154 | kwargs = vars(kwargs) 155 | kwargs.update(optional_kwargs) 156 | 157 | return Config(**kwargs) 158 | -------------------------------------------------------------------------------- /layers/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence, PackedSequence 5 | from utils import to_var, reverse_order_valid, PAD_ID 6 | from .rnncells import StackedGRUCell, StackedLSTMCell 7 | 8 | import copy 9 | 10 | class BaseRNNEncoder(nn.Module): 11 | def __init__(self): 12 | """Base RNN Encoder Class""" 13 | super(BaseRNNEncoder, self).__init__() 14 | 15 | @property 16 | def use_lstm(self): 17 | if hasattr(self, 'rnn'): 18 | return isinstance(self.rnn, nn.LSTM) 19 | else: 20 | raise AttributeError('no rnn selected') 21 | 22 | def init_h(self, batch_size=None, hidden=None): 23 | """Return RNN initial state""" 24 | if hidden is not None: 25 | return hidden 26 | 27 | if self.use_lstm: 28 | return (to_var(torch.zeros(self.num_layers*self.num_directions, 29 | batch_size, 30 | self.hidden_size)), 31 | to_var(torch.zeros(self.num_layers*self.num_directions, 32 | batch_size, 33 | self.hidden_size))) 34 | else: 35 | return to_var(torch.zeros(self.num_layers*self.num_directions, 36 | batch_size, 37 | self.hidden_size)) 38 | 39 | def batch_size(self, inputs=None, h=None): 40 | """ 41 | inputs: [batch_size, seq_len] 42 | h: [num_layers, batch_size, hidden_size] (RNN/GRU) 43 | h_c: [2, num_layers, batch_size, hidden_size] (LSTM) 44 | """ 45 | if inputs is not None: 46 | batch_size = inputs.size(0) 47 | return batch_size 48 | 49 | else: 50 | if self.use_lstm: 51 | batch_size = h[0].size(1) 52 | else: 53 | batch_size = h.size(1) 54 | return batch_size 55 | 56 | def forward(self): 57 | raise NotImplementedError 58 | 59 | 60 | class EncoderRNN(BaseRNNEncoder): 61 | def __init__(self, vocab_size, embedding_size, 62 | hidden_size, rnn=nn.GRU, num_layers=1, bidirectional=False, 63 | dropout=0.0, bias=True, batch_first=True): 64 | """Sentence-level Encoder""" 65 | super(EncoderRNN, self).__init__() 66 | 67 | self.vocab_size = vocab_size 68 | self.embedding_size = embedding_size 69 | self.hidden_size = hidden_size 70 | self.num_layers = num_layers 71 | self.dropout = dropout 72 | self.batch_first = batch_first 73 | self.bidirectional = bidirectional 74 | 75 | if bidirectional: 76 | self.num_directions = 2 77 | else: 78 | self.num_directions = 1 79 | 80 | # word embedding 81 | self.embedding = nn.Embedding(vocab_size, embedding_size, padding_idx=PAD_ID) 82 | 83 | self.rnn = rnn(input_size=embedding_size, 84 | hidden_size=hidden_size, 85 | num_layers=num_layers, 86 | bias=bias, 87 | batch_first=batch_first, 88 | dropout=dropout, 89 | bidirectional=bidirectional) 90 | 91 | def forward(self, inputs, input_length, hidden=None): 92 | """ 93 | Args: 94 | inputs (Variable, LongTensor): [num_setences, max_seq_len] 95 | input_length (Variable, LongTensor): [num_sentences] 96 | Return: 97 | outputs (Variable): [max_source_length, batch_size, hidden_size] 98 | - list of all hidden states 99 | hidden ((tuple of) Variable): [num_layers*num_directions, batch_size, hidden_size] 100 | - last hidden state 101 | - (h, c) or h 102 | """ 103 | batch_size, seq_len = inputs.size() 104 | 105 | # Sort in decreasing order of length for pack_padded_sequence() 106 | input_length_sorted, indices = input_length.sort(descending=True) 107 | 108 | input_length_sorted = input_length_sorted.data.tolist() 109 | 110 | # [num_sentences, max_source_length] 111 | inputs_sorted = inputs.index_select(0, indices) 112 | 113 | # [num_sentences, max_source_length, embedding_dim] 114 | embedded = self.embedding(inputs_sorted) 115 | 116 | # batch_first=True 117 | rnn_input = pack_padded_sequence(embedded, input_length_sorted, 118 | batch_first=self.batch_first) 119 | 120 | hidden = self.init_h(batch_size, hidden=hidden) 121 | 122 | # outputs: [batch, seq_len, hidden_size * num_directions] 123 | # hidden: [num_layers * num_directions, batch, hidden_size] 124 | self.rnn.flatten_parameters() 125 | outputs, hidden = self.rnn(rnn_input, hidden) 126 | outputs, outputs_lengths = pad_packed_sequence(outputs, batch_first=self.batch_first) 127 | 128 | # Reorder outputs and hidden 129 | _, inverse_indices = indices.sort() 130 | outputs = outputs.index_select(0, inverse_indices) 131 | 132 | if self.use_lstm: 133 | hidden = (hidden[0].index_select(1, inverse_indices), 134 | hidden[1].index_select(1, inverse_indices)) 135 | else: 136 | hidden = hidden.index_select(1, inverse_indices) 137 | 138 | return outputs, hidden 139 | 140 | class ContextRNN(BaseRNNEncoder): 141 | def __init__(self, input_size, context_size, rnn=nn.GRU, num_layers=1, dropout=0.0, 142 | bidirectional=False, bias=True, batch_first=True): 143 | """Context-level Encoder""" 144 | super(ContextRNN, self).__init__() 145 | 146 | self.input_size = input_size 147 | self.context_size = context_size 148 | self.hidden_size = self.context_size 149 | self.num_layers = num_layers 150 | self.dropout = dropout 151 | self.bidirectional = bidirectional 152 | self.batch_first = batch_first 153 | 154 | if bidirectional: 155 | self.num_directions = 2 156 | else: 157 | self.num_directions = 1 158 | 159 | self.rnn = rnn(input_size=input_size, 160 | hidden_size=context_size, 161 | num_layers=num_layers, 162 | bias=bias, 163 | batch_first=batch_first, 164 | dropout=dropout, 165 | bidirectional=bidirectional) 166 | 167 | def forward(self, encoder_hidden, conversation_length, hidden=None): 168 | """ 169 | Args: 170 | encoder_hidden (Variable, FloatTensor): [batch_size, max_len, num_layers * direction * hidden_size] 171 | conversation_length (Variable, LongTensor): [batch_size] 172 | Return: 173 | outputs (Variable): [batch_size, max_seq_len, hidden_size] 174 | - list of all hidden states 175 | hidden ((tuple of) Variable): [num_layers*num_directions, batch_size, hidden_size] 176 | - last hidden state 177 | - (h, c) or h 178 | """ 179 | batch_size, seq_len, _ = encoder_hidden.size() 180 | 181 | # Sort for PackedSequence 182 | conv_length_sorted, indices = conversation_length.sort(descending=True) 183 | conv_length_sorted = conv_length_sorted.data.tolist() 184 | encoder_hidden_sorted = encoder_hidden.index_select(0, indices) 185 | 186 | rnn_input = pack_padded_sequence(encoder_hidden_sorted, conv_length_sorted, batch_first=True) 187 | 188 | hidden = self.init_h(batch_size, hidden=hidden) 189 | 190 | self.rnn.flatten_parameters() 191 | outputs, hidden = self.rnn(rnn_input, hidden) 192 | 193 | # outputs: [batch_size, max_conversation_length, context_size] 194 | outputs, outputs_length = pad_packed_sequence(outputs, batch_first=True) 195 | 196 | # reorder outputs and hidden 197 | _, inverse_indices = indices.sort() 198 | outputs = outputs.index_select(0, inverse_indices) 199 | 200 | if self.use_lstm: 201 | hidden = (hidden[0].index_select(1, inverse_indices), 202 | hidden[1].index_select(1, inverse_indices)) 203 | else: 204 | hidden = hidden.index_select(1, inverse_indices) 205 | 206 | # outputs: [batch, seq_len, hidden_size * num_directions] 207 | # hidden: [num_layers * num_directions, batch, hidden_size] 208 | return outputs, hidden 209 | 210 | def step(self, encoder_hidden, hidden): 211 | 212 | batch_size = encoder_hidden.size(0) 213 | # encoder_hidden: [1, batch_size, hidden_size] 214 | encoder_hidden = torch.unsqueeze(encoder_hidden, 1) 215 | 216 | if hidden is None: 217 | hidden = self.init_h(batch_size, hidden=None) 218 | 219 | outputs, hidden = self.rnn(encoder_hidden, hidden) 220 | return outputs, hidden 221 | -------------------------------------------------------------------------------- /layers/decoder.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from .rnncells import StackedLSTMCell, StackedGRUCell 6 | from .beam_search import Beam 7 | from .feedforward import FeedForward 8 | from utils import to_var, SOS_ID, UNK_ID, EOS_ID 9 | import math 10 | 11 | 12 | class BaseRNNDecoder(nn.Module): 13 | def __init__(self): 14 | """Base Decoder Class""" 15 | super(BaseRNNDecoder, self).__init__() 16 | 17 | @property 18 | def use_lstm(self): 19 | return isinstance(self.rnncell, StackedLSTMCell) 20 | 21 | def init_token(self, batch_size, SOS_ID=SOS_ID): 22 | """Get Variable of Index (batch_size)""" 23 | x = to_var(torch.LongTensor([SOS_ID] * batch_size)) 24 | return x 25 | 26 | def init_h(self, batch_size=None, zero=True, hidden=None): 27 | """Return RNN initial state""" 28 | if hidden is not None: 29 | return hidden 30 | 31 | if self.use_lstm: 32 | # (h, c) 33 | return (to_var(torch.zeros(self.num_layers, 34 | batch_size, 35 | self.hidden_size)), 36 | to_var(torch.zeros(self.num_layers, 37 | batch_size, 38 | self.hidden_size))) 39 | else: 40 | # h 41 | return to_var(torch.zeros(self.num_layers, 42 | batch_size, 43 | self.hidden_size)) 44 | 45 | def batch_size(self, inputs=None, h=None): 46 | """ 47 | inputs: [batch_size, seq_len] 48 | h: [num_layers, batch_size, hidden_size] (RNN/GRU) 49 | h_c: [2, num_layers, batch_size, hidden_size] (LSTMCell) 50 | """ 51 | if inputs is not None: 52 | batch_size = inputs.size(0) 53 | return batch_size 54 | 55 | else: 56 | if self.use_lstm: 57 | batch_size = h[0].size(1) 58 | else: 59 | batch_size = h.size(1) 60 | return batch_size 61 | 62 | def decode(self, out): 63 | """ 64 | Args: 65 | out: unnormalized word distribution [batch_size, vocab_size] 66 | Return: 67 | x: word_index [batch_size] 68 | """ 69 | 70 | # Sample next word from multinomial word distribution 71 | if self.sample: 72 | # x: [batch_size] - word index (next input) 73 | x = torch.multinomial(self.softmax(out / self.temperature), 1).view(-1) 74 | 75 | # Greedy sampling 76 | else: 77 | # x: [batch_size] - word index (next input) 78 | _, x = out.max(dim=1) 79 | return x 80 | 81 | def forward(self): 82 | """Base forward function to inherit""" 83 | raise NotImplementedError 84 | 85 | def forward_step(self): 86 | """Run RNN single step""" 87 | raise NotImplementedError 88 | 89 | def embed(self, x): 90 | """word index: [batch_size] => word vectors: [batch_size, hidden_size]""" 91 | 92 | if self.training and self.word_drop > 0.0: 93 | if random.random() < self.word_drop: 94 | embed = self.embedding(to_var(x.data.new([UNK_ID] * x.size(0)))) 95 | else: 96 | embed = self.embedding(x) 97 | else: 98 | embed = self.embedding(x) 99 | 100 | return embed 101 | 102 | def beam_decode(self, 103 | init_h=None, 104 | encoder_outputs=None, input_valid_length=None, 105 | decode=False): 106 | """ 107 | Args: 108 | encoder_outputs (Variable, FloatTensor): [batch_size, source_length, hidden_size] 109 | input_valid_length (Variable, LongTensor): [batch_size] (optional) 110 | init_h (variable, FloatTensor): [batch_size, hidden_size] (optional) 111 | Return: 112 | out : [batch_size, seq_len] 113 | """ 114 | batch_size = self.batch_size(h=init_h) 115 | 116 | # [batch_size x beam_size] 117 | x = self.init_token(batch_size * self.beam_size, SOS_ID) 118 | 119 | # [num_layers, batch_size x beam_size, hidden_size] 120 | h = self.init_h(batch_size, hidden=init_h).repeat(1, self.beam_size, 1) 121 | 122 | # batch_position [batch_size] 123 | # [0, beam_size, beam_size * 2, .., beam_size * (batch_size-1)] 124 | # Points where batch starts in [batch_size x beam_size] tensors 125 | # Ex. position_idx[5]: when 5-th batch starts 126 | batch_position = to_var(torch.arange(0, batch_size).long() * self.beam_size) 127 | 128 | # Initialize scores of sequence 129 | # [batch_size x beam_size] 130 | # Ex. batch_size: 5, beam_size: 3 131 | # [0, -inf, -inf, 0, -inf, -inf, 0, -inf, -inf, 0, -inf, -inf, 0, -inf, -inf] 132 | score = torch.ones(batch_size * self.beam_size) * -float('inf') 133 | score.index_fill_(0, torch.arange(0, batch_size).long() * self.beam_size, 0.0) 134 | score = to_var(score) 135 | 136 | # Initialize Beam that stores decisions for backtracking 137 | beam = Beam( 138 | batch_size, 139 | self.hidden_size, 140 | self.vocab_size, 141 | self.beam_size, 142 | self.max_unroll, 143 | batch_position) 144 | 145 | for i in range(self.max_unroll): 146 | 147 | # x: [batch_size x beam_size]; (token index) 148 | # => 149 | # out: [batch_size x beam_size, vocab_size] 150 | # h: [num_layers, batch_size x beam_size, hidden_size] 151 | out, h = self.forward_step(x, h, 152 | encoder_outputs=encoder_outputs, 153 | input_valid_length=input_valid_length) 154 | # log_prob: [batch_size x beam_size, vocab_size] 155 | log_prob = F.log_softmax(out, dim=1) 156 | 157 | # [batch_size x beam_size] 158 | # => [batch_size x beam_size, vocab_size] 159 | score = score.view(-1, 1) + log_prob 160 | 161 | # Select `beam size` transitions out of `vocab size` combinations 162 | 163 | # [batch_size x beam_size, vocab_size] 164 | # => [batch_size, beam_size x vocab_size] 165 | # Cutoff and retain candidates with top-k scores 166 | # score: [batch_size, beam_size] 167 | # top_k_idx: [batch_size, beam_size] 168 | # each element of top_k_idx [0 ~ beam x vocab) 169 | 170 | score, top_k_idx = score.view(batch_size, -1).topk(self.beam_size, dim=1) 171 | 172 | # Get token ids with remainder after dividing by top_k_idx 173 | # Each element is among [0, vocab_size) 174 | # Ex. Index of token 3 in beam 4 175 | # (4 * vocab size) + 3 => 3 176 | # x: [batch_size x beam_size] 177 | x = (top_k_idx % self.vocab_size).view(-1) 178 | 179 | # top-k-pointer [batch_size x beam_size] 180 | # Points top-k beam that scored best at current step 181 | # Later used as back-pointer at backtracking 182 | # Each element is beam index: 0 ~ beam_size 183 | # + position index: 0 ~ beam_size x (batch_size-1) 184 | beam_idx = top_k_idx / self.vocab_size # [batch_size, beam_size] 185 | top_k_pointer = (beam_idx + batch_position.unsqueeze(1)).view(-1) 186 | 187 | # Select next h (size doesn't change) 188 | # [num_layers, batch_size * beam_size, hidden_size] 189 | h = h.index_select(1, top_k_pointer) 190 | 191 | # Update sequence scores at beam 192 | beam.update(score.clone(), top_k_pointer, x) # , h) 193 | 194 | # Erase scores for EOS so that they are not expanded 195 | # [batch_size, beam_size] 196 | eos_idx = x.data.eq(EOS_ID).view(batch_size, self.beam_size) 197 | if eos_idx.nonzero().dim() > 0: 198 | score.data.masked_fill_(eos_idx, -float('inf')) 199 | 200 | # prediction ([batch, k, max_unroll]) 201 | # A list of Tensors containing predicted sequence 202 | # final_score [batch, k] 203 | # A list containing the final scores for all top-k sequences 204 | # length [batch, k] 205 | # A list specifying the length of each sequence in the top-k candidates 206 | # prediction, final_score, length = beam.backtrack() 207 | prediction, final_score, length = beam.backtrack() 208 | 209 | return prediction, final_score, length 210 | 211 | 212 | class DecoderRNN(BaseRNNDecoder): 213 | def __init__(self, vocab_size, embedding_size, 214 | hidden_size, rnncell=StackedGRUCell, num_layers=1, 215 | dropout=0.0, word_drop=0.0, 216 | max_unroll=30, sample=True, temperature=1.0, beam_size=1): 217 | super(DecoderRNN, self).__init__() 218 | 219 | self.vocab_size = vocab_size 220 | self.embedding_size = embedding_size 221 | self.hidden_size = hidden_size 222 | self.num_layers = num_layers 223 | self.dropout = dropout 224 | self.temperature = temperature 225 | self.word_drop = word_drop 226 | self.max_unroll = max_unroll 227 | self.sample = sample 228 | self.beam_size = beam_size 229 | 230 | self.embedding = nn.Embedding(vocab_size, embedding_size) 231 | 232 | self.rnncell = rnncell(num_layers, 233 | embedding_size, 234 | hidden_size, 235 | dropout) 236 | self.out = nn.Linear(hidden_size, vocab_size) 237 | self.softmax = nn.Softmax(dim=1) 238 | 239 | def forward_step(self, x, h, 240 | encoder_outputs=None, 241 | input_valid_length=None): 242 | """ 243 | Single RNN Step 244 | 1. Input Embedding (vocab_size => hidden_size) 245 | 2. RNN Step (hidden_size => hidden_size) 246 | 3. Output Projection (hidden_size => vocab size) 247 | 248 | Args: 249 | x: [batch_size] 250 | h: [num_layers, batch_size, hidden_size] (h and c from all layers) 251 | 252 | Return: 253 | out: [batch_size,vocab_size] (Unnormalized word distribution) 254 | h: [num_layers, batch_size, hidden_size] (h and c from all layers) 255 | """ 256 | # x: [batch_size] => [batch_size, hidden_size] 257 | x = self.embed(x) 258 | 259 | # last_h: [batch_size, hidden_size] (h from Top RNN layer) 260 | # h: [num_layers, batch_size, hidden_size] (h and c from all layers) 261 | last_h, h = self.rnncell(x, h) 262 | 263 | if self.use_lstm: 264 | # last_h_c: [2, batch_size, hidden_size] (h from Top RNN layer) 265 | # h_c: [2, num_layers, batch_size, hidden_size] (h and c from all layers) 266 | last_h = last_h[0] 267 | 268 | # Unormalized word distribution 269 | # out: [batch_size, vocab_size] 270 | out = self.out(last_h) 271 | return out, h 272 | 273 | def forward(self, inputs, init_h=None, encoder_outputs=None, input_valid_length=None, 274 | decode=False): 275 | """ 276 | Train (decode=False) 277 | Args: 278 | inputs (Variable, LongTensor): [batch_size, seq_len] 279 | init_h: (Variable, FloatTensor): [num_layers, batch_size, hidden_size] 280 | Return: 281 | out : [batch_size, seq_len, vocab_size] 282 | Test (decode=True) 283 | Args: 284 | inputs: None 285 | init_h: (Variable, FloatTensor): [num_layers, batch_size, hidden_size] 286 | Return: 287 | out : [batch_size, seq_len] 288 | """ 289 | batch_size = self.batch_size(inputs, init_h) 290 | 291 | # x: [batch_size] 292 | x = self.init_token(batch_size, SOS_ID) 293 | 294 | # h: [num_layers, batch_size, hidden_size] 295 | h = self.init_h(batch_size, hidden=init_h) 296 | 297 | 298 | if not decode: 299 | out_list = [] 300 | seq_len = inputs.size(1) 301 | for i in range(seq_len): 302 | 303 | # x: [batch_size] 304 | # => 305 | # out: [batch_size, vocab_size] 306 | # h: [num_layers, batch_size, hidden_size] (h and c from all layers) 307 | out, h = self.forward_step(x, h) 308 | 309 | out_list.append(out) 310 | x = inputs[:, i] 311 | 312 | # [batch_size, max_target_len, vocab_size] 313 | return torch.stack(out_list, dim=1) 314 | else: 315 | x_list = [] 316 | for i in range(self.max_unroll): 317 | 318 | # x: [batch_size] 319 | # => 320 | # out: [batch_size, vocab_size] 321 | # h: [num_layers, batch_size, hidden_size] (h and c from all layers) 322 | out, h = self.forward_step(x, h) 323 | 324 | # out: [batch_size, vocab_size] 325 | # => x: [batch_size] 326 | x = self.decode(out) 327 | x_list.append(x) 328 | 329 | # [batch_size, max_target_len] 330 | return torch.stack(x_list, dim=1) 331 | -------------------------------------------------------------------------------- /solver.py: -------------------------------------------------------------------------------- 1 | from itertools import cycle 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import models 6 | from layers import masked_cross_entropy 7 | from utils import to_var, time_desc_decorator, pad_and_pack, normal_kl_div, to_bow, bag_of_words_loss, normal_kl_div, embedding_metric 8 | import os 9 | from tqdm import tqdm 10 | from math import isnan 11 | import re 12 | import math 13 | import pickle 14 | #import gensim 15 | 16 | word2vec_path = "../datasets/GoogleNews-vectors-negative300.bin" 17 | 18 | class Solver(object): 19 | def __init__(self, config, train_data_loader, eval_data_loader, vocab, is_train=True, model=None): 20 | self.config = config 21 | self.epoch_i = 0 22 | self.train_data_loader = train_data_loader 23 | self.eval_data_loader = eval_data_loader 24 | self.vocab = vocab 25 | self.is_train = is_train 26 | self.model = model 27 | 28 | @time_desc_decorator('Build Graph') 29 | def build(self, cuda=True): 30 | 31 | if self.model is None: 32 | self.model = getattr(models, self.config.model)(self.config) 33 | 34 | # orthogonal initialiation for hidden weights 35 | # input gate bias for GRUs 36 | if self.config.mode == 'train' and self.config.checkpoint is None: 37 | print('Parameter initiailization') 38 | for name, param in self.model.named_parameters(): 39 | if 'weight_hh' in name: 40 | print('\t' + name) 41 | nn.init.orthogonal_(param) 42 | 43 | # bias_hh is concatenation of reset, input, new gates 44 | # only set the input gate bias to 2.0 45 | if 'bias_hh' in name: 46 | print('\t' + name) 47 | dim = int(param.size(0) / 3) 48 | param.data[dim:2 * dim].fill_(2.0) 49 | 50 | if torch.cuda.is_available() and cuda: 51 | self.model.cuda() 52 | 53 | # Overview Parameters 54 | print('Model Parameters') 55 | for name, param in self.model.named_parameters(): 56 | print('\t' + name + '\t', list(param.size())) 57 | 58 | if self.config.checkpoint: 59 | self.load_model(self.config.checkpoint) 60 | 61 | if self.is_train: 62 | #self.writer = TensorboardWriter(self.config.logdir) 63 | self.optimizer = self.config.optimizer( 64 | filter(lambda p: p.requires_grad, self.model.parameters()), 65 | lr=self.config.learning_rate) 66 | 67 | def save_model(self, epoch): 68 | """Save parameters to checkpoint""" 69 | ckpt_path = os.path.join(self.config.save_path, f'{epoch}.pkl') 70 | print(f'Save parameters to {ckpt_path}') 71 | torch.save(self.model.state_dict(), ckpt_path) 72 | 73 | def load_model(self, checkpoint): 74 | """Load parameters from checkpoint""" 75 | print(f'Load parameters from {checkpoint}') 76 | epoch = re.match(r"[0-9]*", os.path.basename(checkpoint)).group(0) 77 | self.epoch_i = int(epoch) 78 | self.model.load_state_dict(torch.load(checkpoint)) 79 | 80 | def write_summary(self, epoch_i): 81 | epoch_loss = getattr(self, 'epoch_loss', None) 82 | if epoch_loss is not None: 83 | self.writer.update_loss( 84 | loss=epoch_loss, 85 | step_i=epoch_i + 1, 86 | name='train_loss') 87 | 88 | epoch_recon_loss = getattr(self, 'epoch_recon_loss', None) 89 | if epoch_recon_loss is not None: 90 | self.writer.update_loss( 91 | loss=epoch_recon_loss, 92 | step_i=epoch_i + 1, 93 | name='train_recon_loss') 94 | 95 | epoch_kl_div = getattr(self, 'epoch_kl_div', None) 96 | if epoch_kl_div is not None: 97 | self.writer.update_loss( 98 | loss=epoch_kl_div, 99 | step_i=epoch_i + 1, 100 | name='train_kl_div') 101 | 102 | kl_mult = getattr(self, 'kl_mult', None) 103 | if kl_mult is not None: 104 | self.writer.update_loss( 105 | loss=kl_mult, 106 | step_i=epoch_i + 1, 107 | name='kl_mult') 108 | 109 | epoch_bow_loss = getattr(self, 'epoch_bow_loss', None) 110 | if epoch_bow_loss is not None: 111 | self.writer.update_loss( 112 | loss=epoch_bow_loss, 113 | step_i=epoch_i + 1, 114 | name='bow_loss') 115 | 116 | validation_loss = getattr(self, 'validation_loss', None) 117 | if validation_loss is not None: 118 | self.writer.update_loss( 119 | loss=validation_loss, 120 | step_i=epoch_i + 1, 121 | name='validation_loss') 122 | 123 | @time_desc_decorator('Training Start!') 124 | def train(self): 125 | epoch_loss_history = [] 126 | for epoch_i in range(self.epoch_i, self.config.n_epoch): 127 | self.epoch_i = epoch_i 128 | batch_loss_history = [] 129 | self.model.train() 130 | n_total_words = 0 131 | for batch_i, (conversations, conversation_length, sentence_length) in enumerate(tqdm(self.train_data_loader, ncols=80)): 132 | # conversations: (batch_size) list of conversations 133 | # conversation: list of sentences 134 | # sentence: list of tokens 135 | # conversation_length: list of int 136 | # sentence_length: (batch_size) list of conversation list of sentence_lengths 137 | 138 | input_conversations = [conv[:-1] for conv in conversations] 139 | target_conversations = [conv[1:] for conv in conversations] 140 | 141 | # flatten input and target conversations 142 | input_sentences = [sent for conv in input_conversations for sent in conv] 143 | target_sentences = [sent for conv in target_conversations for sent in conv] 144 | input_sentence_length = [l for len_list in sentence_length for l in len_list[:-1]] 145 | target_sentence_length = [l for len_list in sentence_length for l in len_list[1:]] 146 | input_conversation_length = [l - 1 for l in conversation_length] 147 | 148 | input_sentences = to_var(torch.LongTensor(input_sentences)) 149 | target_sentences = to_var(torch.LongTensor(target_sentences)) 150 | input_sentence_length = to_var(torch.LongTensor(input_sentence_length)) 151 | target_sentence_length = to_var(torch.LongTensor(target_sentence_length)) 152 | input_conversation_length = to_var(torch.LongTensor(input_conversation_length)) 153 | 154 | # reset gradient 155 | self.optimizer.zero_grad() 156 | 157 | sentence_logits = self.model( 158 | input_sentences, 159 | input_sentence_length, 160 | input_conversation_length, 161 | target_sentences, 162 | decode=False) 163 | 164 | batch_loss, n_words = masked_cross_entropy( 165 | sentence_logits, 166 | target_sentences, 167 | target_sentence_length) 168 | 169 | assert not isnan(batch_loss.item()) 170 | batch_loss_history.append(batch_loss.item()) 171 | n_total_words += n_words.item() 172 | 173 | if batch_i % self.config.print_every == 0: 174 | tqdm.write( 175 | f'Epoch: {epoch_i+1}, iter {batch_i}: loss = {batch_loss.item()/ n_words.item():.3f}') 176 | 177 | # Back-propagation 178 | batch_loss.backward() 179 | 180 | # Gradient cliping 181 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip) 182 | 183 | # Run optimizer 184 | self.optimizer.step() 185 | 186 | epoch_loss = np.sum(batch_loss_history) / n_total_words 187 | epoch_loss_history.append(epoch_loss) 188 | self.epoch_loss = epoch_loss 189 | 190 | print_str = f'Epoch {epoch_i+1} loss average: {epoch_loss:.3f}' 191 | print(print_str) 192 | 193 | if epoch_i % self.config.save_every_epoch == 0: 194 | self.save_model(epoch_i + 1) 195 | 196 | print('\n...') 197 | self.validation_loss = self.evaluate() 198 | 199 | #if epoch_i % self.config.plot_every_epoch == 0: 200 | # self.write_summary(epoch_i) 201 | 202 | self.save_model(self.config.n_epoch) 203 | 204 | return epoch_loss_history 205 | 206 | def generate_sentence(self, input_sentences, input_sentence_length, 207 | input_conversation_length, target_sentences): 208 | self.model.eval() 209 | 210 | # [batch_size, max_seq_len, vocab_size] 211 | generated_sentences = self.model( 212 | input_sentences, 213 | input_sentence_length, 214 | input_conversation_length, 215 | target_sentences, 216 | decode=True) 217 | 218 | # write output to file 219 | with open(os.path.join(self.config.save_path, 'samples.txt'), 'a') as f: 220 | f.write(f'\n\n') 221 | 222 | tqdm.write('\n') 223 | for input_sent, target_sent, output_sent in zip(input_sentences, target_sentences, generated_sentences): 224 | input_sent = self.vocab.decode(input_sent) 225 | target_sent = self.vocab.decode(target_sent) 226 | output_sent = '\n'.join([self.vocab.decode(sent) for sent in output_sent]) 227 | s = '\n'.join(['Input sentence: ' + input_sent, 228 | 'Ground truth: ' + target_sent, 229 | 'Generated response: ' + output_sent + '\n']) 230 | f.write(s + '\n') 231 | print(s) 232 | print('') 233 | 234 | def evaluate(self): 235 | self.model.eval() 236 | batch_loss_history = [] 237 | n_total_words = 0 238 | for batch_i, (conversations, conversation_length, sentence_length) in enumerate(tqdm(self.eval_data_loader, ncols=80)): 239 | # conversations: (batch_size) list of conversations 240 | # conversation: list of sentences 241 | # sentence: list of tokens 242 | # conversation_length: list of int 243 | # sentence_length: (batch_size) list of conversation list of sentence_lengths 244 | 245 | input_conversations = [conv[:-1] for conv in conversations] 246 | target_conversations = [conv[1:] for conv in conversations] 247 | 248 | # flatten input and target conversations 249 | input_sentences = [sent for conv in input_conversations for sent in conv] 250 | target_sentences = [sent for conv in target_conversations for sent in conv] 251 | input_sentence_length = [l for len_list in sentence_length for l in len_list[:-1]] 252 | target_sentence_length = [l for len_list in sentence_length for l in len_list[1:]] 253 | input_conversation_length = [l - 1 for l in conversation_length] 254 | 255 | with torch.no_grad(): 256 | input_sentences = to_var(torch.LongTensor(input_sentences)) 257 | target_sentences = to_var(torch.LongTensor(target_sentences)) 258 | input_sentence_length = to_var(torch.LongTensor(input_sentence_length)) 259 | target_sentence_length = to_var(torch.LongTensor(target_sentence_length)) 260 | input_conversation_length = to_var( 261 | torch.LongTensor(input_conversation_length)) 262 | 263 | if batch_i == 0: 264 | self.generate_sentence(input_sentences, 265 | input_sentence_length, 266 | input_conversation_length, 267 | target_sentences) 268 | 269 | sentence_logits = self.model( 270 | input_sentences, 271 | input_sentence_length, 272 | input_conversation_length, 273 | target_sentences) 274 | 275 | batch_loss, n_words = masked_cross_entropy( 276 | sentence_logits, 277 | target_sentences, 278 | target_sentence_length) 279 | 280 | assert not isnan(batch_loss.item()) 281 | batch_loss_history.append(batch_loss.item()) 282 | n_total_words += n_words.item() 283 | 284 | epoch_loss = np.sum(batch_loss_history) / n_total_words 285 | 286 | print_str = f'Validation loss: {epoch_loss:.3f}\n' 287 | print(print_str) 288 | 289 | return epoch_loss 290 | 291 | def test(self): 292 | self.model.eval() 293 | batch_loss_history = [] 294 | n_total_words = 0 295 | for batch_i, (conversations, conversation_length, sentence_length) in enumerate(tqdm(self.eval_data_loader, ncols=80)): 296 | # conversations: (batch_size) list of conversations 297 | # conversation: list of sentences 298 | # sentence: list of tokens 299 | # conversation_length: list of int 300 | # sentence_length: (batch_size) list of conversation list of sentence_lengths 301 | 302 | input_conversations = [conv[:-1] for conv in conversations] 303 | target_conversations = [conv[1:] for conv in conversations] 304 | 305 | # flatten input and target conversations 306 | input_sentences = [sent for conv in input_conversations for sent in conv] 307 | target_sentences = [sent for conv in target_conversations for sent in conv] 308 | input_sentence_length = [l for len_list in sentence_length for l in len_list[:-1]] 309 | target_sentence_length = [l for len_list in sentence_length for l in len_list[1:]] 310 | input_conversation_length = [l - 1 for l in conversation_length] 311 | 312 | with torch.no_grad(): 313 | input_sentences = to_var(torch.LongTensor(input_sentences)) 314 | target_sentences = to_var(torch.LongTensor(target_sentences)) 315 | input_sentence_length = to_var(torch.LongTensor(input_sentence_length)) 316 | target_sentence_length = to_var(torch.LongTensor(target_sentence_length)) 317 | input_conversation_length = to_var(torch.LongTensor(input_conversation_length)) 318 | 319 | sentence_logits = self.model( 320 | input_sentences, 321 | input_sentence_length, 322 | input_conversation_length, 323 | target_sentences) 324 | 325 | batch_loss, n_words = masked_cross_entropy( 326 | sentence_logits, 327 | target_sentences, 328 | target_sentence_length) 329 | 330 | assert not isnan(batch_loss.item()) 331 | batch_loss_history.append(batch_loss.item()) 332 | n_total_words += n_words.item() 333 | 334 | epoch_loss = np.sum(batch_loss_history) / n_total_words 335 | 336 | print(f'Number of words: {n_total_words}') 337 | print(f'Bits per word: {epoch_loss:.3f}') 338 | word_perplexity = np.exp(epoch_loss) 339 | 340 | print_str = f'Word perplexity : {word_perplexity:.3f}\n' 341 | print(print_str) 342 | 343 | return word_perplexity 344 | 345 | def embedding_metric(self): 346 | word2vec = getattr(self, 'word2vec', None) 347 | if word2vec is None: 348 | print('Loading word2vec model') 349 | word2vec = gensim.models.KeyedVectors.load_word2vec_format(word2vec_path, binary=True) 350 | self.word2vec = word2vec 351 | keys = word2vec.vocab 352 | self.model.eval() 353 | n_context = self.config.n_context 354 | n_sample_step = self.config.n_sample_step 355 | metric_average_history = [] 356 | metric_extrema_history = [] 357 | metric_greedy_history = [] 358 | context_history = [] 359 | sample_history = [] 360 | n_sent = 0 361 | n_conv = 0 362 | for batch_i, (conversations, conversation_length, sentence_length) \ 363 | in enumerate(tqdm(self.eval_data_loader, ncols=80)): 364 | # conversations: (batch_size) list of conversations 365 | # conversation: list of sentences 366 | # sentence: list of tokens 367 | # conversation_length: list of int 368 | # sentence_length: (batch_size) list of conversation list of sentence_lengths 369 | 370 | conv_indices = [i for i in range(len(conversations)) if len(conversations[i]) >= n_context + n_sample_step] 371 | context = [c for i in conv_indices for c in [conversations[i][:n_context]]] 372 | ground_truth = [c for i in conv_indices for c in [conversations[i][n_context:n_context + n_sample_step]]] 373 | sentence_length = [c for i in conv_indices for c in [sentence_length[i][:n_context]]] 374 | 375 | with torch.no_grad(): 376 | context = to_var(torch.LongTensor(context)) 377 | sentence_length = to_var(torch.LongTensor(sentence_length)) 378 | 379 | samples = self.model.generate(context, sentence_length, n_context) 380 | 381 | context = context.data.cpu().numpy().tolist() 382 | samples = samples.data.cpu().numpy().tolist() 383 | context_history.append(context) 384 | sample_history.append(samples) 385 | 386 | samples = [[self.vocab.decode(sent) for sent in c] for c in samples] 387 | ground_truth = [[self.vocab.decode(sent) for sent in c] for c in ground_truth] 388 | 389 | samples = [sent for c in samples for sent in c] 390 | ground_truth = [sent for c in ground_truth for sent in c] 391 | 392 | samples = [[word2vec[s] for s in sent.split() if s in keys] for sent in samples] 393 | ground_truth = [[word2vec[s] for s in sent.split() if s in keys] for sent in ground_truth] 394 | 395 | indices = [i for i, s, g in zip(range(len(samples)), samples, ground_truth) if s != [] and g != []] 396 | samples = [samples[i] for i in indices] 397 | ground_truth = [ground_truth[i] for i in indices] 398 | n = len(samples) 399 | n_sent += n 400 | 401 | metric_average = embedding_metric(samples, ground_truth, word2vec, 'average') 402 | metric_extrema = embedding_metric(samples, ground_truth, word2vec, 'extrema') 403 | metric_greedy = embedding_metric(samples, ground_truth, word2vec, 'greedy') 404 | metric_average_history.append(metric_average) 405 | metric_extrema_history.append(metric_extrema) 406 | metric_greedy_history.append(metric_greedy) 407 | 408 | epoch_average = np.mean(np.concatenate(metric_average_history), axis=0) 409 | epoch_extrema = np.mean(np.concatenate(metric_extrema_history), axis=0) 410 | epoch_greedy = np.mean(np.concatenate(metric_greedy_history), axis=0) 411 | 412 | print('n_sentences:', n_sent) 413 | print_str = f'Metrics - Average: {epoch_average:.3f}, Extrema: {epoch_extrema:.3f}, Greedy: {epoch_greedy:.3f}' 414 | print(print_str) 415 | print('\n') 416 | 417 | return epoch_average, epoch_extrema, epoch_greedy 418 | 419 | def generate_for_evaluation(self): 420 | self.model.eval() 421 | n_sample_step = self.config.n_sample_step 422 | n_sent = 0 423 | fo = open(self.config.pred_path, "w") 424 | for batch_i, (conversations, conversation_length, sentence_length) in enumerate(tqdm(self.eval_data_loader, ncols=80)): 425 | # conversations: (batch_size) list of conversations 426 | # conversation: list of sentences 427 | # sentence: list of tokens 428 | # conversation_length: list of int 429 | # sentence_length: (batch_size) list of conversation list of sentence_lengths 430 | assert len(conversations) == 1 431 | conversation = conversations[0] 432 | context = conversation[:-1] 433 | context_str = ' '.join([self.vocab.decode(sent) for sent in context]) 434 | ground_truth = conversation[-1] 435 | n_context = len(context) 436 | sentence_length = sentence_length[0][:-1] 437 | 438 | with torch.no_grad(): 439 | context = to_var(torch.LongTensor(context)) 440 | context = context.unsqueeze(0) 441 | sentence_length = to_var(torch.LongTensor(sentence_length)) 442 | sentence_length = sentence_length.unsqueeze(0) 443 | 444 | samples = self.model.generate(context, sentence_length, n_context) 445 | 446 | samples = samples.data.cpu().numpy().tolist() 447 | sent = samples[0][0] 448 | sample = self.vocab.decode(sent) 449 | ground_truth = self.vocab.decode(ground_truth) 450 | 451 | n_sent += 1 452 | 453 | #print("ground_truth: ", ground_truth) 454 | #print("gen: ", samples) 455 | 456 | fo.write(context_str + "\t" + ground_truth + "\t" + sample + "\n") 457 | #print(ground_truth + "\t" + sample) 458 | print('n_sentences:', n_sent) 459 | print('\n') 460 | fo.close() 461 | 462 | 463 | 464 | 465 | class VariationalSolver(Solver): 466 | 467 | def __init__(self, config, train_data_loader, eval_data_loader, vocab, is_train=True, model=None): 468 | self.config = config 469 | self.epoch_i = 0 470 | self.train_data_loader = train_data_loader 471 | self.eval_data_loader = eval_data_loader 472 | self.vocab = vocab 473 | self.is_train = is_train 474 | self.model = model 475 | 476 | @time_desc_decorator('Training Start!') 477 | def train(self): 478 | epoch_loss_history = [] 479 | kl_mult = 0.0 480 | conv_kl_mult = 0.0 481 | for epoch_i in range(self.epoch_i, self.config.n_epoch): 482 | self.epoch_i = epoch_i 483 | batch_loss_history = [] 484 | recon_loss_history = [] 485 | kl_div_history = [] 486 | kl_div_sent_history = [] 487 | kl_div_conv_history = [] 488 | bow_loss_history = [] 489 | self.model.train() 490 | n_total_words = 0 491 | 492 | # self.evaluate() 493 | 494 | for batch_i, (conversations, conversation_length, sentence_length) \ 495 | in enumerate(tqdm(self.train_data_loader, ncols=80)): 496 | # conversations: (batch_size) list of conversations 497 | # conversation: list of sentences 498 | # sentence: list of tokens 499 | # conversation_length: list of int 500 | # sentence_length: (batch_size) list of conversation list of sentence_lengths 501 | 502 | target_conversations = [conv[1:] for conv in conversations] 503 | 504 | # flatten input and target conversations 505 | sentences = [sent for conv in conversations for sent in conv] 506 | input_conversation_length = [l - 1 for l in conversation_length] 507 | target_sentences = [sent for conv in target_conversations for sent in conv] 508 | target_sentence_length = [l for len_list in sentence_length for l in len_list[1:]] 509 | sentence_length = [l for len_list in sentence_length for l in len_list] 510 | 511 | sentences = to_var(torch.LongTensor(sentences)) 512 | sentence_length = to_var(torch.LongTensor(sentence_length)) 513 | input_conversation_length = to_var(torch.LongTensor(input_conversation_length)) 514 | target_sentences = to_var(torch.LongTensor(target_sentences)) 515 | target_sentence_length = to_var(torch.LongTensor(target_sentence_length)) 516 | 517 | # reset gradient 518 | self.optimizer.zero_grad() 519 | 520 | sentence_logits, kl_div, _, _ = self.model( 521 | sentences, 522 | sentence_length, 523 | input_conversation_length, 524 | target_sentences) 525 | 526 | recon_loss, n_words = masked_cross_entropy( 527 | sentence_logits, 528 | target_sentences, 529 | target_sentence_length) 530 | 531 | batch_loss = recon_loss + kl_mult * kl_div 532 | batch_loss_history.append(batch_loss.item()) 533 | recon_loss_history.append(recon_loss.item()) 534 | kl_div_history.append(kl_div.item()) 535 | n_total_words += n_words.item() 536 | 537 | if self.config.bow: 538 | bow_loss = self.model.compute_bow_loss(target_conversations) 539 | batch_loss += bow_loss 540 | bow_loss_history.append(bow_loss.item()) 541 | 542 | assert not isnan(batch_loss.item()) 543 | 544 | if batch_i % self.config.print_every == 0: 545 | print_str = f'Epoch: {epoch_i+1}, iter {batch_i}: loss = {batch_loss.item() / n_words.item():.3f}, recon = {recon_loss.item() / n_words.item():.3f}, kl_div = {kl_div.item() / n_words.item():.3f}' 546 | if self.config.bow: 547 | print_str += f', bow_loss = {bow_loss.item() / n_words.item():.3f}' 548 | tqdm.write(print_str) 549 | 550 | # Back-propagation 551 | batch_loss.backward() 552 | 553 | # Gradient cliping 554 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.clip) 555 | 556 | # Run optimizer 557 | self.optimizer.step() 558 | kl_mult = min(kl_mult + 1.0 / self.config.kl_annealing_iter, 1.0) 559 | 560 | epoch_loss = np.sum(batch_loss_history) / n_total_words 561 | epoch_loss_history.append(epoch_loss) 562 | 563 | epoch_recon_loss = np.sum(recon_loss_history) / n_total_words 564 | epoch_kl_div = np.sum(kl_div_history) / n_total_words 565 | 566 | self.kl_mult = kl_mult 567 | self.epoch_loss = epoch_loss 568 | self.epoch_recon_loss = epoch_recon_loss 569 | self.epoch_kl_div = epoch_kl_div 570 | 571 | print_str = f'Epoch {epoch_i+1} loss average: {epoch_loss:.3f}, recon_loss: {epoch_recon_loss:.3f}, kl_div: {epoch_kl_div:.3f}' 572 | if bow_loss_history: 573 | self.epoch_bow_loss = np.sum(bow_loss_history) / n_total_words 574 | print_str += f', bow_loss = {self.epoch_bow_loss:.3f}' 575 | print(print_str) 576 | 577 | if epoch_i % self.config.save_every_epoch == 0: 578 | self.save_model(epoch_i + 1) 579 | 580 | print('\n...') 581 | self.validation_loss = self.evaluate() 582 | 583 | #if epoch_i % self.config.plot_every_epoch == 0: 584 | # self.write_summary(epoch_i) 585 | 586 | return epoch_loss_history 587 | 588 | def generate_sentence(self, sentences, sentence_length, 589 | input_conversation_length, input_sentences, target_sentences): 590 | """Generate output of decoder (single batch)""" 591 | self.model.eval() 592 | 593 | # [batch_size, max_seq_len, vocab_size] 594 | generated_sentences, _, _, _ = self.model( 595 | sentences, 596 | sentence_length, 597 | input_conversation_length, 598 | target_sentences, 599 | decode=True) 600 | 601 | # write output to file 602 | with open(os.path.join(self.config.save_path, 'samples.txt'), 'a') as f: 603 | f.write(f'\n\n') 604 | 605 | tqdm.write('\n') 606 | for input_sent, target_sent, output_sent in zip(input_sentences, target_sentences, generated_sentences): 607 | input_sent = self.vocab.decode(input_sent) 608 | target_sent = self.vocab.decode(target_sent) 609 | output_sent = '\n'.join([self.vocab.decode(sent) for sent in output_sent]) 610 | s = '\n'.join(['Input sentence: ' + input_sent, 611 | 'Ground truth: ' + target_sent, 612 | 'Generated response: ' + output_sent + '\n']) 613 | f.write(s + '\n') 614 | print(s) 615 | print('') 616 | 617 | def evaluate(self): 618 | self.model.eval() 619 | batch_loss_history = [] 620 | recon_loss_history = [] 621 | kl_div_history = [] 622 | bow_loss_history = [] 623 | n_total_words = 0 624 | for batch_i, (conversations, conversation_length, sentence_length) \ 625 | in enumerate(tqdm(self.eval_data_loader, ncols=80)): 626 | # conversations: (batch_size) list of conversations 627 | # conversation: list of sentences 628 | # sentence: list of tokens 629 | # conversation_length: list of int 630 | # sentence_length: (batch_size) list of conversation list of sentence_lengths 631 | 632 | target_conversations = [conv[1:] for conv in conversations] 633 | 634 | # flatten input and target conversations 635 | sentences = [sent for conv in conversations for sent in conv] 636 | input_conversation_length = [l - 1 for l in conversation_length] 637 | target_sentences = [sent for conv in target_conversations for sent in conv] 638 | target_sentence_length = [l for len_list in sentence_length for l in len_list[1:]] 639 | sentence_length = [l for len_list in sentence_length for l in len_list] 640 | 641 | with torch.no_grad(): 642 | sentences = to_var(torch.LongTensor(sentences)) 643 | sentence_length = to_var(torch.LongTensor(sentence_length)) 644 | input_conversation_length = to_var( 645 | torch.LongTensor(input_conversation_length)) 646 | target_sentences = to_var(torch.LongTensor(target_sentences)) 647 | target_sentence_length = to_var(torch.LongTensor(target_sentence_length)) 648 | 649 | if batch_i == 0: 650 | input_conversations = [conv[:-1] for conv in conversations] 651 | input_sentences = [sent for conv in input_conversations for sent in conv] 652 | with torch.no_grad(): 653 | input_sentences = to_var(torch.LongTensor(input_sentences)) 654 | self.generate_sentence(sentences, 655 | sentence_length, 656 | input_conversation_length, 657 | input_sentences, 658 | target_sentences) 659 | 660 | sentence_logits, kl_div, _, _ = self.model( 661 | sentences, 662 | sentence_length, 663 | input_conversation_length, 664 | target_sentences) 665 | 666 | recon_loss, n_words = masked_cross_entropy( 667 | sentence_logits, 668 | target_sentences, 669 | target_sentence_length) 670 | 671 | batch_loss = recon_loss + kl_div 672 | if self.config.bow: 673 | bow_loss = self.model.compute_bow_loss(target_conversations) 674 | bow_loss_history.append(bow_loss.item()) 675 | 676 | assert not isnan(batch_loss.item()) 677 | batch_loss_history.append(batch_loss.item()) 678 | recon_loss_history.append(recon_loss.item()) 679 | kl_div_history.append(kl_div.item()) 680 | n_total_words += n_words.item() 681 | 682 | epoch_loss = np.sum(batch_loss_history) / n_total_words 683 | epoch_recon_loss = np.sum(recon_loss_history) / n_total_words 684 | epoch_kl_div = np.sum(kl_div_history) / n_total_words 685 | 686 | print_str = f'Validation loss: {epoch_loss:.3f}, recon_loss: {epoch_recon_loss:.3f}, kl_div: {epoch_kl_div:.3f}' 687 | if bow_loss_history: 688 | epoch_bow_loss = np.sum(bow_loss_history) / n_total_words 689 | print_str += f', bow_loss = {epoch_bow_loss:.3f}' 690 | print(print_str) 691 | print('\n') 692 | 693 | return epoch_loss 694 | 695 | def generate_for_evaluation(self): 696 | self.model.eval() 697 | n_sample_step = self.config.n_sample_step 698 | n_sent = 0 699 | fo = open(self.config.pred_path, "w") 700 | for batch_i, (conversations, conversation_length, sentence_length) in enumerate(tqdm(self.eval_data_loader, ncols=80)): 701 | # conversations: (batch_size) list of conversations 702 | # conversation: list of sentences 703 | # sentence: list of tokens 704 | # conversation_length: list of int 705 | # sentence_length: (batch_size) list of conversation list of sentence_lengths 706 | assert len(conversations) == 1 707 | conversation = conversations[0] 708 | context = conversation[:-1] 709 | context_str = ' '.join([self.vocab.decode(sent) for sent in context]) 710 | ground_truth = conversation[-1] 711 | n_context = len(context) 712 | sentence_length = sentence_length[0][:-1] 713 | 714 | with torch.no_grad(): 715 | context = to_var(torch.LongTensor(context)) 716 | context = context.unsqueeze(0) 717 | sentence_length = to_var(torch.LongTensor(sentence_length)) 718 | sentence_length = sentence_length.unsqueeze(0) 719 | 720 | samples = self.model.generate(context, sentence_length, n_context) 721 | 722 | samples = samples.data.cpu().numpy().tolist() 723 | sent = samples[0][0] 724 | sample = self.vocab.decode(sent) 725 | ground_truth = self.vocab.decode(ground_truth) 726 | 727 | n_sent += 1 728 | 729 | #print("ground_truth: ", ground_truth) 730 | #print("gen: ", samples) 731 | 732 | fo.write(context_str + "\t" + ground_truth + "\t" + sample + "\n") 733 | #print(ground_truth + "\t" + sample) 734 | print('n_sentences:', n_sent) 735 | print('\n') 736 | fo.close() 737 | 738 | 739 | def importance_sample(self): 740 | ''' Perform importance sampling to get tighter bound 741 | ''' 742 | self.model.eval() 743 | weight_history = [] 744 | n_total_words = 0 745 | kl_div_history = [] 746 | for batch_i, (conversations, conversation_length, sentence_length) \ 747 | in enumerate(tqdm(self.eval_data_loader, ncols=80)): 748 | # conversations: (batch_size) list of conversations 749 | # conversation: list of sentences 750 | # sentence: list of tokens 751 | # conversation_length: list of int 752 | # sentence_length: (batch_size) list of conversation list of sentence_lengths 753 | 754 | target_conversations = [conv[1:] for conv in conversations] 755 | 756 | # flatten input and target conversations 757 | sentences = [sent for conv in conversations for sent in conv] 758 | input_conversation_length = [l - 1 for l in conversation_length] 759 | target_sentences = [sent for conv in target_conversations for sent in conv] 760 | target_sentence_length = [l for len_list in sentence_length for l in len_list[1:]] 761 | sentence_length = [l for len_list in sentence_length for l in len_list] 762 | 763 | # n_words += sum([len([word for word in sent if word != PAD_ID]) for sent in target_sentences]) 764 | with torch.no_grad(): 765 | sentences = to_var(torch.LongTensor(sentences)) 766 | sentence_length = to_var(torch.LongTensor(sentence_length)) 767 | input_conversation_length = to_var( 768 | torch.LongTensor(input_conversation_length)) 769 | target_sentences = to_var(torch.LongTensor(target_sentences)) 770 | target_sentence_length = to_var(torch.LongTensor(target_sentence_length)) 771 | 772 | # treat whole batch as one data sample 773 | weights = [] 774 | for j in range(self.config.importance_sample): 775 | sentence_logits, kl_div, log_p_z, log_q_zx = self.model( 776 | sentences, 777 | sentence_length, 778 | input_conversation_length, 779 | target_sentences) 780 | 781 | recon_loss, n_words = masked_cross_entropy( 782 | sentence_logits, 783 | target_sentences, 784 | target_sentence_length) 785 | 786 | log_w = (-recon_loss.sum() + log_p_z - log_q_zx).data 787 | weights.append(log_w) 788 | if j == 0: 789 | n_total_words += n_words.item() 790 | kl_div_history.append(kl_div.item()) 791 | 792 | # weights: [n_samples] 793 | weights = torch.stack(weights, 0) 794 | m = np.floor(weights.max()) 795 | weights = np.log(torch.exp(weights - m).sum()) 796 | weights = m + weights - np.log(self.config.importance_sample) 797 | weight_history.append(weights) 798 | 799 | print(f'Number of words: {n_total_words}') 800 | bits_per_word = -np.sum(weight_history) / n_total_words 801 | print(f'Bits per word: {bits_per_word:.3f}') 802 | word_perplexity = np.exp(bits_per_word) 803 | 804 | epoch_kl_div = np.sum(kl_div_history) / n_total_words 805 | 806 | print_str = f'Word perplexity upperbound using {self.config.importance_sample} importance samples: {word_perplexity:.3f}, kl_div: {epoch_kl_div:.3f}\n' 807 | print(print_str) 808 | 809 | return word_perplexity 810 | 811 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils import to_var, pad, normal_kl_div, normal_logpdf, bag_of_words_loss, to_bow, EOS_ID 4 | import layers 5 | import numpy as np 6 | import random 7 | 8 | VariationalModels = ['VHRED', 'VHCR'] 9 | 10 | class HRED(nn.Module): 11 | def __init__(self, config): 12 | super(HRED, self).__init__() 13 | 14 | self.config = config 15 | self.encoder = layers.EncoderRNN(config.vocab_size, 16 | config.embedding_size, 17 | config.encoder_hidden_size, 18 | config.rnn, 19 | config.num_layers, 20 | config.bidirectional, 21 | config.dropout) 22 | 23 | context_input_size = (config.num_layers 24 | * config.encoder_hidden_size 25 | * self.encoder.num_directions) 26 | self.context_encoder = layers.ContextRNN(context_input_size, 27 | config.context_size, 28 | config.rnn, 29 | config.num_layers, 30 | config.dropout) 31 | 32 | self.decoder = layers.DecoderRNN(config.vocab_size, 33 | config.embedding_size, 34 | config.decoder_hidden_size, 35 | config.rnncell, 36 | config.num_layers, 37 | config.dropout, 38 | config.word_drop, 39 | config.max_unroll, 40 | config.sample, 41 | config.temperature, 42 | config.beam_size) 43 | 44 | self.context2decoder = layers.FeedForward(config.context_size, 45 | config.num_layers * config.decoder_hidden_size, 46 | num_layers=1, 47 | activation=config.activation) 48 | 49 | if config.tie_embedding: 50 | self.decoder.embedding = self.encoder.embedding 51 | 52 | def forward(self, input_sentences, input_sentence_length, 53 | input_conversation_length, target_sentences, decode=False): 54 | """ 55 | Args: 56 | input_sentences: (Variable, LongTensor) [num_sentences, seq_len] 57 | target_sentences: (Variable, LongTensor) [num_sentences, seq_len] 58 | Return: 59 | decoder_outputs: (Variable, FloatTensor) 60 | - train: [batch_size, seq_len, vocab_size] 61 | - eval: [batch_size, seq_len] 62 | """ 63 | num_sentences = input_sentences.size(0) 64 | max_len = input_conversation_length.data.max().item() 65 | 66 | # encoder_outputs: [num_sentences, max_source_length, hidden_size * direction] 67 | # encoder_hidden: [num_layers * direction, num_sentences, hidden_size] 68 | encoder_outputs, encoder_hidden = self.encoder(input_sentences, 69 | input_sentence_length) 70 | 71 | # encoder_hidden: [num_sentences, num_layers * direction * hidden_size] 72 | encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(num_sentences, -1) 73 | 74 | # pad and pack encoder_hidden 75 | start = torch.cumsum(torch.cat((to_var(input_conversation_length.data.new(1).zero_()), 76 | input_conversation_length[:-1])), 0) 77 | 78 | # encoder_hidden: [batch_size, max_len, num_layers * direction * hidden_size] 79 | encoder_hidden = torch.stack([pad(encoder_hidden.narrow(0, s, l), max_len) 80 | for s, l in zip(start.data.tolist(), 81 | input_conversation_length.data.tolist())], 0) 82 | 83 | # context_outputs: [batch_size, max_len, context_size] 84 | context_outputs, context_last_hidden = self.context_encoder(encoder_hidden, 85 | input_conversation_length) 86 | 87 | # flatten outputs 88 | # context_outputs: [num_sentences, context_size] 89 | context_outputs = torch.cat([context_outputs[i, :l, :] 90 | for i, l in enumerate(input_conversation_length.data)]) 91 | 92 | # project context_outputs to decoder init state 93 | decoder_init = self.context2decoder(context_outputs) 94 | 95 | # [num_layers, batch_size, hidden_size] 96 | decoder_init = decoder_init.view(self.decoder.num_layers, -1, self.decoder.hidden_size) 97 | 98 | # train: [batch_size, seq_len, vocab_size] 99 | # eval: [batch_size, seq_len] 100 | if not decode: 101 | 102 | decoder_outputs = self.decoder(target_sentences, 103 | init_h=decoder_init, 104 | decode=decode) 105 | return decoder_outputs 106 | 107 | else: 108 | # decoder_outputs = self.decoder(target_sentences, 109 | # init_h=decoder_init, 110 | # decode=decode) 111 | # return decoder_outputs.unsqueeze(1) 112 | # prediction: [batch_size, beam_size, max_unroll] 113 | prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init) 114 | 115 | # Get top prediction only 116 | # [batch_size, max_unroll] 117 | # prediction = prediction[:, 0] 118 | 119 | # [batch_size, beam_size, max_unroll] 120 | return prediction 121 | 122 | def generate(self, context, sentence_length, n_context): 123 | # context: [batch_size, n_context, seq_len] 124 | batch_size = context.size(0) 125 | # n_context = context.size(1) 126 | samples = [] 127 | 128 | # Run for context 129 | context_hidden=None 130 | for i in range(n_context): 131 | # encoder_outputs: [batch_size, seq_len, hidden_size * direction] 132 | # encoder_hidden: [num_layers * direction, batch_size, hidden_size] 133 | encoder_outputs, encoder_hidden = self.encoder(context[:, i, :], 134 | sentence_length[:, i]) 135 | 136 | encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1) 137 | # context_outputs: [batch_size, 1, context_hidden_size * direction] 138 | # context_hidden: [num_layers * direction, batch_size, context_hidden_size] 139 | context_outputs, context_hidden = self.context_encoder.step(encoder_hidden, 140 | context_hidden) 141 | 142 | # Run for generation 143 | for j in range(self.config.n_sample_step): 144 | # context_outputs: [batch_size, context_hidden_size * direction] 145 | context_outputs = context_outputs.squeeze(1) 146 | decoder_init = self.context2decoder(context_outputs) 147 | decoder_init = decoder_init.view(self.decoder.num_layers, -1, self.decoder.hidden_size) 148 | 149 | prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init) 150 | # prediction: [batch_size, seq_len] 151 | prediction = prediction[:, 0, :] 152 | # length: [batch_size] 153 | length = [l[0] for l in length] 154 | length = to_var(torch.LongTensor(length)) 155 | samples.append(prediction) 156 | 157 | encoder_outputs, encoder_hidden = self.encoder(prediction, 158 | length) 159 | 160 | encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1) 161 | 162 | context_outputs, context_hidden = self.context_encoder.step(encoder_hidden, 163 | context_hidden) 164 | 165 | samples = torch.stack(samples, 1) 166 | return samples 167 | 168 | 169 | class VHRED(nn.Module): 170 | def __init__(self, config): 171 | super(VHRED, self).__init__() 172 | 173 | self.config = config 174 | self.encoder = layers.EncoderRNN(config.vocab_size, 175 | config.embedding_size, 176 | config.encoder_hidden_size, 177 | config.rnn, 178 | config.num_layers, 179 | config.bidirectional, 180 | config.dropout) 181 | 182 | context_input_size = (config.num_layers 183 | * config.encoder_hidden_size 184 | * self.encoder.num_directions) 185 | self.context_encoder = layers.ContextRNN(context_input_size, 186 | config.context_size, 187 | config.rnn, 188 | config.num_layers, 189 | config.dropout) 190 | 191 | self.decoder = layers.DecoderRNN(config.vocab_size, 192 | config.embedding_size, 193 | config.decoder_hidden_size, 194 | config.rnncell, 195 | config.num_layers, 196 | config.dropout, 197 | config.word_drop, 198 | config.max_unroll, 199 | config.sample, 200 | config.temperature, 201 | config.beam_size) 202 | 203 | self.context2decoder = layers.FeedForward(config.context_size + config.z_sent_size, 204 | config.num_layers * config.decoder_hidden_size, 205 | num_layers=1, 206 | activation=config.activation) 207 | 208 | self.softplus = nn.Softplus() 209 | self.prior_h = layers.FeedForward(config.context_size, 210 | config.context_size, 211 | num_layers=2, 212 | hidden_size=config.context_size, 213 | activation=config.activation) 214 | self.prior_mu = nn.Linear(config.context_size, 215 | config.z_sent_size) 216 | self.prior_var = nn.Linear(config.context_size, 217 | config.z_sent_size) 218 | 219 | self.posterior_h = layers.FeedForward(config.encoder_hidden_size * self.encoder.num_directions * config.num_layers + config.context_size, 220 | config.context_size, 221 | num_layers=2, 222 | hidden_size=config.context_size, 223 | activation=config.activation) 224 | self.posterior_mu = nn.Linear(config.context_size, 225 | config.z_sent_size) 226 | self.posterior_var = nn.Linear(config.context_size, 227 | config.z_sent_size) 228 | if config.tie_embedding: 229 | self.decoder.embedding = self.encoder.embedding 230 | 231 | if config.bow: 232 | self.bow_h = layers.FeedForward(config.z_sent_size, 233 | config.decoder_hidden_size, 234 | num_layers=1, 235 | hidden_size=config.decoder_hidden_size, 236 | activation=config.activation) 237 | self.bow_predict = nn.Linear(config.decoder_hidden_size, config.vocab_size) 238 | 239 | def prior(self, context_outputs): 240 | # Context dependent prior 241 | h_prior = self.prior_h(context_outputs) 242 | mu_prior = self.prior_mu(h_prior) 243 | var_prior = self.softplus(self.prior_var(h_prior)) 244 | return mu_prior, var_prior 245 | 246 | def posterior(self, context_outputs, encoder_hidden): 247 | h_posterior = self.posterior_h(torch.cat([context_outputs, encoder_hidden], 1)) 248 | mu_posterior = self.posterior_mu(h_posterior) 249 | var_posterior = self.softplus(self.posterior_var(h_posterior)) 250 | return mu_posterior, var_posterior 251 | 252 | def compute_bow_loss(self, target_conversations): 253 | target_bow = np.stack([to_bow(sent, self.config.vocab_size) for conv in target_conversations for sent in conv], axis=0) 254 | target_bow = to_var(torch.FloatTensor(target_bow)) 255 | bow_logits = self.bow_predict(self.bow_h(self.z_sent)) 256 | bow_loss = bag_of_words_loss(bow_logits, target_bow) 257 | return bow_loss 258 | 259 | def forward(self, sentences, sentence_length, 260 | input_conversation_length, target_sentences, decode=False): 261 | """ 262 | Args: 263 | sentences: (Variable, LongTensor) [num_sentences + batch_size, seq_len] 264 | target_sentences: (Variable, LongTensor) [num_sentences, seq_len] 265 | Return: 266 | decoder_outputs: (Variable, FloatTensor) 267 | - train: [batch_size, seq_len, vocab_size] 268 | - eval: [batch_size, seq_len] 269 | """ 270 | batch_size = input_conversation_length.size(0) 271 | num_sentences = sentences.size(0) - batch_size 272 | max_len = input_conversation_length.data.max().item() 273 | 274 | # encoder_outputs: [num_sentences + batch_size, max_source_length, hidden_size] 275 | # encoder_hidden: [num_layers * direction, num_sentences + batch_size, hidden_size] 276 | encoder_outputs, encoder_hidden = self.encoder(sentences, 277 | sentence_length) 278 | 279 | # encoder_hidden: [num_sentences + batch_size, num_layers * direction * hidden_size] 280 | encoder_hidden = encoder_hidden.transpose( 281 | 1, 0).contiguous().view(num_sentences + batch_size, -1) 282 | 283 | # pad and pack encoder_hidden 284 | start = torch.cumsum(torch.cat((to_var(input_conversation_length.data.new(1).zero_()), 285 | input_conversation_length[:-1] + 1)), 0) 286 | # encoder_hidden: [batch_size, max_len + 1, num_layers * direction * hidden_size] 287 | encoder_hidden = torch.stack([pad(encoder_hidden.narrow(0, s, l + 1), max_len + 1) 288 | for s, l in zip(start.data.tolist(), 289 | input_conversation_length.data.tolist())], 0) 290 | 291 | # encoder_hidden_inference: [batch_size, max_len, num_layers * direction * hidden_size] 292 | encoder_hidden_inference = encoder_hidden[:, 1:, :] 293 | encoder_hidden_inference_flat = torch.cat( 294 | [encoder_hidden_inference[i, :l, :] for i, l in enumerate(input_conversation_length.data)]) 295 | 296 | # encoder_hidden_input: [batch_size, max_len, num_layers * direction * hidden_size] 297 | encoder_hidden_input = encoder_hidden[:, :-1, :] 298 | 299 | # context_outputs: [batch_size, max_len, context_size] 300 | context_outputs, context_last_hidden = self.context_encoder(encoder_hidden_input, 301 | input_conversation_length) 302 | # flatten outputs 303 | # context_outputs: [num_sentences, context_size] 304 | context_outputs = torch.cat([context_outputs[i, :l, :] 305 | for i, l in enumerate(input_conversation_length.data)]) 306 | 307 | mu_prior, var_prior = self.prior(context_outputs) 308 | eps = to_var(torch.randn((num_sentences, self.config.z_sent_size))) 309 | if not decode: 310 | mu_posterior, var_posterior = self.posterior( 311 | context_outputs, encoder_hidden_inference_flat) 312 | z_sent = mu_posterior + torch.sqrt(var_posterior) * eps 313 | log_q_zx = normal_logpdf(z_sent, mu_posterior, var_posterior).sum() 314 | 315 | log_p_z = normal_logpdf(z_sent, mu_prior, var_prior).sum() 316 | # kl_div: [num_sentneces] 317 | kl_div = normal_kl_div(mu_posterior, var_posterior, 318 | mu_prior, var_prior) 319 | kl_div = torch.sum(kl_div) 320 | else: 321 | z_sent = mu_prior + torch.sqrt(var_prior) * eps 322 | kl_div = None 323 | log_p_z = normal_logpdf(z_sent, mu_prior, var_prior).sum() 324 | log_q_zx = None 325 | 326 | self.z_sent = z_sent 327 | latent_context = torch.cat([context_outputs, z_sent], 1) 328 | decoder_init = self.context2decoder(latent_context) 329 | decoder_init = decoder_init.view(-1, 330 | self.decoder.num_layers, 331 | self.decoder.hidden_size) 332 | decoder_init = decoder_init.transpose(1, 0).contiguous() 333 | 334 | # train: [batch_size, seq_len, vocab_size] 335 | # eval: [batch_size, seq_len] 336 | if not decode: 337 | 338 | decoder_outputs = self.decoder(target_sentences, 339 | init_h=decoder_init, 340 | decode=decode) 341 | 342 | return decoder_outputs, kl_div, log_p_z, log_q_zx 343 | 344 | else: 345 | # prediction: [batch_size, beam_size, max_unroll] 346 | prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init) 347 | 348 | return prediction, kl_div, log_p_z, log_q_zx 349 | 350 | def generate(self, context, sentence_length, n_context): 351 | # context: [batch_size, n_context, seq_len] 352 | batch_size = context.size(0) 353 | # n_context = context.size(1) 354 | samples = [] 355 | 356 | # Run for context 357 | context_hidden=None 358 | for i in range(n_context): 359 | # encoder_outputs: [batch_size, seq_len, hidden_size * direction] 360 | # encoder_hidden: [num_layers * direction, batch_size, hidden_size] 361 | encoder_outputs, encoder_hidden = self.encoder(context[:, i, :], 362 | sentence_length[:, i]) 363 | 364 | encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1) 365 | # context_outputs: [batch_size, 1, context_hidden_size * direction] 366 | # context_hidden: [num_layers * direction, batch_size, context_hidden_size] 367 | context_outputs, context_hidden = self.context_encoder.step(encoder_hidden, 368 | context_hidden) 369 | 370 | # Run for generation 371 | for j in range(self.config.n_sample_step): 372 | # context_outputs: [batch_size, context_hidden_size * direction] 373 | context_outputs = context_outputs.squeeze(1) 374 | 375 | mu_prior, var_prior = self.prior(context_outputs) 376 | eps = to_var(torch.randn((batch_size, self.config.z_sent_size))) 377 | z_sent = mu_prior + torch.sqrt(var_prior) * eps 378 | 379 | latent_context = torch.cat([context_outputs, z_sent], 1) 380 | decoder_init = self.context2decoder(latent_context) 381 | decoder_init = decoder_init.view(self.decoder.num_layers, -1, self.decoder.hidden_size) 382 | 383 | if self.config.sample: 384 | prediction = self.decoder(None, decoder_init) 385 | p = prediction.data.cpu().numpy() 386 | length = torch.from_numpy(np.where(p == EOS_ID)[1]) 387 | else: 388 | prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init) 389 | # prediction: [batch_size, seq_len] 390 | prediction = prediction[:, 0, :] 391 | # length: [batch_size] 392 | length = [l[0] for l in length] 393 | length = to_var(torch.LongTensor(length)) 394 | 395 | samples.append(prediction) 396 | 397 | encoder_outputs, encoder_hidden = self.encoder(prediction, 398 | length) 399 | 400 | encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1) 401 | 402 | context_outputs, context_hidden = self.context_encoder.step(encoder_hidden, 403 | context_hidden) 404 | 405 | samples = torch.stack(samples, 1) 406 | return samples 407 | 408 | 409 | class VHCR(nn.Module): 410 | def __init__(self, config): 411 | super(VHCR, self).__init__() 412 | 413 | self.config = config 414 | self.encoder = layers.EncoderRNN(config.vocab_size, 415 | config.embedding_size, 416 | config.encoder_hidden_size, 417 | config.rnn, 418 | config.num_layers, 419 | config.bidirectional, 420 | config.dropout) 421 | 422 | context_input_size = (config.num_layers 423 | * config.encoder_hidden_size 424 | * self.encoder.num_directions + config.z_conv_size) 425 | self.context_encoder = layers.ContextRNN(context_input_size, 426 | config.context_size, 427 | config.rnn, 428 | config.num_layers, 429 | config.dropout) 430 | 431 | self.unk_sent = nn.Parameter(torch.randn(context_input_size - config.z_conv_size)) 432 | 433 | self.z_conv2context = layers.FeedForward(config.z_conv_size, 434 | config.num_layers * config.context_size, 435 | num_layers=1, 436 | activation=config.activation) 437 | 438 | context_input_size = (config.num_layers 439 | * config.encoder_hidden_size 440 | * self.encoder.num_directions) 441 | self.context_inference = layers.ContextRNN(context_input_size, 442 | config.context_size, 443 | config.rnn, 444 | config.num_layers, 445 | config.dropout, 446 | bidirectional=True) 447 | 448 | self.decoder = layers.DecoderRNN(config.vocab_size, 449 | config.embedding_size, 450 | config.decoder_hidden_size, 451 | config.rnncell, 452 | config.num_layers, 453 | config.dropout, 454 | config.word_drop, 455 | config.max_unroll, 456 | config.sample, 457 | config.temperature, 458 | config.beam_size) 459 | 460 | self.context2decoder = layers.FeedForward(config.context_size + config.z_sent_size + config.z_conv_size, 461 | config.num_layers * config.decoder_hidden_size, 462 | num_layers=1, 463 | activation=config.activation) 464 | 465 | self.softplus = nn.Softplus() 466 | 467 | self.conv_posterior_h = layers.FeedForward(config.num_layers * self.context_inference.num_directions * config.context_size, 468 | config.context_size, 469 | num_layers=2, 470 | hidden_size=config.context_size, 471 | activation=config.activation) 472 | self.conv_posterior_mu = nn.Linear(config.context_size, 473 | config.z_conv_size) 474 | self.conv_posterior_var = nn.Linear(config.context_size, 475 | config.z_conv_size) 476 | 477 | self.sent_prior_h = layers.FeedForward(config.context_size + config.z_conv_size, 478 | config.context_size, 479 | num_layers=1, 480 | hidden_size=config.z_sent_size, 481 | activation=config.activation) 482 | self.sent_prior_mu = nn.Linear(config.context_size, 483 | config.z_sent_size) 484 | self.sent_prior_var = nn.Linear(config.context_size, 485 | config.z_sent_size) 486 | 487 | self.sent_posterior_h = layers.FeedForward(config.z_conv_size + config.encoder_hidden_size * self.encoder.num_directions * config.num_layers + config.context_size, 488 | config.context_size, 489 | num_layers=2, 490 | hidden_size=config.context_size, 491 | activation=config.activation) 492 | self.sent_posterior_mu = nn.Linear(config.context_size, 493 | config.z_sent_size) 494 | self.sent_posterior_var = nn.Linear(config.context_size, 495 | config.z_sent_size) 496 | 497 | if config.tie_embedding: 498 | self.decoder.embedding = self.encoder.embedding 499 | 500 | def conv_prior(self): 501 | # Standard gaussian prior 502 | return to_var(torch.FloatTensor([0.0])), to_var(torch.FloatTensor([1.0])) 503 | 504 | def conv_posterior(self, context_inference_hidden): 505 | h_posterior = self.conv_posterior_h(context_inference_hidden) 506 | mu_posterior = self.conv_posterior_mu(h_posterior) 507 | var_posterior = self.softplus(self.conv_posterior_var(h_posterior)) 508 | return mu_posterior, var_posterior 509 | 510 | def sent_prior(self, context_outputs, z_conv): 511 | # Context dependent prior 512 | h_prior = self.sent_prior_h(torch.cat([context_outputs, z_conv], dim=1)) 513 | mu_prior = self.sent_prior_mu(h_prior) 514 | var_prior = self.softplus(self.sent_prior_var(h_prior)) 515 | return mu_prior, var_prior 516 | 517 | def sent_posterior(self, context_outputs, encoder_hidden, z_conv): 518 | h_posterior = self.sent_posterior_h(torch.cat([context_outputs, encoder_hidden, z_conv], 1)) 519 | mu_posterior = self.sent_posterior_mu(h_posterior) 520 | var_posterior = self.softplus(self.sent_posterior_var(h_posterior)) 521 | return mu_posterior, var_posterior 522 | 523 | def forward(self, sentences, sentence_length, 524 | input_conversation_length, target_sentences, decode=False): 525 | """ 526 | Args: 527 | sentences: (Variable, LongTensor) [num_sentences + batch_size, seq_len] 528 | target_sentences: (Variable, LongTensor) [num_sentences, seq_len] 529 | Return: 530 | decoder_outputs: (Variable, FloatTensor) 531 | - train: [batch_size, seq_len, vocab_size] 532 | - eval: [batch_size, seq_len] 533 | """ 534 | batch_size = input_conversation_length.size(0) 535 | num_sentences = sentences.size(0) - batch_size 536 | max_len = input_conversation_length.data.max().item() 537 | 538 | # encoder_outputs: [num_sentences + batch_size, max_source_length, hidden_size] 539 | # encoder_hidden: [num_layers * direction, num_sentences + batch_size, hidden_size] 540 | encoder_outputs, encoder_hidden = self.encoder(sentences, 541 | sentence_length) 542 | 543 | # encoder_hidden: [num_sentences + batch_size, num_layers * direction * hidden_size] 544 | encoder_hidden = encoder_hidden.transpose( 545 | 1, 0).contiguous().view(num_sentences + batch_size, -1) 546 | 547 | # pad and pack encoder_hidden 548 | start = torch.cumsum(torch.cat((to_var(input_conversation_length.data.new(1).zero_()), 549 | input_conversation_length[:-1] + 1)), 0) 550 | # encoder_hidden: [batch_size, max_len + 1, num_layers * direction * hidden_size] 551 | encoder_hidden = torch.stack([pad(encoder_hidden.narrow(0, s, l + 1), max_len + 1) 552 | for s, l in zip(start.data.tolist(), 553 | input_conversation_length.data.tolist())], 0) 554 | 555 | # encoder_hidden_inference: [batch_size, max_len, num_layers * direction * hidden_size] 556 | encoder_hidden_inference = encoder_hidden[:, 1:, :] 557 | encoder_hidden_inference_flat = torch.cat( 558 | [encoder_hidden_inference[i, :l, :] for i, l in enumerate(input_conversation_length.data)]) 559 | 560 | # encoder_hidden_input: [batch_size, max_len, num_layers * direction * hidden_size] 561 | encoder_hidden_input = encoder_hidden[:, :-1, :] 562 | 563 | # Standard Gaussian prior 564 | conv_eps = to_var(torch.randn([batch_size, self.config.z_conv_size])) 565 | conv_mu_prior, conv_var_prior = self.conv_prior() 566 | 567 | if not decode: 568 | if self.config.sentence_drop > 0.0: 569 | indices = np.where(np.random.rand(max_len) < self.config.sentence_drop)[0] 570 | if len(indices) > 0: 571 | encoder_hidden_input[:, indices, :] = self.unk_sent 572 | 573 | # context_inference_outputs: [batch_size, max_len, num_directions * context_size] 574 | # context_inference_hidden: [num_layers * num_directions, batch_size, hidden_size] 575 | context_inference_outputs, context_inference_hidden = self.context_inference(encoder_hidden, 576 | input_conversation_length + 1) 577 | 578 | # context_inference_hidden: [batch_size, num_layers * num_directions * hidden_size] 579 | context_inference_hidden = context_inference_hidden.transpose( 580 | 1, 0).contiguous().view(batch_size, -1) 581 | conv_mu_posterior, conv_var_posterior = self.conv_posterior(context_inference_hidden) 582 | z_conv = conv_mu_posterior + torch.sqrt(conv_var_posterior) * conv_eps 583 | log_q_zx_conv = normal_logpdf(z_conv, conv_mu_posterior, conv_var_posterior).sum() 584 | 585 | log_p_z_conv = normal_logpdf(z_conv, conv_mu_prior, conv_var_prior).sum() 586 | kl_div_conv = normal_kl_div(conv_mu_posterior, conv_var_posterior, 587 | conv_mu_prior, conv_var_prior).sum() 588 | 589 | context_init = self.z_conv2context(z_conv).view( 590 | self.config.num_layers, batch_size, self.config.context_size) 591 | 592 | z_conv_expand = z_conv.view(z_conv.size(0), 1, z_conv.size( 593 | 1)).expand(z_conv.size(0), max_len, z_conv.size(1)) 594 | context_outputs, context_last_hidden = self.context_encoder( 595 | torch.cat([encoder_hidden_input, z_conv_expand], 2), 596 | input_conversation_length, 597 | hidden=context_init) 598 | 599 | # flatten outputs 600 | # context_outputs: [num_sentences, context_size] 601 | context_outputs = torch.cat([context_outputs[i, :l, :] 602 | for i, l in enumerate(input_conversation_length.data)]) 603 | 604 | z_conv_flat = torch.cat( 605 | [z_conv_expand[i, :l, :] for i, l in enumerate(input_conversation_length.data)]) 606 | sent_mu_prior, sent_var_prior = self.sent_prior(context_outputs, z_conv_flat) 607 | eps = to_var(torch.randn((num_sentences, self.config.z_sent_size))) 608 | 609 | sent_mu_posterior, sent_var_posterior = self.sent_posterior( 610 | context_outputs, encoder_hidden_inference_flat, z_conv_flat) 611 | z_sent = sent_mu_posterior + torch.sqrt(sent_var_posterior) * eps 612 | log_q_zx_sent = normal_logpdf(z_sent, sent_mu_posterior, sent_var_posterior).sum() 613 | 614 | log_p_z_sent = normal_logpdf(z_sent, sent_mu_prior, sent_var_prior).sum() 615 | # kl_div: [num_sentences] 616 | kl_div_sent = normal_kl_div(sent_mu_posterior, sent_var_posterior, 617 | sent_mu_prior, sent_var_prior).sum() 618 | 619 | kl_div = kl_div_conv + kl_div_sent 620 | log_q_zx = log_q_zx_conv + log_q_zx_sent 621 | log_p_z = log_p_z_conv + log_p_z_sent 622 | else: 623 | z_conv = conv_mu_prior + torch.sqrt(conv_var_prior) * conv_eps 624 | context_init = self.z_conv2context(z_conv).view( 625 | self.config.num_layers, batch_size, self.config.context_size) 626 | 627 | z_conv_expand = z_conv.view(z_conv.size(0), 1, z_conv.size( 628 | 1)).expand(z_conv.size(0), max_len, z_conv.size(1)) 629 | # context_outputs: [batch_size, max_len, context_size] 630 | context_outputs, context_last_hidden = self.context_encoder( 631 | torch.cat([encoder_hidden_input, z_conv_expand], 2), 632 | input_conversation_length, 633 | hidden=context_init) 634 | # flatten outputs 635 | # context_outputs: [num_sentences, context_size] 636 | context_outputs = torch.cat([context_outputs[i, :l, :] 637 | for i, l in enumerate(input_conversation_length.data)]) 638 | 639 | 640 | z_conv_flat = torch.cat( 641 | [z_conv_expand[i, :l, :] for i, l in enumerate(input_conversation_length.data)]) 642 | sent_mu_prior, sent_var_prior = self.sent_prior(context_outputs, z_conv_flat) 643 | eps = to_var(torch.randn((num_sentences, self.config.z_sent_size))) 644 | 645 | z_sent = sent_mu_prior + torch.sqrt(sent_var_prior) * eps 646 | kl_div = None 647 | log_p_z = normal_logpdf(z_sent, sent_mu_prior, sent_var_prior).sum() 648 | log_p_z += normal_logpdf(z_conv, conv_mu_prior, conv_var_prior).sum() 649 | log_q_zx = None 650 | 651 | # expand z_conv to all associated sentences 652 | z_conv = torch.cat([z.view(1, -1).expand(m.item(), self.config.z_conv_size) 653 | for z, m in zip(z_conv, input_conversation_length)]) 654 | 655 | # latent_context: [num_sentences, context_size + z_sent_size + 656 | # z_conv_size] 657 | latent_context = torch.cat([context_outputs, z_sent, z_conv], 1) 658 | decoder_init = self.context2decoder(latent_context) 659 | decoder_init = decoder_init.view(-1, 660 | self.decoder.num_layers, 661 | self.decoder.hidden_size) 662 | decoder_init = decoder_init.transpose(1, 0).contiguous() 663 | 664 | # train: [batch_size, seq_len, vocab_size] 665 | # eval: [batch_size, seq_len] 666 | if not decode: 667 | decoder_outputs = self.decoder(target_sentences, 668 | init_h=decoder_init, 669 | decode=decode) 670 | return decoder_outputs, kl_div, log_p_z, log_q_zx 671 | 672 | else: 673 | # prediction: [batch_size, beam_size, max_unroll] 674 | prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init) 675 | return prediction, kl_div, log_p_z, log_q_zx 676 | 677 | def generate(self, context, sentence_length, n_context): 678 | # context: [batch_size, n_context, seq_len] 679 | batch_size = context.size(0) 680 | # n_context = context.size(1) 681 | samples = [] 682 | 683 | # Run for context 684 | 685 | conv_eps = to_var(torch.randn([batch_size, self.config.z_conv_size])) 686 | # conv_mu_prior, conv_var_prior = self.conv_prior() 687 | # z_conv = conv_mu_prior + torch.sqrt(conv_var_prior) * conv_eps 688 | 689 | encoder_hidden_list = [] 690 | for i in range(n_context): 691 | # encoder_outputs: [batch_size, seq_len, hidden_size * direction] 692 | # encoder_hidden: [num_layers * direction, batch_size, hidden_size] 693 | encoder_outputs, encoder_hidden = self.encoder(context[:, i, :], 694 | sentence_length[:, i]) 695 | 696 | # encoder_hidden: [batch_size, num_layers * direction * hidden_size] 697 | encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1) 698 | encoder_hidden_list.append(encoder_hidden) 699 | 700 | encoder_hidden = torch.stack(encoder_hidden_list, 1) 701 | context_inference_outputs, context_inference_hidden = self.context_inference(encoder_hidden, 702 | to_var(torch.LongTensor([n_context] * batch_size))) 703 | context_inference_hidden = context_inference_hidden.transpose( 704 | 1, 0).contiguous().view(batch_size, -1) 705 | conv_mu_posterior, conv_var_posterior = self.conv_posterior(context_inference_hidden) 706 | z_conv = conv_mu_posterior + torch.sqrt(conv_var_posterior) * conv_eps 707 | 708 | context_init = self.z_conv2context(z_conv).view( 709 | self.config.num_layers, batch_size, self.config.context_size) 710 | 711 | context_hidden = context_init 712 | for i in range(n_context): 713 | # encoder_outputs: [batch_size, seq_len, hidden_size * direction] 714 | # encoder_hidden: [num_layers * direction, batch_size, hidden_size] 715 | encoder_outputs, encoder_hidden = self.encoder(context[:, i, :], 716 | sentence_length[:, i]) 717 | 718 | # encoder_hidden: [batch_size, num_layers * direction * 719 | encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1) 720 | encoder_hidden_list.append(encoder_hidden) 721 | # context_outputs: [batch_size, 1, context_hidden_size * direction] 722 | # context_hidden: [num_layers * direction, batch_size, context_hidden_size] 723 | context_outputs, context_hidden = self.context_encoder.step(torch.cat([encoder_hidden, z_conv], 1), 724 | context_hidden) 725 | 726 | # Run for generation 727 | for j in range(self.config.n_sample_step): 728 | # context_outputs: [batch_size, context_hidden_size * direction] 729 | context_outputs = context_outputs.squeeze(1) 730 | 731 | mu_prior, var_prior = self.sent_prior(context_outputs, z_conv) 732 | eps = to_var(torch.randn((batch_size, self.config.z_sent_size))) 733 | z_sent = mu_prior + torch.sqrt(var_prior) * eps 734 | 735 | latent_context = torch.cat([context_outputs, z_sent, z_conv], 1) 736 | decoder_init = self.context2decoder(latent_context) 737 | decoder_init = decoder_init.view(self.decoder.num_layers, -1, self.decoder.hidden_size) 738 | 739 | if self.config.sample: 740 | prediction = self.decoder(None, decoder_init, decode=True) 741 | p = prediction.data.cpu().numpy() 742 | length = torch.from_numpy(np.where(p == EOS_ID)[1]) 743 | else: 744 | prediction, final_score, length = self.decoder.beam_decode(init_h=decoder_init) 745 | # prediction: [batch_size, seq_len] 746 | prediction = prediction[:, 0, :] 747 | # length: [batch_size] 748 | length = [l[0] for l in length] 749 | length = to_var(torch.LongTensor(length)) 750 | 751 | samples.append(prediction) 752 | 753 | encoder_outputs, encoder_hidden = self.encoder(prediction, 754 | length) 755 | 756 | encoder_hidden = encoder_hidden.transpose(1, 0).contiguous().view(batch_size, -1) 757 | 758 | context_outputs, context_hidden = self.context_encoder.step(torch.cat([encoder_hidden, z_conv], 1), 759 | context_hidden) 760 | 761 | samples = torch.stack(samples, 1) 762 | return samples 763 | --------------------------------------------------------------------------------