├── .gitignore ├── src ├── images │ ├── ball.png │ ├── book.png │ └── hat.png ├── models │ ├── __init__.py │ ├── utils.py │ ├── ctx_encoder.py │ ├── modules.py │ ├── selection_model.py │ ├── rnn_model.py │ └── attn.py ├── engines │ ├── rnn_engine.py │ ├── selection_engine.py │ ├── __init__.py │ ├── engine.py │ └── latent_clustering_engine.py ├── config.py ├── test.py ├── vis.py ├── split.py ├── chat.py ├── utils.py ├── eval_selfplay.py ├── train.py ├── metric.py ├── selfplay.py ├── avg_rank.py ├── domain.py ├── reinforce.py ├── dialog.py └── data.py ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── README.md └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /src/images/ball.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/end-to-end-negotiator/HEAD/src/images/ball.png -------------------------------------------------------------------------------- /src/images/book.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/end-to-end-negotiator/HEAD/src/images/book.png -------------------------------------------------------------------------------- /src/images/hat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/end-to-end-negotiator/HEAD/src/images/hat.png -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. Please [read the full text](https://code.fb.com/codeofconduct) so that you can understand what actions will and will not be tolerated. -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from models.latent_clustering_model import LatentClusteringModel, LatentClusteringLanguageModel 8 | from models.latent_clustering_model import LatentClusteringPredictionModel, BaselineClusteringModel 9 | from models.selection_model import SelectionModel 10 | from models.rnn_model import RnnModel 11 | 12 | 13 | MODELS = { 14 | 'latent_clustering_model': LatentClusteringModel, 15 | 'latent_clustering_prediction_model': LatentClusteringPredictionModel, 16 | 'latent_clustering_language_model': LatentClusteringLanguageModel, 17 | 'baseline_clustering_model': BaselineClusteringModel, 18 | 'selection_model': SelectionModel, 19 | 'rnn_model': RnnModel, 20 | } 21 | 22 | 23 | def get_model_names(): 24 | return MODELS.keys() 25 | 26 | 27 | def get_model_type(name): 28 | return MODELS[name] 29 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to end-to-end-negotiator 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to end-to-end-negotiator, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /src/engines/rnn_engine.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch.autograd import Variable 9 | 10 | from engines import EngineBase, Criterion 11 | 12 | 13 | class RnnEngine(EngineBase): 14 | def __init__(self, model, args, verbose=False): 15 | super(RnnEngine, self).__init__(model, args, verbose) 16 | 17 | def _forward(model, batch): 18 | ctx, inpt, tgt, sel_tgt = batch 19 | ctx = Variable(ctx) 20 | inpt = Variable(inpt) 21 | tgt = Variable(tgt) 22 | sel_tgt = Variable(sel_tgt) 23 | 24 | out, sel_out = model(inpt, ctx) 25 | return out, tgt, sel_out, sel_tgt 26 | 27 | def train_batch(self, batch): 28 | out, tgt, sel_out, sel_tgt = RnnEngine._forward(self.model, batch) 29 | loss = self.crit(out, tgt) 30 | loss += self.sel_crit(sel_out, sel_tgt) * self.model.args.sel_weight 31 | self.opt.zero_grad() 32 | loss.backward() 33 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip) 34 | self.opt.step() 35 | return loss.item() 36 | 37 | def valid_batch(self, batch): 38 | with torch.no_grad(): 39 | out, tgt, sel_out, sel_tgt = RnnEngine._forward(self.model, batch) 40 | valid_loss = tgt.size(0) * self.crit(out, tgt) 41 | select_loss = self.sel_crit(sel_out, sel_tgt) 42 | return valid_loss.item(), select_loss.item(), 0 43 | -------------------------------------------------------------------------------- /src/models/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | A set of useful tools. 8 | """ 9 | 10 | import torch 11 | import torch.nn as nn 12 | import math 13 | 14 | 15 | def init_rnn(rnn, init_range, weights=None, biases=None): 16 | """ Orthogonal initialization of RNN. """ 17 | weights = weights or ['weight_ih_l0', 'weight_hh_l0'] 18 | biases = biases or ['bias_ih_l0', 'bias_hh_l0'] 19 | bound = 1 / math.sqrt(rnn.hidden_size) 20 | 21 | # init weights 22 | for w in weights: 23 | #nn.init.orthogonal(rnn._parameters[w]) 24 | rnn._parameters[w].data.uniform_(-bound, bound) 25 | #rnn._parameters[w].data.uniform_(-init_range, init_range) 26 | #rnn._parameters[w].data.orthogonal_() 27 | # init biases 28 | for b in biases: 29 | p = rnn._parameters[b] 30 | n = p.size(0) 31 | p.data.fill_(0) 32 | # init bias for the reset gate in GRU 33 | p.data.narrow(0, 0, n // 3).fill_(0.0) 34 | 35 | 36 | def init_rnn_cell(rnn, init_range): 37 | """ Orthogonal initialization of RNNCell. """ 38 | init_rnn(rnn, init_range, ['weight_ih', 'weight_hh'], ['bias_ih', 'bias_hh']) 39 | 40 | 41 | def init_linear(linear, init_range): 42 | """ Uniform initialization of Linear. """ 43 | linear.weight.data.uniform_(-init_range, init_range) 44 | linear.bias.data.fill_(0) 45 | 46 | 47 | def init_cont(cont, init_range): 48 | """ Uniform initialization of a container. """ 49 | for m in cont: 50 | if hasattr(m, 'weight'): 51 | m.weight.data.uniform_(-init_range, init_range) 52 | if hasattr(m, 'bias'): 53 | m.bias.data.fill_(0) 54 | 55 | 56 | def make_mask(n, marked, value=-1000): 57 | """ Create a masked tensor. """ 58 | mask = torch.Tensor(n).fill_(0) 59 | for i in marked: 60 | mask[i] = value 61 | return mask 62 | -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Configuration script. Stores variables and settings used across application 8 | """ 9 | 10 | import logging 11 | 12 | log_level = logging.INFO 13 | log_format = '%(asctime)s : %(levelname)s : %(filename)s : %(message)s' 14 | 15 | # default training settings 16 | data_dir = 'data/negotiate' # data corpus directory 17 | nembed_word = 256 # size of word embeddings 18 | nembed_ctx = 64 # size of context embeddings 19 | nhid_lang = 256 # size of the hidden state for the language model 20 | nhid_ctx = 64 # size of the hidden state for the context model 21 | nhid_strat = 64 # size of the hidden state for the strategy model 22 | nhid_attn = 64 # size of the hidden state for the attention module 23 | nhid_sel = 64 # size of the hidden state for the selection module 24 | lr = 20.0 # initial learning rate 25 | min_lr = 1e-5 # min thresshold for learning rate annealing 26 | decay_rate = 9.0 # decrease learning rate by this factor 27 | decay_every = 1 # decrease learning rate after decay_every epochs 28 | momentum = 0.0 # momentum for SGD 29 | nesterov = False # enable Nesterov momentum 30 | clip = 0.2 # gradient clipping 31 | dropout = 0.5 # dropout rate in embedding layer 32 | init_range = 0.1 #initialization range 33 | max_epoch = 30 # max number of epochs 34 | bsz = 25 # batch size 35 | unk_threshold = 20 # minimum word frequency to be in dictionary 36 | temperature = 0.1 # temperature 37 | sel_weight = 1.0 # selection weight 38 | seed = 1 # random seed 39 | cuda = False # use CUDA 40 | plot_graphs = False # use visdom 41 | domain = "object_division" # domain for the dialogue 42 | rnn_ctx_encoder = False # Whether to use RNN for encoding the context 43 | 44 | #fixes 45 | rl_gamma = 0.95 46 | rl_eps = 0.0 47 | rl_momentum = 0.0 48 | rl_lr = 20.0 49 | rl_clip = 0.2 50 | rl_reinforcement_lr = 20.0 51 | rl_reinforcement_clip = 0.2 52 | rl_bsz = 25 53 | rl_sv_train_freq = 4 54 | rl_nepoch = 4 55 | rl_score_threshold= 6 56 | verbose = True 57 | rl_temperature = 0.1 58 | -------------------------------------------------------------------------------- /src/models/ctx_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Set of context encoders. 8 | """ 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.nn.init 15 | from torch.autograd import Variable 16 | import torch.nn.functional as F 17 | 18 | from models.utils import * 19 | 20 | 21 | class MlpContextEncoder(nn.Module): 22 | """ Simple encoder for the dialogue context. Encoder counts and values via MLP. """ 23 | def __init__(self, n, k, nembed, nhid, dropout, init_range, skip_values=False): 24 | super(MlpContextEncoder, self).__init__() 25 | 26 | # embeddings for counts and values 27 | self.cnt_enc = nn.Sequential( 28 | nn.Embedding(n, nembed), 29 | nn.Dropout(dropout)) 30 | self.val_enc = nn.Sequential( 31 | nn.Embedding(n, nembed), 32 | nn.Dropout(dropout)) 33 | 34 | self.encoder = nn.Sequential( 35 | nn.Linear(k * nembed, nhid), 36 | nn.Tanh() 37 | ) 38 | 39 | # a flag to only use counts to encode the context 40 | self.skip_values = skip_values 41 | 42 | init_cont(self.cnt_enc, init_range) 43 | init_cont(self.val_enc, init_range) 44 | init_cont(self.encoder, init_range) 45 | 46 | def forward(self, ctx): 47 | cnt_idx = Variable(torch.Tensor(range(0, ctx.size(0), 2)).long()) 48 | cnt = ctx.index_select(0, cnt_idx) 49 | cnt_emb = self.cnt_enc(cnt) 50 | 51 | if self.skip_values: 52 | h = cnt_emb 53 | else: 54 | val_idx = Variable(torch.Tensor(range(1, ctx.size(0), 2)).long()) 55 | val = ctx.index_select(0, val_idx) 56 | val_emb = self.val_enc(val) 57 | # element vise multiplication of embeddings 58 | h = torch.mul(cnt_emb, val_emb) 59 | 60 | # run MLP to acquire fixed size representation 61 | h = h.transpose(0, 1).contiguous().view(ctx.size(1), -1) 62 | ctx_h = self.encoder(h) 63 | return ctx_h 64 | -------------------------------------------------------------------------------- /src/engines/selection_engine.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch.autograd import Variable 9 | 10 | from engines import EngineBase, Criterion 11 | 12 | 13 | class SelectionEngine(EngineBase): 14 | def __init__(self, model, args, verbose=False): 15 | super(SelectionEngine, self).__init__(model, args, verbose) 16 | self.sel_crit = Criterion( 17 | self.model.item_dict, 18 | bad_toks=['', ''], 19 | reduction='mean' if args.sep_sel else 'none') 20 | 21 | def _forward(model, batch, sep_sel=False): 22 | ctx, _, inpts, lens, _, sel_tgt, rev_idxs, hid_idxs, _ = batch 23 | ctx = Variable(ctx) 24 | inpts = [Variable(inpt) for inpt in inpts] 25 | rev_idxs = [Variable(idx) for idx in rev_idxs] 26 | hid_idxs = [Variable(idx) for idx in hid_idxs] 27 | if sep_sel: 28 | sel_tgt = Variable(sel_tgt) 29 | else: 30 | sel_tgt = [Variable(t) for t in sel_tgt] 31 | 32 | # remove YOU:/THEM: from the end 33 | sel_out = model(inpts[:-1], lens[:-1], rev_idxs[:-1], hid_idxs[:-1], ctx) 34 | 35 | return sel_out, sel_tgt 36 | 37 | def train_batch(self, batch): 38 | sel_out, sel_tgt = SelectionEngine._forward(self.model, batch, 39 | sep_sel=self.args.sep_sel) 40 | loss = 0 41 | if self.args.sep_sel: 42 | loss = self.sel_crit(sel_out, sel_tgt) 43 | else: 44 | for out, tgt in zip(sel_out, sel_tgt): 45 | loss += self.sel_crit(out, tgt) 46 | loss /= sel_out[0].size(0) 47 | 48 | self.opt.zero_grad() 49 | loss.backward() 50 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip) 51 | self.opt.step() 52 | return loss.item() 53 | 54 | def valid_batch(self, batch): 55 | with torch.no_grad(): 56 | sel_out, sel_tgt = SelectionEngine._forward(self.model, batch, 57 | sep_sel=self.args.sep_sel) 58 | loss = 0 59 | if self.args.sep_sel: 60 | loss = self.sel_crit(sel_out, sel_tgt) 61 | else: 62 | for out, tgt in zip(sel_out, sel_tgt): 63 | loss += self.sel_crit(out, tgt) 64 | loss /= sel_out[0].size(0) 65 | 66 | return 0, loss.item(), 0 67 | 68 | 69 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Performs evaluation of the model on the test dataset. 8 | """ 9 | 10 | import argparse 11 | 12 | import numpy as np 13 | import torch 14 | from torch.autograd import Variable 15 | 16 | import data 17 | import utils 18 | from engine import Engine, Criterion 19 | 20 | 21 | def main(): 22 | parser = argparse.ArgumentParser(description='testing script') 23 | parser.add_argument('--data', type=str, default='data/negotiate', 24 | help='location of the data corpus') 25 | parser.add_argument('--unk_threshold', type=int, default=20, 26 | help='minimum word frequency to be in dictionary') 27 | parser.add_argument('--model_file', type=str, 28 | help='pretrained model file') 29 | parser.add_argument('--seed', type=int, default=1, 30 | help='random seed') 31 | parser.add_argument('--hierarchical', action='store_true', default=False, 32 | help='use hierarchical model') 33 | parser.add_argument('--bsz', type=int, default=16, 34 | help='batch size') 35 | parser.add_argument('--cuda', action='store_true', default=False, 36 | help='use CUDA') 37 | args = parser.parse_args() 38 | 39 | device_id = utils.use_cuda(args.cuda) 40 | utils.set_seed(args.seed) 41 | 42 | corpus = data.WordCorpus(args.data, freq_cutoff=args.unk_threshold, verbose=True) 43 | model = utils.load_model(args.model_file) 44 | 45 | crit = Criterion(model.word_dict, device_id=device_id) 46 | sel_crit = Criterion(model.item_dict, device_id=device_id, 47 | bad_toks=['', '']) 48 | 49 | 50 | testset, testset_stats = corpus.test_dataset(args.bsz, device_id=device_id) 51 | test_loss, test_select_loss = 0, 0 52 | 53 | N = len(corpus.word_dict) 54 | for batch in testset: 55 | # run forward on the batch, produces output, hidden, target, 56 | # selection output and selection target 57 | out, hid, tgt, sel_out, sel_tgt = Engine.forward(model, batch, volatile=False) 58 | 59 | # compute LM and selection losses 60 | test_loss += tgt.size(0) * crit(out.view(-1, N), tgt).data[0] 61 | test_select_loss += sel_crit(sel_out, sel_tgt).data[0] 62 | 63 | test_loss /= testset_stats['nonpadn'] 64 | test_select_loss /= len(testset) 65 | print('testloss %.3f | testppl %.3f' % (test_loss, np.exp(test_loss))) 66 | print('testselectloss %.3f | testselectppl %.3f' % (test_select_loss, np.exp(test_select_loss))) 67 | 68 | 69 | if __name__ == '__main__': 70 | main() 71 | -------------------------------------------------------------------------------- /src/vis.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | A visualization library. Relies on visdom. 8 | """ 9 | 10 | import pdb 11 | 12 | import visdom 13 | import numpy as np 14 | 15 | 16 | class Plot(object): 17 | """A class for plotting and updating the plot in real time.""" 18 | def __init__(self, metrics, title, ylabel, xlabel='t', running_n=100): 19 | self.vis = visdom.Visdom() 20 | self.metrics = metrics 21 | self.opts = dict( 22 | fillarea=False, 23 | xlabel=xlabel, 24 | ylabel=ylabel, 25 | title=title, 26 | ) 27 | self.win = None 28 | self.running_n = running_n 29 | self.vals = dict() 30 | self.cnts = dict() 31 | 32 | def _update_metric(self, metric, x, y): 33 | if metric not in self.vals: 34 | self.vals[metric] = np.zeros(self.running_n) 35 | self.cnts[metric] = 0 36 | 37 | self.vals[metric][self.cnts[metric] % self.running_n] = y 38 | self.cnts[metric] += 1 39 | 40 | y = self.vals[metric][:min(self.cnts[metric], self.running_n)].mean() 41 | return np.array([x]), np.array([y]) 42 | 43 | def update(self, metric, x, y): 44 | assert metric in self.metrics, 'metric %s is not in %s' % (metric, self.metrics) 45 | X, Y = self._update_metric(metric, x, y) 46 | if self.win is None: 47 | self.opts['legend'] = [metric,] 48 | self.win = self.vis.line(X=X, Y=Y, opts=self.opts) 49 | else: 50 | self.vis.line(X=X, Y=Y, win=self.win, update='append', name=metric) 51 | 52 | 53 | class ModulePlot(object): 54 | """A helper class that plots norms of weights and gradients for a given module.""" 55 | def __init__(self, module, plot_weight=False, plot_grad=False, running_n=100): 56 | self.module = module 57 | self.plot_weight = plot_weight 58 | self.plot_grad = plot_grad 59 | self.plots = dict() 60 | 61 | def make_plot(m, n): 62 | names = m._parameters.keys() 63 | if self.plot_weight: 64 | self.plots[n + '_w'] = Plot(names, n + '_w', 'norm', running_n=running_n) 65 | if self.plot_grad: 66 | self.plots[n + '_g'] = Plot(names, n + '_g', 'norm', running_n=running_n) 67 | 68 | self._for_all(make_plot, self.module) 69 | 70 | def _for_all(self, fn, module, name=None): 71 | name = name or module.__class__.__name__.lower() 72 | if len(module._modules) == 0: 73 | fn(module, name) 74 | else: 75 | for n, m in module._modules.items(): 76 | self._for_all(fn, m, name + '_' + n) 77 | 78 | def update(self, x): 79 | def update_plot(m, n): 80 | for k, p in m._parameters.items(): 81 | if self.plot_weight: 82 | self.plots[n + '_w'].update(k, x, p.norm().item()) 83 | if self.plot_grad and hasattr(p, 'grad') and p.grad is not None: 84 | self.plots[n + '_g'].update(k, x, p.grad.norm().item()) 85 | 86 | self._for_all(update_plot, self.module) 87 | -------------------------------------------------------------------------------- /src/split.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import pdb 8 | import argparse 9 | import data 10 | import re 11 | 12 | 13 | item_pattern = re.compile('^item([0-9])=([0-9\-])+$') 14 | 15 | 16 | def find(tokens, tag): 17 | for i, t in enumerate(tokens): 18 | if t == tag: 19 | return i 20 | assert False 21 | 22 | 23 | def invert(cnts, sel): 24 | inv_sel = [] 25 | for s in sel: 26 | match = item_pattern.match(s) 27 | i = int(match.groups()[0]) 28 | v = int(match.groups()[1]) 29 | inv_sel.append('item%d=%d' % (i, cnts[i] - v)) 30 | return inv_sel 31 | 32 | 33 | def dialog_len(line): 34 | tokens = line.split(' ') 35 | n = len([t for t in tokens if t in ('YOU:', 'THEM:')]) 36 | return 'dialoglen: %d' % n 37 | 38 | 39 | def select(line): 40 | tokens = line.split(' ') 41 | n = 0 42 | for i in range(1, len(tokens)): 43 | if tokens[i] == '' and tokens[i - 1] == 'YOU:': 44 | n = 1 45 | break 46 | return 'botselect: %d' % n 47 | 48 | 49 | def conv(line): 50 | tokens = line.split(' ') 51 | context = tokens[3:9] 52 | assert tokens[9] in ('YOU:', 'THEM:') 53 | inv = tokens[9] == 'THEM:' 54 | sel_start = find(tokens, '') + 1 55 | cnts = [int(n) for n in context[0::2]] 56 | if tokens[sel_start] == 'no' or tokens[sel_start] == '': 57 | selection = [''] * 3 58 | else: 59 | selection = tokens[sel_start: sel_start + 3] 60 | if inv: 61 | selection = invert(cnts, selection) 62 | return 'debug: %s %s %s' % (' '.join(context), ' '.join(selection), ' '.join(selection)) 63 | 64 | 65 | def main(): 66 | parser = argparse.ArgumentParser( 67 | description='A script to compute Pareto efficiency') 68 | parser.add_argument('--log_file', type=str, default='', 69 | help='location of the log file') 70 | parser.add_argument('--output_file', type=str, default='', 71 | help='location of the log file') 72 | parser.add_argument('--bot_name', type=str, default='', 73 | help='bot name') 74 | args = parser.parse_args() 75 | 76 | lines = data.read_lines(args.log_file) 77 | bots = dict() 78 | for line in lines: 79 | if line.startswith(args.bot_name + '1') or \ 80 | line.startswith(args.bot_name + '2'): 81 | bots[line.split(' ')[2]] = line 82 | 83 | humans = dict() 84 | for line in lines: 85 | if line.startswith('human'): 86 | i = line.split(' ')[2] 87 | if i in bots: 88 | humans[i] = line 89 | 90 | with open(args.output_file, 'w') as f: 91 | for i in bots.keys(): 92 | if i in humans: 93 | print(dialog_len(bots[i]), file=f) 94 | print(select(bots[i]), file=f) 95 | print(conv(bots[i]), file=f) 96 | print(conv(humans[i]), file=f) 97 | print('-' * 80, file=f) 98 | 99 | if __name__ == '__main__': 100 | main() 101 | -------------------------------------------------------------------------------- /src/chat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import itertools 9 | import domain 10 | 11 | import utils 12 | from utils import ContextGenerator, ManualContextGenerator 13 | from agent import RnnAgent, HumanAgent, RnnRolloutAgent, HierarchicalAgent 14 | from dialog import Dialog, DialogLogger 15 | 16 | 17 | class Chat(object): 18 | def __init__(self, dialog, ctx_gen, logger=None): 19 | self.dialog = dialog 20 | self.ctx_gen = ctx_gen 21 | self.logger = logger if logger else DialogLogger() 22 | 23 | def run(self): 24 | self.logger.dump('Welcome to our Chatroulette!') 25 | for dialog_id in itertools.count(): 26 | self.logger.dump('=' * 80) 27 | self.logger.dump('Dialog %d' % dialog_id) 28 | self.logger.dump('-' * 80) 29 | ctxs = self.ctx_gen.sample() 30 | self.dialog.run(ctxs, self.logger) 31 | self.logger.dump('=' * 80) 32 | self.logger.dump('') 33 | 34 | 35 | def main(): 36 | parser = argparse.ArgumentParser(description='chat utility') 37 | parser.add_argument('--model_file', type=str, 38 | help='model file') 39 | parser.add_argument('--domain', type=str, default='object_division', 40 | help='domain for the dialogue') 41 | parser.add_argument('--context_file', type=str, default='', 42 | help='context file') 43 | parser.add_argument('--temperature', type=float, default=1.0, 44 | help='temperature') 45 | parser.add_argument('--num_types', type=int, default=3, 46 | help='number of object types') 47 | parser.add_argument('--num_objects', type=int, default=6, 48 | help='total number of objects') 49 | parser.add_argument('--max_score', type=int, default=10, 50 | help='max score per object') 51 | parser.add_argument('--score_threshold', type=int, default=6, 52 | help='successful dialog should have more than score_threshold in score') 53 | parser.add_argument('--seed', type=int, default=1, 54 | help='random seed') 55 | parser.add_argument('--smart_ai', action='store_true', default=False, 56 | help='make AI smart again') 57 | parser.add_argument('--ai_starts', action='store_true', default=False, 58 | help='allow AI to start the dialog') 59 | parser.add_argument('--ref_text', type=str, 60 | help='file with the reference text') 61 | parser.add_argument('--cuda', action='store_true', default=False, 62 | help='use CUDA') 63 | args = parser.parse_args() 64 | 65 | utils.use_cuda(args.cuda) 66 | utils.set_seed(args.seed) 67 | 68 | human = HumanAgent(domain.get_domain(args.domain)) 69 | 70 | alice_ty = RnnRolloutAgent if args.smart_ai else HierarchicalAgent 71 | ai = alice_ty(utils.load_model(args.model_file), args) 72 | 73 | 74 | agents = [ai, human] if args.ai_starts else [human, ai] 75 | 76 | dialog = Dialog(agents, args) 77 | logger = DialogLogger(verbose=True) 78 | if args.context_file == '': 79 | ctx_gen = ManualContextGenerator(args.num_types, args.num_objects, args.max_score) 80 | else: 81 | ctx_gen = ContextGenerator(args.context_file) 82 | 83 | chat = Chat(dialog, ctx_gen, logger) 84 | chat.run() 85 | 86 | 87 | if __name__ == '__main__': 88 | main() 89 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Various helpers. 8 | """ 9 | 10 | import random 11 | import copy 12 | import pdb 13 | import sys 14 | 15 | import torch 16 | import numpy as np 17 | 18 | 19 | def backward_hook(grad): 20 | """Hook for backward pass.""" 21 | print(grad) 22 | pdb.set_trace() 23 | return grad 24 | 25 | 26 | def save_model(model, file_name): 27 | """Serializes model to a file.""" 28 | if file_name != '': 29 | with open(file_name, 'wb') as f: 30 | torch.save(model, f) 31 | 32 | 33 | def load_model(file_name): 34 | """Reads model from a file.""" 35 | with open(file_name, 'rb') as f: 36 | return torch.load(f) 37 | 38 | 39 | def set_seed(seed): 40 | """Sets random seed everywhere.""" 41 | torch.manual_seed(seed) 42 | if torch.cuda.is_available(): 43 | torch.cuda.manual_seed(seed) 44 | random.seed(seed) 45 | np.random.seed(seed) 46 | 47 | 48 | def use_cuda(enabled, device_id=0): 49 | """Verifies if CUDA is available and sets default device to be device_id.""" 50 | if not enabled: 51 | return None 52 | assert torch.cuda.is_available(), 'CUDA is not available' 53 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 54 | torch.cuda.set_device(device_id) 55 | return device_id 56 | 57 | 58 | def prob_random(): 59 | """Prints out the states of various RNGs.""" 60 | print('random state: python %.3f torch %.3f numpy %.3f' % ( 61 | random.random(), torch.rand(1)[0], np.random.rand())) 62 | 63 | 64 | class ContextGenerator(object): 65 | """Dialogue context generator. Generates contexes from the file.""" 66 | def __init__(self, context_file): 67 | self.ctxs = [] 68 | with open(context_file, 'r') as f: 69 | ctx_pair = [] 70 | for line in f: 71 | ctx = line.strip().split() 72 | ctx_pair.append(ctx) 73 | if len(ctx_pair) == 2: 74 | self.ctxs.append(ctx_pair) 75 | ctx_pair = [] 76 | 77 | def sample(self): 78 | return random.choice(self.ctxs) 79 | 80 | def iter(self, nepoch=1): 81 | for e in range(nepoch): 82 | random.shuffle(self.ctxs) 83 | for ctx in self.ctxs: 84 | yield ctx 85 | 86 | 87 | class ManualContextGenerator(object): 88 | """Dialogue context generator. Takes contexes from stdin.""" 89 | def __init__(self, num_types=3, num_objects=10, max_score=10): 90 | self.num_types = num_types 91 | self.num_objects = num_objects 92 | self.max_score = max_score 93 | 94 | def _input_ctx(self): 95 | while True: 96 | try: 97 | ctx = input('Input context: ') 98 | ctx = ctx.strip().split() 99 | if len(ctx) != 2 * self.num_types: 100 | raise 101 | if np.sum([int(x) for x in ctx[0::2]]) != self.num_objects: 102 | raise 103 | if np.max([int(x) for x in ctx[1::2]]) > self.max_score: 104 | raise 105 | return ctx 106 | except KeyboardInterrupt: 107 | sys.exit() 108 | except: 109 | print('The context is invalid! Try again.') 110 | print('Reason: num_types=%d, num_objects=%d, max_score=%s' % ( 111 | self.num_types, self.num_objects, self.max_score)) 112 | 113 | def _update_scores(self, ctx): 114 | for i in range(1, len(ctx), 2): 115 | ctx[i] = np.random.randint(0, self.args.max_score + 1) 116 | return ctx 117 | 118 | def sample(self): 119 | ctx1 = self._input_ctx() 120 | ctx2 = self._update_scores(copy.copy(ctx1)) 121 | return [ctx1, ctx2] 122 | -------------------------------------------------------------------------------- /src/models/modules.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Helper functions for module initialization. 8 | """ 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.nn.init 15 | from torch.autograd import Variable 16 | import torch.nn.functional as F 17 | 18 | 19 | def init_rnn(rnn, init_range, weights=None, biases=None): 20 | """Initializes RNN uniformly.""" 21 | weights = weights or ['weight_ih_l0', 'weight_hh_l0'] 22 | biases = biases or ['bias_ih_l0', 'bias_hh_l0'] 23 | # Init weights 24 | for w in weights: 25 | rnn._parameters[w].data.uniform_(-init_range, init_range) 26 | # Init biases 27 | for b in biases: 28 | rnn._parameters[b].data.fill_(0) 29 | 30 | 31 | def init_rnn_cell(rnn, init_range): 32 | """Initializes RNNCell uniformly.""" 33 | init_rnn(rnn, init_range, ['weight_ih', 'weight_hh'], ['bias_ih', 'bias_hh']) 34 | 35 | 36 | def init_cont(cont, init_range): 37 | """Initializes a container uniformly.""" 38 | for m in cont: 39 | if hasattr(m, 'weight'): 40 | m.weight.data.uniform_(-init_range, init_range) 41 | if hasattr(m, 'bias'): 42 | m.bias.data.fill_(0) 43 | 44 | 45 | class CudaModule(nn.Module): 46 | """A helper to run a module on a particular device using CUDA.""" 47 | def __init__(self, device_id): 48 | super(CudaModule, self).__init__() 49 | self.device_id = device_id 50 | 51 | def to_device(self, m): 52 | if self.device_id is not None: 53 | return m.cuda(self.device_id) 54 | return m 55 | 56 | 57 | class RnnContextEncoder(CudaModule): 58 | """A module that encodes dialogues context using an RNN.""" 59 | def __init__(self, n, k, nembed, nhid, init_range, device_id): 60 | super(RnnContextEncoder, self).__init__(device_id) 61 | self.nhid = nhid 62 | 63 | # use the same embedding for counts and values 64 | self.embeder = nn.Embedding(n, nembed) 65 | # an RNN to encode a sequence of counts and values 66 | self.encoder = nn.GRU( 67 | input_size=nembed, 68 | hidden_size=nhid, 69 | bias=True) 70 | 71 | self.embeder.weight.data.uniform_(-init_range, init_range) 72 | init_rnn(self.encoder, init_range) 73 | 74 | def forward(self, ctx): 75 | ctx_h = self.to_device(torch.zeros(1, ctx.size(1), self.nhid)) 76 | # create embedding 77 | ctx_emb = self.embeder(ctx) 78 | # run it through the RNN to get a hidden representation of the context 79 | _, ctx_h = self.encoder(ctx_emb, Variable(ctx_h)) 80 | return ctx_h 81 | 82 | 83 | class MlpContextEncoder(CudaModule): 84 | """A module that encodes dialogues context using an MLP.""" 85 | def __init__(self, n, k, nembed, nhid, init_range, device_id): 86 | super(MlpContextEncoder, self).__init__(device_id) 87 | 88 | # create separate embedding for counts and values 89 | self.cnt_enc = nn.Embedding(n, nembed) 90 | self.val_enc = nn.Embedding(n, nembed) 91 | 92 | self.encoder = nn.Sequential( 93 | nn.Tanh(), 94 | nn.Linear(k * nembed, nhid) 95 | ) 96 | 97 | self.cnt_enc.weight.data.uniform_(-init_range, init_range) 98 | self.val_enc.weight.data.uniform_(-init_range, init_range) 99 | init_cont(self.encoder, init_range) 100 | 101 | def forward(self, ctx): 102 | idx = np.arange(ctx.size(0) // 2) 103 | # extract counts and values 104 | cnt_idx = Variable(self.to_device(torch.from_numpy(2 * idx + 0))) 105 | val_idx = Variable(self.to_device(torch.from_numpy(2 * idx + 1))) 106 | 107 | cnt = ctx.index_select(0, cnt_idx) 108 | val = ctx.index_select(0, val_idx) 109 | 110 | # embed counts and values 111 | cnt_emb = self.cnt_enc(cnt) 112 | val_emb = self.val_enc(val) 113 | 114 | # element wise multiplication to get a hidden state 115 | h = torch.mul(cnt_emb, val_emb) 116 | # run the hidden state through the MLP 117 | h = h.transpose(0, 1).contiguous().view(ctx.size(1), -1) 118 | ctx_h = self.encoder(h).unsqueeze(0) 119 | return ctx_h 120 | -------------------------------------------------------------------------------- /src/eval_selfplay.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Script to evaluate selfplay. 8 | It computes agreement rate, average score and Pareto optimality. 9 | """ 10 | 11 | import argparse 12 | import sys 13 | import time 14 | import random 15 | import itertools 16 | import re 17 | import pdb 18 | 19 | import numpy as np 20 | 21 | import data 22 | import utils 23 | from domain import get_domain 24 | 25 | 26 | def parse_line(line, domain): 27 | # skip the 'debug:' token 28 | tokens = line.split(' ')[1:] 29 | context = tokens[:2 * domain.input_length()] 30 | choice_str = tokens[-domain.selection_length():] 31 | 32 | cnts, vals = domain.parse_context(context) 33 | picks = [] 34 | for i, c in enumerate(choice_str[:domain.selection_length() // 2]): 35 | if c in ('', ''): 36 | picks.append(-1) 37 | else: 38 | idx, pick = domain.parse_choice(c) 39 | assert idx == i 40 | picks.append(pick) 41 | 42 | return cnts, vals, picks 43 | 44 | 45 | def parse_log(file_name, domain): 46 | """Parse the log file produced by selfplay. 47 | See the format of that log file to get more details. 48 | """ 49 | dataset, current = [], [] 50 | for line in data.read_lines(file_name): 51 | if line.startswith('debug:'): 52 | cnts, vals, picks = parse_line(line, domain) 53 | current.append((cnts, vals, picks)) 54 | if len(current) == 2: 55 | # validate that the counts match 56 | cnts1, vals1, picks1 = current[0] 57 | cnts2, vals2, picks2 = current[1] 58 | assert cnts1 == cnts2 59 | dataset.append((cnts1, vals1, picks1, vals2, picks2)) 60 | current = [] 61 | return dataset 62 | 63 | 64 | def compute_score(vals, picks): 65 | """Compute the score of the selection.""" 66 | assert len(vals) == len(picks) 67 | return np.sum([v * p for v, p in zip(vals, picks)]) 68 | 69 | 70 | def gen_choices(cnts, idx=0, choice=[]): 71 | """Generate all the valid choices. 72 | It generates both yours and your opponent choices. 73 | """ 74 | if idx >= len(cnts): 75 | return [(choice[:], [n - c for n, c in zip(cnts, choice)]),] 76 | choices = [] 77 | for c in range(cnts[idx] + 1): 78 | choice.append(c) 79 | choices += gen_choices(cnts, idx + 1, choice) 80 | choice.pop() 81 | return choices 82 | 83 | 84 | def main(): 85 | parser = argparse.ArgumentParser( 86 | description='A script to compute Pareto efficiency') 87 | parser.add_argument('--log_file', type=str, default='', 88 | help='location of the log file') 89 | parser.add_argument('--domain', type=str, default='object_division', 90 | help='domain for the dialogue') 91 | 92 | args = parser.parse_args() 93 | domain = get_domain(args.domain) 94 | 95 | dataset = parse_log(args.log_file, domain) 96 | 97 | avg_agree, avg_can_improve = 0, 0 98 | avg_score1, avg_score2 = 0, 0 99 | avg_max_score1, avg_max_score2 = 0, 0 100 | for cnts, vals1, picks1, vals2, picks2 in dataset: 101 | if np.min(picks1) == -1 or np.min(picks2) == -1: 102 | continue 103 | agree = True 104 | for p1, p2, n in zip(picks1, picks2, cnts): 105 | agree = agree and (p1 + p2 == n) 106 | if not agree: 107 | continue 108 | 109 | avg_agree += 1 110 | score1 = compute_score(vals1, picks1) 111 | score2 = compute_score(vals2, picks2) 112 | choices = gen_choices(cnts) 113 | can_improve = False 114 | for cand1, cand2 in choices: 115 | cand_score1 = compute_score(vals1, cand1) 116 | cand_score2 = compute_score(vals2, cand2) 117 | if (cand_score1 > score1 and cand_score2 >= score2) or (cand_score1 >= score1 and cand_score2 > score2): 118 | can_improve = True 119 | 120 | avg_score1 += score1 121 | avg_score2 += score2 122 | avg_can_improve += int(can_improve) 123 | 124 | print('pareto opt (%%)\t:\t%.2f' % (100. * (1 - avg_can_improve / avg_agree))) 125 | print('agree (%%)\t:\t%.2f' % (100. * avg_agree / len(dataset))) 126 | print('score (all)\t:\t%.2f vs. %.2f' % ( 127 | 1. * avg_score1 / len(dataset), 1. * avg_score2 / len(dataset))) 128 | print('score (agreed)\t:\t%.2f vs. %.2f' % ( 129 | 1. * avg_score1 / avg_agree, 1. * avg_score2 / avg_agree)) 130 | 131 | 132 | if __name__ == '__main__': 133 | main() 134 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import sys 9 | import time 10 | import random 11 | import itertools 12 | import re 13 | 14 | import numpy as np 15 | import torch 16 | import torch.nn as nn 17 | import torch.optim as optim 18 | from torch.autograd import Variable 19 | 20 | import data 21 | import utils 22 | import models 23 | from domain import get_domain 24 | 25 | 26 | def main(): 27 | parser = argparse.ArgumentParser(description='training script') 28 | parser.add_argument('--data', type=str, default='data/negotiate', 29 | help='location of the data corpus') 30 | parser.add_argument('--nembed_word', type=int, default=256, 31 | help='size of word embeddings') 32 | parser.add_argument('--nembed_ctx', type=int, default=64, 33 | help='size of context embeddings') 34 | parser.add_argument('--nhid_lang', type=int, default=256, 35 | help='size of the hidden state for the language module') 36 | parser.add_argument('--nhid_cluster', type=int, default=256, 37 | help='size of the hidden state for the language module') 38 | parser.add_argument('--nhid_ctx', type=int, default=64, 39 | help='size of the hidden state for the context module') 40 | parser.add_argument('--nhid_strat', type=int, default=64, 41 | help='size of the hidden state for the strategy module') 42 | parser.add_argument('--nhid_attn', type=int, default=64, 43 | help='size of the hidden state for the attention module') 44 | parser.add_argument('--nhid_sel', type=int, default=64, 45 | help='size of the hidden state for the selection module') 46 | parser.add_argument('--lr', type=float, default=20.0, 47 | help='initial learning rate') 48 | parser.add_argument('--min_lr', type=float, default=1e-5, 49 | help='min threshold for learning rate annealing') 50 | parser.add_argument('--decay_rate', type=float, default=9.0, 51 | help='decrease learning rate by this factor') 52 | parser.add_argument('--decay_every', type=int, default=1, 53 | help='decrease learning rate after decay_every epochs') 54 | parser.add_argument('--momentum', type=float, default=0.0, 55 | help='momentum for sgd') 56 | parser.add_argument('--clip', type=float, default=0.2, 57 | help='gradient clipping') 58 | parser.add_argument('--dropout', type=float, default=0.5, 59 | help='dropout rate in embedding layer') 60 | parser.add_argument('--init_range', type=float, default=0.1, 61 | help='initialization range') 62 | parser.add_argument('--max_epoch', type=int, default=30, 63 | help='max number of epochs') 64 | parser.add_argument('--num_clusters', type=int, default=50, 65 | help='number of clusters') 66 | parser.add_argument('--bsz', type=int, default=25, 67 | help='batch size') 68 | parser.add_argument('--unk_threshold', type=int, default=20, 69 | help='minimum word frequency to be in dictionary') 70 | parser.add_argument('--temperature', type=float, default=0.1, 71 | help='temperature') 72 | parser.add_argument('--partner_ctx_weight', type=float, default=0.0, 73 | help='selection weight') 74 | parser.add_argument('--sel_weight', type=float, default=0.6, 75 | help='selection weight') 76 | parser.add_argument('--seed', type=int, default=1, 77 | help='random seed') 78 | parser.add_argument('--cuda', action='store_true', default=False, 79 | help='use CUDA') 80 | parser.add_argument('--model_file', type=str, default='', 81 | help='path to save the final model') 82 | parser.add_argument('--prediction_model_file', type=str, default='', 83 | help='path to save the prediction model') 84 | parser.add_argument('--selection_model_file', type=str, default='', 85 | help='path to save the selection model') 86 | parser.add_argument('--cluster_model_file', type=str, default='', 87 | help='path to save the cluster model') 88 | parser.add_argument('--lang_model_file', type=str, default='', 89 | help='path to save the language model') 90 | parser.add_argument('--visual', action='store_true', default=False, 91 | help='plot graphs') 92 | parser.add_argument('--skip_values', action='store_true', default=False, 93 | help='skip values in ctx encoder') 94 | parser.add_argument('--model_type', type=str, default='rnn_model', 95 | help='model type', choices=models.get_model_names()) 96 | parser.add_argument('--domain', type=str, default='object_division', 97 | help='domain for the dialogue') 98 | parser.add_argument('--clustering', action='store_true', default=False, 99 | help='use clustering') 100 | parser.add_argument('--sep_sel', action='store_true', default=False, 101 | help='use separate classifiers for selection') 102 | 103 | 104 | args = parser.parse_args() 105 | 106 | utils.use_cuda(args.cuda) 107 | utils.set_seed(args.seed) 108 | 109 | domain = get_domain(args.domain) 110 | model_ty = models.get_model_type(args.model_type) 111 | corpus = model_ty.corpus_ty(domain, args.data, freq_cutoff=args.unk_threshold, 112 | verbose=True, sep_sel=args.sep_sel) 113 | model = model_ty(corpus.word_dict, corpus.item_dict_old, 114 | corpus.context_dict, corpus.count_dict, args) 115 | if args.cuda: 116 | model.cuda() 117 | engine = model_ty.engine_ty(model, args, verbose=True) 118 | train_loss, valid_loss, select_loss, extra = engine.train(corpus) 119 | 120 | utils.save_model(engine.get_model(), args.model_file) 121 | 122 | 123 | if __name__ == '__main__': 124 | main() 125 | -------------------------------------------------------------------------------- /src/models/selection_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import sys 8 | import re 9 | import time 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.nn.init 16 | from torch.autograd import Variable 17 | import torch.nn.functional as F 18 | 19 | import data 20 | from engines.selection_engine import SelectionEngine 21 | from domain import get_domain 22 | from models.utils import * 23 | from models.ctx_encoder import MlpContextEncoder 24 | from models.attn import Attention, HierarchicalAttention 25 | 26 | 27 | class SelectionModule(nn.Module): 28 | def __init__(self, query_size, value_size, hidden_size, selection_size, num_heads, output_size, args): 29 | super(SelectionModule, self).__init__() 30 | 31 | self.hidden_size = hidden_size 32 | self.output_size = output_size 33 | 34 | self.attn = HierarchicalAttention(query_size, value_size, hidden_size, 35 | args.dropout, args.init_range) 36 | 37 | self.sel_encoder = nn.Sequential( 38 | nn.Linear(2 * hidden_size + query_size, selection_size), 39 | nn.Tanh() 40 | ) 41 | 42 | self.dropout = nn.Dropout(args.dropout) 43 | 44 | self.sel_decoders = nn.ModuleList() 45 | for i in range(num_heads): 46 | self.sel_decoders.append(nn.Linear(selection_size, output_size)) 47 | 48 | init_cont(self.sel_encoder, args.init_range) 49 | init_cont(self.sel_decoders, args.init_range) 50 | 51 | def flatten_parameters(self): 52 | self.attn.flatten_parameters() 53 | 54 | def forward(self, q, hs, lens, rev_idxs, hid_idxs): 55 | # run attention over hs, condition on q 56 | (h, sent_p), (_, word_ps) = self.attn(q, hs, lens, rev_idxs, hid_idxs) 57 | 58 | # add q to h 59 | h = torch.cat([h, q], 1) 60 | 61 | h = self.sel_encoder(h) 62 | h = self.dropout(h) 63 | 64 | outs = [decoder(h) for decoder in self.sel_decoders] 65 | outs = torch.cat(outs, 1).view(-1, self.output_size) 66 | return outs, sent_p, word_ps 67 | 68 | 69 | class SelectionModel(nn.Module): 70 | corpus_ty = data.SentenceCorpus 71 | engine_ty = SelectionEngine 72 | def __init__(self, word_dict, item_dict, context_dict, count_dict, args): 73 | super(SelectionModel, self).__init__() 74 | 75 | self.nhid_pos = 32 76 | self.nhid_speaker = 32 77 | self.len_cutoff = 10 78 | 79 | domain = get_domain(args.domain) 80 | 81 | self.word_dict = word_dict 82 | self.item_dict = item_dict 83 | self.context_dict = context_dict 84 | self.count_dict = count_dict 85 | self.args = args 86 | 87 | self.word_encoder = nn.Embedding(len(self.word_dict), args.nembed_word) 88 | self.pos_encoder = nn.Embedding(self.len_cutoff, self.nhid_pos) 89 | self.speaker_encoder = nn.Embedding(len(self.word_dict), self.nhid_speaker) 90 | self.ctx_encoder = MlpContextEncoder(len(self.context_dict), domain.input_length(), 91 | args.nembed_ctx, args.nhid_ctx, args.dropout, args.init_range, args.skip_values) 92 | 93 | self.sel_head = SelectionModule( 94 | query_size=args.nhid_ctx, 95 | value_size=args.nembed_word + self.nhid_pos + self.nhid_speaker, 96 | hidden_size=args.nhid_attn, 97 | selection_size=args.nhid_sel, 98 | num_heads=6, 99 | output_size=len(item_dict), 100 | args=args) 101 | 102 | self.dropout = nn.Dropout(args.dropout) 103 | 104 | # init embeddings 105 | self.word_encoder.weight.data.uniform_(-self.args.init_range, self.args.init_range) 106 | self.pos_encoder.weight.data.uniform_(-self.args.init_range, self.args.init_range) 107 | self.speaker_encoder.weight.data.uniform_(-self.args.init_range, self.args.init_range) 108 | 109 | def flatten_parameters(self): 110 | self.sel_head.flatten_parameters() 111 | 112 | def forward_inpts(self, inpts, ctx_h): 113 | hs = [] 114 | for i, inpt in enumerate(inpts): 115 | speaker_emb = self.speaker_encoder(inpt[0]).unsqueeze(0) 116 | inpt_emb = self.word_encoder(inpt) 117 | inpt_emb = self.dropout(inpt_emb) 118 | pos = Variable(torch.Tensor([min(self.len_cutoff, len(inpts) - i) - 1]).long()) 119 | pos_emb = self.pos_encoder(pos).unsqueeze(0) 120 | 121 | # duplicate ctx_h along the temporal dimension and cat it with the input 122 | h = torch.cat([ 123 | inpt_emb, 124 | speaker_emb.expand(inpt_emb.size(0), speaker_emb.size(1), speaker_emb.size(2)), 125 | pos_emb.expand(inpt_emb.size(0), inpt_emb.size(1), pos_emb.size(2))], 126 | 2) 127 | hs.append(h) 128 | 129 | return hs 130 | 131 | def forward(self, inpts, lens, rev_idxs, hid_idxs, ctx): 132 | ctx_h = self.ctx_encoder(ctx) 133 | ctx_h = self.dropout(ctx_h) 134 | hs = self.forward_inpts(inpts, ctx_h) 135 | sel, _, _ = self.sel_head(ctx_h, hs, lens, rev_idxs, hid_idxs) 136 | return sel 137 | 138 | def forward_each_timestamp(self, inpts, lens, rev_idxs, hid_idxs, ctx): 139 | ctx_h = self.ctx_encoder(ctx) 140 | ctx_h = self.dropout(ctx_h) 141 | hs = self.forward_inpts(inpts, ctx_h) 142 | sels = [] 143 | for i in range(len(hs)): 144 | sel, _, _ = self.sel_head(ctx_h, hs[:i + 1], lens[: i + 1], 145 | rev_idxs[:i + 1], hid_idxs[:i + 1]) 146 | sels.append(sel) 147 | return sels 148 | 149 | 150 | 151 | 152 | 153 | -------------------------------------------------------------------------------- /src/metric.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from collections import OrderedDict 8 | import numpy as np 9 | import pdb 10 | import time 11 | 12 | import data 13 | 14 | 15 | class TimeMetric(object): 16 | def __init__(self): 17 | self.t = 0 18 | self.n = 0 19 | 20 | def reset(self): 21 | self.last_t = time.time() 22 | 23 | def record(self, n=1): 24 | self.t += time.time() - self.last_t 25 | self.n += 1 26 | 27 | def value(self): 28 | self.n = max(1, self.n) 29 | return 1.0 * self.t / self.n 30 | 31 | def show(self): 32 | return '%.3fs' % (1. * self.value()) 33 | 34 | 35 | class NumericMetric(object): 36 | def __init__(self): 37 | self.k = 0 38 | self.n = 0 39 | 40 | def reset(self): 41 | pass 42 | 43 | def record(self, k, n=1): 44 | self.k += k 45 | self.n += n 46 | 47 | def value(self): 48 | self.n = max(1, self.n) 49 | return 1.0 * self.k / self.n 50 | 51 | 52 | class PercentageMetric(NumericMetric): 53 | def show(self): 54 | return '%2.2f%%' % (100. * self.value()) 55 | 56 | 57 | class AverageMetric(NumericMetric): 58 | def show(self): 59 | return '%.2f' % (1. * self.value()) 60 | 61 | 62 | class MovingNumericMetric(object): 63 | def __init__(self, window=100): 64 | self.window = window 65 | self.a = np.zeros(window) 66 | self.n = 0 67 | 68 | def reset(self): 69 | pass 70 | 71 | def record(self, k): 72 | self.a[self.n % self.window] = k 73 | self.n += 1 74 | 75 | def value(self): 76 | s = np.sum(self.a) 77 | n = min(self.a.size, self.n + 1) 78 | return 1.0 * s / n 79 | 80 | 81 | class MovingAverageMetric(MovingNumericMetric): 82 | def show(self): 83 | return '%.2f' % (1. * self.value()) 84 | 85 | 86 | class MovingPercentageMetric(MovingNumericMetric): 87 | def show(self): 88 | return '%2.2f%%' % (100. * self.value()) 89 | 90 | 91 | class TextMetric(object): 92 | def __init__(self, text): 93 | self.text = text 94 | self.k = 0 95 | self.n = 0 96 | 97 | def reset(self): 98 | pass 99 | 100 | def value(self): 101 | self.n = max(1, self.n) 102 | return 1. * self.k / self.n 103 | 104 | def show(self): 105 | return '%.2f' % (1. * self.value()) 106 | 107 | 108 | class NGramMetric(TextMetric): 109 | def __init__(self, text, ngram=-1): 110 | super(NGramMetric, self).__init__(text) 111 | self.ngram = ngram 112 | 113 | def record(self, sen): 114 | n = len(sen) if self.ngram == -1 else self.ngram 115 | for i in range(len(sen) - n + 1): 116 | self.n += 1 117 | target = ' '.join(sen[i:i + n]) 118 | if self.text.find(target) != -1: 119 | self.k += 1 120 | 121 | 122 | class UniquenessMetric(object): 123 | def __init__(self): 124 | self.seen = set() 125 | 126 | def reset(self): 127 | pass 128 | 129 | def record(self, sen): 130 | self.seen.add(' '.join(sen)) 131 | 132 | def value(self): 133 | return len(self.seen) 134 | 135 | def show(self): 136 | return str(self.value()) 137 | 138 | 139 | class SimilarityMetric(object): 140 | def __init__(self): 141 | self.reset() 142 | self.k = 0 143 | self.n = 0 144 | 145 | def reset(self): 146 | self.history = [] 147 | 148 | def record(self, sen): 149 | self.n += 1 150 | sen = ' '.join(sen) 151 | for h in self.history: 152 | if h == sen: 153 | self.k += 1 154 | break 155 | self.history.append(sen) 156 | 157 | def value(self): 158 | self.n = max(1, self.n) 159 | return 1. * self.k / self.n 160 | 161 | def show(self): 162 | return '%.2f' % (1. * self.value()) 163 | 164 | 165 | class MetricsContainer(object): 166 | def __init__(self): 167 | self.metrics = OrderedDict() 168 | 169 | def _register(self, name, ty, *args, **kwargs): 170 | name = name.lower() 171 | assert name not in self.metrics 172 | self.metrics[name] = ty(*args, **kwargs) 173 | 174 | def register_average(self, name, *args, **kwargs): 175 | self._register(name, AverageMetric, *args, **kwargs) 176 | 177 | def register_moving_average(self, name, *args, **kwargs): 178 | self._register(name, MovingAverageMetric, *args, **kwargs) 179 | 180 | def register_time(self, name, *args, **kwargs): 181 | self._register(name, TimeMetric, *args, **kwargs) 182 | 183 | def register_percentage(self, name, *args, **kwargs): 184 | self._register(name, PercentageMetric, *args, **kwargs) 185 | 186 | def register_moving_percentage(self, name, *args, **kwargs): 187 | self._register(name, MovingPercentageMetric, *args, **kwargs) 188 | 189 | def register_ngram(self, name, *args, **kwargs): 190 | self._register(name, NGramMetric, *args, **kwargs) 191 | 192 | def register_similarity(self, name, *args, **kwargs): 193 | self._register(name, SimilarityMetric, *args, **kwargs) 194 | 195 | def register_uniqueness(self, name, *args, **kwargs): 196 | self._register(name, UniquenessMetric, *args, **kwargs) 197 | 198 | def record(self, name, *args, **kwargs): 199 | name = name.lower() 200 | assert name in self.metrics 201 | self.metrics[name].record(*args, **kwargs) 202 | 203 | def reset(self): 204 | for m in self.metrics.values(): 205 | m.reset() 206 | 207 | def value(self, name): 208 | return self.metrics[name].value() 209 | 210 | def show(self): 211 | return ' '.join(['%s=%s' % (k, v.show()) for k, v in self.metrics.iteritems()]) 212 | 213 | def dict(self): 214 | d = OrderedDict() 215 | for k, v in self.metrics.items(): 216 | d[k] = v.show() 217 | return d 218 | -------------------------------------------------------------------------------- /src/selfplay.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import pdb 9 | import re 10 | import random 11 | 12 | import numpy as np 13 | import torch 14 | from torch import optim 15 | from torch import autograd 16 | import torch.nn as nn 17 | 18 | from agent import * 19 | import utils 20 | from utils import ContextGenerator 21 | from dialog import Dialog, DialogLogger 22 | from models.rnn_model import RnnModel 23 | from models.latent_clustering_model import LatentClusteringPredictionModel, BaselineClusteringModel 24 | import domain 25 | 26 | 27 | class SelfPlay(object): 28 | def __init__(self, dialog, ctx_gen, args, logger=None): 29 | self.dialog = dialog 30 | self.ctx_gen = ctx_gen 31 | self.args = args 32 | self.logger = logger if logger else DialogLogger() 33 | 34 | def run(self): 35 | n = 0 36 | for ctxs in self.ctx_gen.iter(): 37 | n += 1 38 | if self.args.smart_alice and n > 1000: 39 | break 40 | self.logger.dump('=' * 80) 41 | self.dialog.run(ctxs, self.logger) 42 | self.logger.dump('=' * 80) 43 | self.logger.dump('') 44 | if n % 100 == 0: 45 | self.logger.dump('%d: %s' % (n, self.dialog.show_metrics()), forced=True) 46 | 47 | 48 | def get_agent_type(model, smart=False): 49 | if isinstance(model, LatentClusteringPredictionModel): 50 | if smart: 51 | return LatentClusteringRolloutAgent 52 | else: 53 | return LatentClusteringAgent 54 | elif isinstance(model, RnnModel): 55 | if smart: 56 | return RnnRolloutAgent 57 | else: 58 | return RnnAgent 59 | elif isinstance(model, BaselineClusteringModel): 60 | if smart: 61 | return BaselineClusteringRolloutAgent 62 | else: 63 | return BaselineClusteringAgent 64 | else: 65 | assert False, 'unknown model type: %s' % (model) 66 | 67 | 68 | def main(): 69 | parser = argparse.ArgumentParser(description='selfplaying script') 70 | parser.add_argument('--alice_model_file', type=str, 71 | help='Alice model file') 72 | parser.add_argument('--alice_forward_model_file', type=str, 73 | help='Alice forward model file') 74 | parser.add_argument('--bob_model_file', type=str, 75 | help='Bob model file') 76 | parser.add_argument('--context_file', type=str, 77 | help='context file') 78 | parser.add_argument('--temperature', type=float, default=1.0, 79 | help='temperature') 80 | parser.add_argument('--pred_temperature', type=float, default=1.0, 81 | help='temperature') 82 | parser.add_argument('--verbose', action='store_true', default=False, 83 | help='print out converations') 84 | parser.add_argument('--seed', type=int, default=1, 85 | help='random seed') 86 | parser.add_argument('--score_threshold', type=int, default=6, 87 | help='successful dialog should have more than score_threshold in score') 88 | parser.add_argument('--max_turns', type=int, default=20, 89 | help='maximum number of turns in a dialog') 90 | parser.add_argument('--log_file', type=str, default='', 91 | help='log successful dialogs to file for training') 92 | parser.add_argument('--smart_alice', action='store_true', default=False, 93 | help='make Alice smart again') 94 | parser.add_argument('--diverse_alice', action='store_true', default=False, 95 | help='make Alice smart again') 96 | parser.add_argument('--rollout_bsz', type=int, default=3, 97 | help='rollout batch size') 98 | parser.add_argument('--rollout_count_threshold', type=int, default=3, 99 | help='rollout count threshold') 100 | parser.add_argument('--smart_bob', action='store_true', default=False, 101 | help='make Bob smart again') 102 | parser.add_argument('--selection_model_file', type=str, default='', 103 | help='path to save the final model') 104 | parser.add_argument('--rollout_model_file', type=str, default='', 105 | help='path to save the final model') 106 | parser.add_argument('--diverse_bob', action='store_true', default=False, 107 | help='make Alice smart again') 108 | parser.add_argument('--ref_text', type=str, 109 | help='file with the reference text') 110 | parser.add_argument('--cuda', action='store_true', default=False, 111 | help='use CUDA') 112 | parser.add_argument('--domain', type=str, default='object_division', 113 | help='domain for the dialogue') 114 | parser.add_argument('--visual', action='store_true', default=False, 115 | help='plot graphs') 116 | parser.add_argument('--eps', type=float, default=0.0, 117 | help='eps greedy') 118 | parser.add_argument('--data', type=str, default='data/negotiate', 119 | help='location of the data corpus') 120 | parser.add_argument('--unk_threshold', type=int, default=20, 121 | help='minimum word frequency to be in dictionary') 122 | parser.add_argument('--bsz', type=int, default=16, 123 | help='batch size') 124 | parser.add_argument('--validate', action='store_true', default=False, 125 | help='plot graphs') 126 | 127 | args = parser.parse_args() 128 | 129 | utils.use_cuda(args.cuda) 130 | utils.set_seed(args.seed) 131 | 132 | alice_model = utils.load_model(args.alice_model_file) 133 | alice_ty = get_agent_type(alice_model, args.smart_alice) 134 | alice = alice_ty(alice_model, args, name='Alice', train=False, diverse=args.diverse_alice) 135 | alice.vis = args.visual 136 | 137 | bob_model = utils.load_model(args.bob_model_file) 138 | bob_ty = get_agent_type(bob_model, args.smart_bob) 139 | bob = bob_ty(bob_model, args, name='Bob', train=False, diverse=args.diverse_bob) 140 | 141 | bob.vis = False 142 | 143 | dialog = Dialog([alice, bob], args) 144 | logger = DialogLogger(verbose=args.verbose, log_file=args.log_file) 145 | ctx_gen = ContextGenerator(args.context_file) 146 | 147 | selfplay = SelfPlay(dialog, ctx_gen, args, logger) 148 | selfplay.run() 149 | 150 | 151 | if __name__ == '__main__': 152 | main() 153 | -------------------------------------------------------------------------------- /src/avg_rank.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Script to compute average rank for both supervised models 8 | and models with planning. 9 | """ 10 | 11 | import argparse 12 | import sys 13 | import time 14 | import random 15 | import itertools 16 | import re 17 | import pdb 18 | 19 | import numpy as np 20 | import torch 21 | import torch.nn as nn 22 | import torch.optim as optim 23 | from torch.autograd import Variable 24 | 25 | import data 26 | from agent import LstmAgent 27 | from dialog import DialogLogger 28 | import utils 29 | from domain import get_domain 30 | 31 | 32 | TAGS = ['YOU:', 'THEM:'] 33 | 34 | 35 | def read_dataset(file_name): 36 | """A helper function that reads the dataset and finds out all the unique sentences.""" 37 | lines = data.read_lines(file_name) 38 | dataset = [] 39 | all_sents = set() 40 | 41 | for line in lines: 42 | tokens = line.split(' ') 43 | ctx = data.get_tag(tokens, 'input') 44 | sents, sent = [], [] 45 | you = None 46 | for t in data.get_tag(tokens, 'dialogue'): 47 | if t in TAGS: 48 | if you is not None: 49 | sents.append((sent, you)) 50 | if you: 51 | all_sents.add(' '.join(sent)) 52 | sent = [] 53 | you = t == 'YOU:' 54 | else: 55 | assert you is not None 56 | sent.append(t) 57 | if t == '': 58 | break 59 | 60 | if len(sent) > 0: 61 | sents.append((sent, you)) 62 | if you: 63 | all_sents.add(' '.join(sent)) 64 | 65 | dataset.append((ctx, sents)) 66 | 67 | sents = [sent.split(' ') for sent in all_sents] 68 | random.shuffle(dataset) 69 | return dataset, sents 70 | 71 | 72 | def rollout(sent, ai, domain, temperature): 73 | enc = ai._encode(sent, ai.model.word_dict) 74 | _, lang_h, lang_hs = ai.model.score_sent(enc, ai.lang_h, ai.ctx_h, temperature) 75 | 76 | is_selection = len(sent) == 1 and sent[0] == '' 77 | 78 | score = 0 79 | for _ in range(5): 80 | combined_lang_hs = ai.lang_hs + [lang_hs] 81 | combined_words = ai.words + [ai.model.word2var('YOU:'), Variable(enc)] 82 | 83 | if not is_selection: 84 | # complete the conversation with rollout_length samples 85 | _, rollout, _, rollout_lang_hs = ai.model.write( 86 | lang_h, ai.ctx_h, 100, temperature, 87 | stop_tokens=[''], resume=True) 88 | combined_lang_hs += [rollout_lang_hs] 89 | combined_words += [rollout] 90 | 91 | # choose items 92 | rollout_score = None 93 | 94 | combined_lang_hs = torch.cat(combined_lang_hs) 95 | combined_words = torch.cat(combined_words) 96 | 97 | rollout_choice, _, p_agree = ai._choose(combined_lang_hs, combined_words, sample=False) 98 | rollout_score = domain.score(ai.context, rollout_choice) 99 | score += p_agree * rollout_score 100 | 101 | return score 102 | 103 | 104 | def likelihood(sent, ai, domain, temperature): 105 | """Computes likelihood of a given sentence according the giving model.""" 106 | enc = ai._encode(sent, ai.model.word_dict) 107 | score, _, _= ai.model.score_sent(enc, ai.lang_h, ai.ctx_h, temperature) 108 | return score 109 | 110 | 111 | def compute_rank(target, sents, ai, domain, temperature, score_func): 112 | """Computes rank of the target sentence. 113 | 114 | Basically find a position in the sorted list of all seen sentences. 115 | """ 116 | scores = [] 117 | # score each unique sentence 118 | for sent in sents: 119 | score = score_func(sent, ai, domain, temperature) 120 | scores.append((score, sent)) 121 | scores = sorted(scores, key=lambda x: -x[0]) 122 | 123 | # score the target sentence 124 | target_score = score_func(target, ai, domain, temperature) 125 | 126 | # find the position of the target sentence in the sorted list of all the senteces 127 | for rank, (score, _) in enumerate(scores): 128 | if target_score > score: 129 | return rank + 1 130 | return len(scores) + 1 131 | 132 | 133 | def main(): 134 | parser = argparse.ArgumentParser(description='Negotiator') 135 | parser.add_argument('--dataset', type=str, default='./data/negotiate/val.txt', 136 | help='location of the dataset') 137 | parser.add_argument('--model_file', type=str, 138 | help='model file') 139 | parser.add_argument('--smart_ai', action='store_true', default=False, 140 | help='to use rollouts') 141 | parser.add_argument('--seed', type=int, default=1, 142 | help='random seed') 143 | parser.add_argument('--temperature', type=float, default=1.0, 144 | help='temperature') 145 | parser.add_argument('--domain', type=str, default='object_division', 146 | help='domain for the dialogue') 147 | parser.add_argument('--log_file', type=str, default='', 148 | help='log file') 149 | args = parser.parse_args() 150 | 151 | utils.set_seed(args.seed) 152 | 153 | model = utils.load_model(args.model_file) 154 | ai = LstmAgent(model, args) 155 | logger = DialogLogger(verbose=True, log_file=args.log_file) 156 | domain = get_domain(args.domain) 157 | 158 | score_func = rollout if args.smart_ai else likelihood 159 | 160 | dataset, sents = read_dataset(args.dataset) 161 | ranks, n, k = 0, 0, 0 162 | for ctx, dialog in dataset: 163 | start_time = time.time() 164 | # start new conversation 165 | ai.feed_context(ctx) 166 | for sent, you in dialog: 167 | if you: 168 | # if it is your turn to say, take the target word and compute its rank 169 | rank = compute_rank(sent, sents, ai, domain, args.temperature, score_func) 170 | # compute lang_h for the groundtruth sentence 171 | enc = ai._encode(sent, ai.model.word_dict) 172 | _, ai.lang_h, lang_hs = ai.model.score_sent(enc, ai.lang_h, ai.ctx_h, args.temperature) 173 | # save hidden states and the utterance 174 | ai.lang_hs.append(lang_hs) 175 | ai.words.append(ai.model.word2var('YOU:')) 176 | ai.words.append(Variable(enc)) 177 | ranks += rank 178 | n += 1 179 | else: 180 | ai.read(sent) 181 | k += 1 182 | time_elapsed = time.time() - start_time 183 | logger.dump('dialogue %d | avg rank %.3f | raw %d/%d | time %.3f' % (k, 1. * ranks / n, ranks, n, time_elapsed)) 184 | 185 | logger.dump('final avg rank %.3f' % (1. * ranks / n)) 186 | 187 | 188 | if __name__ == '__main__': 189 | main() 190 | -------------------------------------------------------------------------------- /src/domain.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import re 9 | from collections import OrderedDict 10 | 11 | 12 | def get_domain(name): 13 | if name == 'object_division': 14 | return ObjectDivisionDomain() 15 | if name == 'trade': 16 | return ObjectTradeDomain() 17 | raise() 18 | 19 | 20 | class Domain(object): 21 | """ Domain interface. """ 22 | def selection_length(self): 23 | pass 24 | 25 | def input_length(self): 26 | pass 27 | 28 | def generate_choices(self, input): 29 | pass 30 | 31 | def parse_context(self, ctx): 32 | pass 33 | 34 | def score(self, context, choice): 35 | pass 36 | 37 | def parse_choice(self, choice): 38 | pass 39 | 40 | def parse_human_choice(self, input, output): 41 | pass 42 | 43 | def score_choices(self, choices, ctxs): 44 | pass 45 | 46 | 47 | class ObjectDivisionDomain(Domain): 48 | def __init__(self): 49 | self.item_pattern = re.compile('^item([0-9])=(-?[0-9])+$') 50 | 51 | def selection_length(self): 52 | return 6 53 | 54 | def input_length(self): 55 | return 3 56 | 57 | def num_choices(self): 58 | return len(self.idx2sel) 59 | 60 | def generate_choices(self, input, with_disagreement=True): 61 | cnts, _ = self.parse_context(input) 62 | 63 | def gen(cnts, idx=0, choice=[]): 64 | if idx >= len(cnts): 65 | left_choice = ['item%d=%d' % (i, c) for i, c in enumerate(choice)] 66 | right_choice = ['item%d=%d' % (i, n - c) for i, (n, c) in enumerate(zip(cnts, choice))] 67 | return [left_choice + right_choice] 68 | choices = [] 69 | for c in range(cnts[idx] + 1): 70 | choice.append(c) 71 | choices += gen(cnts, idx + 1, choice) 72 | choice.pop() 73 | return choices 74 | choices = gen(cnts) 75 | if with_disagreement: 76 | choices.append([''] * self.selection_length()) 77 | choices.append([''] * self.selection_length()) 78 | return choices 79 | 80 | def parse_context(self, ctx): 81 | cnts = [int(n) for n in ctx[0::2]] 82 | vals = [int(v) for v in ctx[1::2]] 83 | return cnts, vals 84 | 85 | def score(self, context, choice): 86 | assert len(choice) == (self.selection_length()) 87 | choice = choice[0:len(choice) // 2] 88 | if choice[0] in ('', '', ''): 89 | return 0 90 | _, vals = self.parse_context(context) 91 | score = 0 92 | for i, (c, v) in enumerate(zip(choice, vals)): 93 | idx, cnt = self.parse_choice(c) 94 | # Verify that the item idx is correct 95 | assert idx == i 96 | score += cnt * v 97 | return score 98 | 99 | def parse_choice(self, choice): 100 | match = self.item_pattern.match(choice) 101 | assert match is not None, 'choice %s' % choice 102 | # Returns item idx and it's count 103 | return (int(match.groups()[0]), int(match.groups()[1])) 104 | 105 | def parse_human_choice(self, input, output): 106 | cnts = self.parse_context(input)[0] 107 | choice = [int(x) for x in output.strip().split()] 108 | 109 | if len(choice) != len(cnts): 110 | raise 111 | for x, n in zip(choice, cnts): 112 | if x < 0 or x > n: 113 | raise 114 | return ['item%d=%d' % (i, x) for i, x in enumerate(choice)] 115 | 116 | def _to_int(self, x): 117 | try: 118 | return int(x) 119 | except: 120 | return 0 121 | 122 | def score_choices(self, choices, ctxs): 123 | assert len(choices) == len(ctxs) 124 | cnts = [int(x) for x in ctxs[0][0::2]] 125 | agree, scores = True, [0 for _ in range(len(ctxs))] 126 | for i, n in enumerate(cnts): 127 | for agent_id, (choice, ctx) in enumerate(zip(choices, ctxs)): 128 | taken = self._to_int(choice[i][-1]) 129 | n -= taken 130 | scores[agent_id] += int(ctx[2 * i + 1]) * taken 131 | agree = agree and (n == 0) 132 | return agree, scores 133 | 134 | 135 | class ObjectTradeDomain(ObjectDivisionDomain): 136 | def __init__(self, max_items=1): 137 | super(ObjectTradeDomain, self).__init__() 138 | self.max_items = max_items 139 | 140 | def selection_length(self): 141 | return 3 142 | 143 | def input_length(self): 144 | return 3 145 | 146 | def generate_choices(self, input): 147 | cnts, _ = self.parse_context(input) 148 | 149 | def gen(cnts, idx=0, choice=[]): 150 | if idx >= len(cnts): 151 | left_choice = ['item%d=%d' % (i, c) for i, c in enumerate(choice)] 152 | return [left_choice] 153 | choices = [] 154 | for c in range(-cnts[idx], self.max_items + 1): 155 | choice.append(c) 156 | choices += gen(cnts, idx + 1, choice) 157 | choice.pop() 158 | 159 | return choices 160 | choices = gen(cnts) 161 | choices.append([''] * self.selection_length()) 162 | choices.append([''] * self.selection_length()) 163 | return choices 164 | 165 | def score_choices(self, choices, ctxs): 166 | assert len(choices) == len(ctxs) 167 | cnts = [int(x) for x in ctxs[0][0::2]] 168 | agree, scores = True, [0 for _ in range(len(ctxs))] 169 | for i in range(len(cnts)): 170 | n = 0 171 | for agent_id, (choice, ctx) in enumerate(zip(choices, ctxs)): 172 | taken = self._to_int(choice[i][choice[i].find("=") + 1:]) 173 | n += taken 174 | scores[agent_id] += int(ctx[2 * i + 1]) * taken 175 | agree = agree and (n == 0) 176 | return agree, scores 177 | 178 | def score(self, context, choice): 179 | assert len(choice) == (self.selection_length()) 180 | if choice[0] == '': 181 | return 0 182 | _, vals = self.parse_context(context) 183 | score = 0 184 | for i, (c, v) in enumerate(zip(choice, vals)): 185 | idx, cnt = self.parse_choice(c) 186 | # Verify that the item idx is correct 187 | assert idx == i 188 | score += cnt * v 189 | return score 190 | 191 | def parse_human_choice(self, input, output): 192 | cnts = self.parse_context(input)[0] 193 | choice = [int(x) for x in output.strip().split()] 194 | 195 | if len(choice) != len(cnts): 196 | raise 197 | for x, n in zip(choice, cnts): 198 | if x < -n or x > 4: 199 | raise 200 | return ['item%d=%d' % (i, x) for i, x in enumerate(choice)] 201 | -------------------------------------------------------------------------------- /src/engines/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import random 9 | import time 10 | import itertools 11 | import sys 12 | import copy 13 | import re 14 | 15 | import torch 16 | from torch import optim 17 | import torch.nn as nn 18 | from torch.autograd import Variable 19 | import numpy as np 20 | 21 | import vis 22 | 23 | 24 | class Criterion(object): 25 | """Weighted CrossEntropyLoss.""" 26 | def __init__(self, dictionary, device_id=None, bad_toks=[], reduction='mean'): 27 | w = torch.Tensor(len(dictionary)).fill_(1) 28 | for tok in bad_toks: 29 | w[dictionary.get_idx(tok)] = 0.0 30 | if device_id is not None: 31 | w = w.cuda(device_id) 32 | # https://pytorch.org/docs/stable/nn.html 33 | self.crit = nn.CrossEntropyLoss(w, reduction=reduction) 34 | 35 | def __call__(self, out, tgt): 36 | return self.crit(out, tgt) 37 | 38 | 39 | class EngineBase(object): 40 | """Base class for training engine.""" 41 | def __init__(self, model, args, verbose=False): 42 | self.model = model 43 | self.args = args 44 | self.verbose = verbose 45 | self.opt = self.make_opt(self.args.lr) 46 | self.crit = Criterion(self.model.word_dict) 47 | self.sel_crit = Criterion( 48 | self.model.item_dict, bad_toks=['', '']) 49 | if self.args.visual: 50 | self.model_plot = vis.ModulePlot(self.model, plot_weight=True, plot_grad=False) 51 | self.loss_plot = vis.Plot(['train', 'valid', 'valid_select'], 52 | 'loss', 'loss', 'epoch', running_n=1, write_to_file=False) 53 | self.ppl_plot = vis.Plot(['train', 'valid', 'valid_select'], 54 | 'perplexity', 'ppl', 'epoch', running_n=1, write_to_file=False) 55 | 56 | def make_opt(self, lr): 57 | return optim.RMSprop( 58 | self.model.parameters(), 59 | lr=lr, 60 | momentum=self.args.momentum) 61 | 62 | def get_model(self): 63 | return self.model 64 | 65 | def train_batch(self, batch): 66 | pass 67 | 68 | def valid_batch(self, batch): 69 | pass 70 | 71 | def train_pass(self, trainset): 72 | self.model.train() 73 | 74 | total_loss = 0 75 | start_time = time.time() 76 | 77 | for batch in trainset: 78 | self.t += 1 79 | loss = self.train_batch(batch) 80 | 81 | if self.args.visual and self.t % 100 == 0: 82 | self.model_plot.update(self.t) 83 | 84 | total_loss += loss 85 | 86 | total_loss /= len(trainset) 87 | time_elapsed = time.time() - start_time 88 | return total_loss, time_elapsed 89 | 90 | def valid_pass(self, validset, validset_stats): 91 | self.model.eval() 92 | 93 | total_valid_loss, total_select_loss, total_partner_ctx_loss = 0, 0, 0 94 | for batch in validset: 95 | valid_loss, select_loss, partner_ctx_loss = self.valid_batch(batch) 96 | total_valid_loss += valid_loss 97 | total_select_loss += select_loss 98 | total_partner_ctx_loss += partner_ctx_loss 99 | 100 | # Dividing by the number of words in the input, not the tokens modeled, 101 | # because the latter includes padding 102 | total_valid_loss /= validset_stats['nonpadn'] 103 | total_select_loss /= len(validset) 104 | total_partner_ctx_loss /= len(validset) 105 | #total_future_loss /= len(validset) 106 | return total_valid_loss, total_select_loss, total_partner_ctx_loss, {} 107 | 108 | def iter(self, epoch, lr, traindata, validdata): 109 | trainset, _ = traindata 110 | validset, validset_stats = validdata 111 | 112 | train_loss, train_time = self.train_pass(trainset) 113 | valid_loss, valid_select_loss, valid_partner_ctx_loss, extra = \ 114 | self.valid_pass(validset, validset_stats) 115 | 116 | if self.verbose: 117 | print('| epoch %03d | trainloss %.3f | trainppl %.3f | s/epoch %.2f | lr %0.8f' % ( 118 | epoch, train_loss, np.exp(train_loss), train_time, lr)) 119 | print('| epoch %03d | validloss %.3f | validppl %.3f' % ( 120 | epoch, valid_loss, np.exp(valid_loss))) 121 | print('| epoch %03d | validselectloss %.3f | validselectppl %.3f' % ( 122 | epoch, valid_select_loss, np.exp(valid_select_loss))) 123 | if self.model.args.partner_ctx_weight != 0: 124 | print('| epoch %03d | validpartnerctxloss %.3f | validpartnerctxppl %.3f' % ( 125 | epoch, valid_partner_ctx_loss, np.exp(valid_partner_ctx_loss))) 126 | 127 | if self.args.visual: 128 | self.loss_plot.update('train', epoch, train_loss) 129 | self.loss_plot.update('valid', epoch, valid_loss) 130 | self.loss_plot.update('valid_select', epoch, valid_select_loss) 131 | self.ppl_plot.update('train', epoch, np.exp(train_loss)) 132 | self.ppl_plot.update('valid', epoch, np.exp(valid_loss)) 133 | self.ppl_plot.update('valid_select', epoch, np.exp(valid_select_loss)) 134 | 135 | return train_loss, valid_loss, valid_select_loss, extra 136 | 137 | def combine_loss(self, lang_loss, select_loss): 138 | return lang_loss + select_loss 139 | 140 | def train(self, corpus): 141 | best_model, best_combined_valid_loss = copy.deepcopy(self.model), 1e100 142 | lr = self.args.lr 143 | last_decay_epoch = 0 144 | self.t = 0 145 | 146 | validdata = corpus.valid_dataset(self.args.bsz) 147 | for epoch in range(1, self.args.max_epoch + 1): 148 | traindata = corpus.train_dataset(self.args.bsz) 149 | _, valid_loss, valid_select_loss, extra = self.iter(epoch, lr, traindata, validdata) 150 | 151 | combined_valid_loss = self.combine_loss(valid_loss, valid_select_loss) 152 | if combined_valid_loss < best_combined_valid_loss: 153 | best_combined_valid_loss = combined_valid_loss 154 | best_model = copy.deepcopy(self.model) 155 | best_model.flatten_parameters() 156 | 157 | if self.verbose: 158 | print('| start annealing | best combined loss %.3f | best combined ppl %.3f' % ( 159 | best_combined_valid_loss, np.exp(best_combined_valid_loss))) 160 | 161 | self.model = best_model 162 | for epoch in range(self.args.max_epoch + 1, 100): 163 | if epoch - last_decay_epoch >= self.args.decay_every: 164 | last_decay_epoch = epoch 165 | lr /= self.args.decay_rate 166 | if lr < self.args.min_lr: 167 | break 168 | self.opt = self.make_opt(lr) 169 | 170 | traindata = corpus.train_dataset(self.args.bsz) 171 | train_loss, valid_loss, valid_select_loss, extra = self.iter( 172 | epoch, lr, traindata, validdata) 173 | 174 | return train_loss, valid_loss, valid_select_loss, extra 175 | -------------------------------------------------------------------------------- /src/reinforce.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import pdb 9 | import random 10 | import re 11 | import time 12 | 13 | import numpy as np 14 | import torch 15 | from torch import optim 16 | from torch import autograd 17 | import torch.nn as nn 18 | 19 | import data 20 | import utils 21 | from utils import ContextGenerator 22 | from agent import RnnAgent, RnnRolloutAgent, RlAgent, HierarchicalAgent 23 | from dialog import Dialog, DialogLogger 24 | from selfplay import get_agent_type 25 | from domain import get_domain 26 | 27 | 28 | class Reinforce(object): 29 | def __init__(self, dialog, ctx_gen, args, engine, corpus, logger=None): 30 | self.dialog = dialog 31 | self.ctx_gen = ctx_gen 32 | self.args = args 33 | self.engine = engine 34 | self.corpus = corpus 35 | self.logger = logger if logger else DialogLogger() 36 | 37 | def run(self): 38 | validset, validset_stats = self.corpus.valid_dataset(self.args.bsz) 39 | trainset, trainset_stats = self.corpus.train_dataset(self.args.bsz) 40 | 41 | n = 0 42 | for ctxs in self.ctx_gen.iter(self.args.nepoch): 43 | n += 1 44 | if self.args.sv_train_freq > 0 and n % self.args.sv_train_freq == 0: 45 | batch = random.choice(trainset) 46 | self.engine.model.train() 47 | self.engine.train_batch(batch) 48 | self.engine.model.eval() 49 | 50 | self.logger.dump('=' * 80) 51 | self.dialog.run(ctxs, self.logger) 52 | self.logger.dump('=' * 80) 53 | self.logger.dump('') 54 | if n % 100 == 0: 55 | self.logger.dump('%d: %s' % (n, self.dialog.show_metrics()), forced=True) 56 | 57 | def dump_stats(dataset, stats, name): 58 | loss, select_loss = self.engine.valid_pass(N, dataset, stats) 59 | self.logger.dump('final: %s_loss %.3f %s_ppl %.3f' % ( 60 | name, float(loss), name, np.exp(float(loss))), 61 | forced=True) 62 | self.logger.dump('final: %s_select_loss %.3f %s_select_ppl %.3f' % ( 63 | name, float(select_loss), name, np.exp(float(select_loss))), 64 | forced=True) 65 | 66 | dump_stats(trainset, trainset_stats, 'train') 67 | dump_stats(validset, validset_stats, 'valid') 68 | 69 | self.logger.dump('final: %s' % self.dialog.show_metrics(), forced=True) 70 | 71 | 72 | def main(): 73 | parser = argparse.ArgumentParser(description='Reinforce') 74 | parser.add_argument('--alice_model_file', type=str, 75 | help='Alice model file') 76 | parser.add_argument('--bob_model_file', type=str, 77 | help='Bob model file') 78 | parser.add_argument('--output_model_file', type=str, 79 | help='output model file') 80 | parser.add_argument('--context_file', type=str, 81 | help='context file') 82 | parser.add_argument('--temperature', type=float, default=1.0, 83 | help='temperature') 84 | parser.add_argument('--pred_temperature', type=float, default=1.0, 85 | help='temperature') 86 | parser.add_argument('--cuda', action='store_true', default=False, 87 | help='use CUDA') 88 | parser.add_argument('--verbose', action='store_true', default=False, 89 | help='print out converations') 90 | parser.add_argument('--seed', type=int, default=1, 91 | help='random seed') 92 | parser.add_argument('--score_threshold', type=int, default=6, 93 | help='successful dialog should have more than score_threshold in score') 94 | parser.add_argument('--log_file', type=str, default='', 95 | help='log successful dialogs to file for training') 96 | parser.add_argument('--smart_bob', action='store_true', default=False, 97 | help='make Bob smart again') 98 | parser.add_argument('--gamma', type=float, default=0.99, 99 | help='discount factor') 100 | parser.add_argument('--eps', type=float, default=0.5, 101 | help='eps greedy') 102 | parser.add_argument('--momentum', type=float, default=0.1, 103 | help='momentum for sgd') 104 | parser.add_argument('--lr', type=float, default=0.1, 105 | help='learning rate') 106 | parser.add_argument('--clip', type=float, default=0.1, 107 | help='gradient clip') 108 | parser.add_argument('--rl_lr', type=float, default=0.002, 109 | help='RL learning rate') 110 | parser.add_argument('--rl_clip', type=float, default=2.0, 111 | help='RL gradient clip') 112 | parser.add_argument('--ref_text', type=str, 113 | help='file with the reference text') 114 | parser.add_argument('--sv_train_freq', type=int, default=-1, 115 | help='supervision train frequency') 116 | parser.add_argument('--nepoch', type=int, default=1, 117 | help='number of epochs') 118 | parser.add_argument('--hierarchical', action='store_true', default=False, 119 | help='use hierarchical training') 120 | parser.add_argument('--visual', action='store_true', default=False, 121 | help='plot graphs') 122 | parser.add_argument('--domain', type=str, default='object_division', 123 | help='domain for the dialogue') 124 | parser.add_argument('--selection_model_file', type=str, default='', 125 | help='path to save the final model') 126 | parser.add_argument('--data', type=str, default='data/negotiate', 127 | help='location of the data corpus') 128 | parser.add_argument('--unk_threshold', type=int, default=20, 129 | help='minimum word frequency to be in dictionary') 130 | parser.add_argument('--bsz', type=int, default=16, 131 | help='batch size') 132 | parser.add_argument('--validate', action='store_true', default=False, 133 | help='plot graphs') 134 | parser.add_argument('--scratch', action='store_true', default=False, 135 | help='erase prediciton weights') 136 | parser.add_argument('--sep_sel', action='store_true', default=False, 137 | help='use separate classifiers for selection') 138 | 139 | args = parser.parse_args() 140 | 141 | utils.use_cuda(args.cuda) 142 | utils.set_seed(args.seed) 143 | 144 | alice_model = utils.load_model(args.alice_model_file) 145 | alice_ty = get_agent_type(alice_model) 146 | alice = alice_ty(alice_model, args, name='Alice', train=True) 147 | alice.vis = args.visual 148 | 149 | bob_model = utils.load_model(args.bob_model_file) 150 | bob_ty = get_agent_type(bob_model) 151 | bob = bob_ty(bob_model, args, name='Bob', train=False) 152 | 153 | dialog = Dialog([alice, bob], args) 154 | logger = DialogLogger(verbose=args.verbose, log_file=args.log_file) 155 | ctx_gen = ContextGenerator(args.context_file) 156 | 157 | domain = get_domain(args.domain) 158 | corpus = alice_model.corpus_ty(domain, args.data, freq_cutoff=args.unk_threshold, 159 | verbose=True, sep_sel=args.sep_sel) 160 | engine = alice_model.engine_ty(alice_model, args) 161 | 162 | reinforce = Reinforce(dialog, ctx_gen, args, engine, corpus, logger) 163 | reinforce.run() 164 | 165 | utils.save_model(alice.model, args.output_model_file) 166 | 167 | 168 | if __name__ == '__main__': 169 | main() 170 | -------------------------------------------------------------------------------- /src/dialog.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import sys 8 | import pdb 9 | 10 | import numpy as np 11 | 12 | from metric import MetricsContainer 13 | import data 14 | import utils 15 | import domain 16 | 17 | 18 | class DialogLogger(object): 19 | CODE2ITEM = [ 20 | ('item0', 'book'), 21 | ('item1', 'hat'), 22 | ('item2', 'ball'), 23 | ] 24 | 25 | def __init__(self, verbose=False, log_file=None, append=False): 26 | self.logs = [] 27 | if verbose: 28 | self.logs.append(sys.stderr) 29 | if log_file: 30 | flags = 'a' if append else 'w' 31 | self.logs.append(open(log_file, flags)) 32 | 33 | def _dump(self, s, forced=False): 34 | for log in self.logs: 35 | print(s, file=log) 36 | log.flush() 37 | if forced: 38 | print(s, file=sys.stdout) 39 | sys.stdout.flush() 40 | 41 | def _dump_with_name(self, name, s): 42 | self._dump('{0: <5} : {1}'.format(name, s)) 43 | 44 | def dump_ctx(self, name, ctx): 45 | assert len(ctx) == 6, 'we expect 3 objects' 46 | s = ' '.join(['%s=(count:%s value:%s)' % (self.CODE2ITEM[i][1], ctx[2 * i], ctx[2 * i + 1]) \ 47 | for i in range(3)]) 48 | self._dump_with_name(name, s) 49 | 50 | def dump_sent(self, name, sent): 51 | self._dump_with_name(name, ' '.join(sent)) 52 | 53 | def dump_choice(self, name, choice): 54 | def rep(w): 55 | p = w.split('=') 56 | if len(p) == 2: 57 | for k, v in self.CODE2ITEM: 58 | if p[0] == k: 59 | return '%s=%s' % (v, p[1]) 60 | return w 61 | 62 | self._dump_with_name(name, ' '.join([rep(c) for c in choice])) 63 | 64 | def dump_agreement(self, agree): 65 | self._dump('Agreement!' if agree else 'Disagreement?!') 66 | 67 | def dump_reward(self, name, agree, reward): 68 | if agree: 69 | self._dump_with_name(name, '%d points' % reward) 70 | else: 71 | self._dump_with_name(name, '0 (potential %d)' % reward) 72 | 73 | def dump(self, s, forced=False): 74 | self._dump(s, forced=forced) 75 | 76 | 77 | class DialogSelfTrainLogger(DialogLogger): 78 | def __init__(self, verbose=False, log_file=None): 79 | super(DialogSelfTrainLogger, self).__init__(verbose, log_file) 80 | self.name2example = {} 81 | self.name2choice = {} 82 | 83 | def _dump_with_name(self, name, sent): 84 | for n in self.name2example: 85 | if n == name: 86 | self.name2example[n] += " YOU: " 87 | else: 88 | self.name2example[n] += " THEM: " 89 | 90 | self.name2example[n] += sent 91 | 92 | def dump_ctx(self, name, ctx): 93 | self.name2example[name] = ' '.join(ctx) 94 | 95 | def dump_choice(self, name, choice): 96 | self.name2choice[name] = ' '.join(choice) 97 | 98 | def dump_agreement(self, agree): 99 | if agree: 100 | for name in self.name2example: 101 | for other_name in self.name2example: 102 | if name != other_name: 103 | self.name2example[name] += ' ' + self.name2choice[name] 104 | self.name2example[name] += ' ' + self.name2choice[other_name] 105 | self._dump(self.name2example[name]) 106 | 107 | def dump_reward(self, name, agree, reward): 108 | pass 109 | 110 | 111 | class Dialog(object): 112 | def __init__(self, agents, args): 113 | # For now we only suppport dialog of 2 agents 114 | assert len(agents) == 2 115 | self.agents = agents 116 | self.args = args 117 | self.domain = domain.get_domain(args.domain) 118 | self.metrics = MetricsContainer() 119 | self._register_metrics() 120 | 121 | def _register_metrics(self): 122 | self.metrics.register_average('dialog_len') 123 | self.metrics.register_average('sent_len') 124 | self.metrics.register_percentage('agree') 125 | self.metrics.register_moving_percentage('moving_agree') 126 | self.metrics.register_average('advantage') 127 | self.metrics.register_moving_average('moving_advantage') 128 | self.metrics.register_time('time') 129 | self.metrics.register_average('comb_rew') 130 | self.metrics.register_average('agree_comb_rew') 131 | for agent in self.agents: 132 | self.metrics.register_average('%s_rew' % agent.name) 133 | self.metrics.register_moving_average('%s_moving_rew' % agent.name) 134 | self.metrics.register_average('agree_%s_rew' % agent.name) 135 | self.metrics.register_percentage('%s_sel' % agent.name) 136 | self.metrics.register_uniqueness('%s_unique' % agent.name) 137 | # text metrics 138 | if self.args.ref_text: 139 | ref_text = ' '.join(data.read_lines(self.args.ref_text)) 140 | self.metrics.register_ngram('full_match', text=ref_text) 141 | 142 | def _is_selection(self, out): 143 | return len(out) == 1 and (out[0] in ['', '']) 144 | 145 | def show_metrics(self): 146 | return ' '.join(['%s=%s' % (k, v) for k, v in self.metrics.dict().items()]) 147 | 148 | def run(self, ctxs, logger, max_words=5000): 149 | assert len(self.agents) == len(ctxs) 150 | for agent, ctx, partner_ctx in zip(self.agents, ctxs, reversed(ctxs)): 151 | agent.feed_context(ctx) 152 | agent.feed_partner_context(partner_ctx) 153 | logger.dump_ctx(agent.name, ctx) 154 | logger.dump('-' * 80) 155 | 156 | # Choose who goes first by random 157 | if np.random.rand() < 0.5: 158 | writer, reader = self.agents 159 | else: 160 | reader, writer = self.agents 161 | 162 | conv = [] 163 | self.metrics.reset() 164 | 165 | #words_left = np.random.randint(50, 200) 166 | words_left = max_words 167 | length = 0 168 | expired = False 169 | 170 | while True: 171 | out = writer.write(max_words=words_left) 172 | words_left -= len(out) 173 | length += len(out) 174 | 175 | self.metrics.record('sent_len', len(out)) 176 | if 'full_match' in self.metrics.metrics: 177 | self.metrics.record('full_match', out) 178 | self.metrics.record('%s_unique' % writer.name, out) 179 | 180 | conv.append(out) 181 | reader.read(out) 182 | if not writer.human: 183 | logger.dump_sent(writer.name, out) 184 | 185 | if self._is_selection(out): 186 | self.metrics.record('%s_sel' % writer.name, 1) 187 | self.metrics.record('%s_sel' % reader.name, 0) 188 | break 189 | 190 | if words_left <= 1: 191 | break 192 | 193 | writer, reader = reader, writer 194 | 195 | 196 | choices = [] 197 | for agent in self.agents: 198 | choice = agent.choose() 199 | choices.append(choice) 200 | logger.dump_choice(agent.name, choice[: self.domain.selection_length() // 2]) 201 | 202 | agree, rewards = self.domain.score_choices(choices, ctxs) 203 | if expired: 204 | agree = False 205 | logger.dump('-' * 80) 206 | logger.dump_agreement(agree) 207 | for i, (agent, reward) in enumerate(zip(self.agents, rewards)): 208 | logger.dump_reward(agent.name, agree, reward) 209 | j = 1 if i == 0 else 0 210 | agent.update(agree, reward, choice=choices[i], 211 | partner_choice=choices[j], partner_input=ctxs[j], max_partner_reward=rewards[j]) 212 | 213 | if agree: 214 | self.metrics.record('advantage', rewards[0] - rewards[1]) 215 | self.metrics.record('moving_advantage', rewards[0] - rewards[1]) 216 | self.metrics.record('agree_comb_rew', np.sum(rewards)) 217 | for agent, reward in zip(self.agents, rewards): 218 | self.metrics.record('agree_%s_rew' % agent.name, reward) 219 | 220 | self.metrics.record('time') 221 | self.metrics.record('dialog_len', len(conv)) 222 | self.metrics.record('agree', int(agree)) 223 | self.metrics.record('moving_agree', int(agree)) 224 | self.metrics.record('comb_rew', np.sum(rewards) if agree else 0) 225 | for agent, reward in zip(self.agents, rewards): 226 | self.metrics.record('%s_rew' % agent.name, reward if agree else 0) 227 | self.metrics.record('%s_moving_rew' % agent.name, reward if agree else 0) 228 | 229 | logger.dump('-' * 80) 230 | logger.dump(self.show_metrics()) 231 | logger.dump('-' * 80) 232 | for ctx, choice in zip(ctxs, choices): 233 | logger.dump('debug: %s %s' % (' '.join(ctx), ' '.join(choice))) 234 | 235 | return conv, agree, rewards 236 | -------------------------------------------------------------------------------- /src/engines/engine.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | """ 7 | Training utilities. 8 | """ 9 | 10 | import argparse 11 | import random 12 | import pdb 13 | import time 14 | import itertools 15 | import sys 16 | import copy 17 | import re 18 | import logging 19 | import torch 20 | from torch import optim 21 | import torch.nn as nn 22 | from torch.autograd import Variable 23 | import numpy as np 24 | 25 | from data import STOP_TOKENS 26 | import vis 27 | 28 | class Criterion(object): 29 | """Weighted CrossEntropyLoss.""" 30 | def __init__(self, dictionary, device_id=None, bad_toks=[], reduction='mean'): 31 | w = torch.Tensor(len(dictionary)).fill_(1) 32 | for tok in bad_toks: 33 | w[dictionary.get_idx(tok)] = 0.0 34 | if device_id is not None: 35 | w = w.cuda(device_id) 36 | # https://pytorch.org/docs/stable/nn.html 37 | self.crit = nn.CrossEntropyLoss(w, reduction=reduction) 38 | 39 | def __call__(self, out, tgt): 40 | return self.crit(out, tgt) 41 | 42 | 43 | class Engine(object): 44 | """The training engine. 45 | 46 | Performs training and evaluation. 47 | """ 48 | def __init__(self, model, args, device_id=None, verbose=False): 49 | self.model = model 50 | self.args = args 51 | self.device_id = device_id 52 | self.verbose = verbose 53 | self.opt = optim.SGD(self.model.parameters(), lr=self.args.lr, 54 | momentum=self.args.momentum, 55 | nesterov=(self.args.nesterov and self.args.momentum > 0)) 56 | self.crit = Criterion(self.model.word_dict, device_id=device_id) 57 | self.sel_crit = Criterion( 58 | self.model.item_dict, device_id=device_id, bad_toks=['', '']) 59 | if self.args.visual: 60 | self.model_plot = vis.ModulePlot(self.model, plot_weight=False, plot_grad=True) 61 | self.loss_plot = vis.Plot(['train', 'valid', 'valid_select'], 62 | 'loss', 'loss', 'epoch', running_n=1) 63 | self.ppl_plot = vis.Plot(['train', 'valid', 'valid_select'], 64 | 'perplexity', 'ppl', 'epoch', running_n=1) 65 | 66 | def forward(model, batch, requires_grad=False): 67 | """A helper function to perform a forward pass on a batch.""" 68 | 69 | with torch.set_grad_enabled(requires_grad): 70 | # extract the batch into contxt, input, target and selection target 71 | ctx, inpt, tgt, sel_tgt = batch 72 | 73 | # create variables 74 | ctx = Variable(ctx) 75 | inpt = Variable(inpt) 76 | tgt = Variable(tgt) 77 | sel_tgt = Variable(sel_tgt) 78 | 79 | # get context hidden state 80 | ctx_h = model.forward_context(ctx) 81 | # create initial hidden state for the language rnn 82 | lang_h = model.zero_hid(ctx_h.size(1), model.args.nhid_lang) 83 | 84 | # perform forward for the language model 85 | out, lang_h = model.forward_lm(inpt, lang_h, ctx_h) 86 | # perform forward for the selection 87 | sel_out = model.forward_selection(inpt, lang_h, ctx_h) 88 | 89 | return out, lang_h, tgt, sel_out, sel_tgt 90 | 91 | def get_model(self): 92 | """Extracts the model.""" 93 | return self.model 94 | 95 | def train_pass(self, N, trainset): 96 | """Training pass.""" 97 | # make the model trainable 98 | self.model.train() 99 | 100 | total_loss = 0 101 | start_time = time.time() 102 | 103 | # training loop 104 | for batch in trainset: 105 | self.t += 1 106 | # forward pass 107 | out, hid, tgt, sel_out, sel_tgt = Engine.forward(self.model, batch, requires_grad=True) 108 | 109 | # compute LM loss and selection loss 110 | loss = self.crit(out.view(-1, N), tgt) 111 | loss += self.sel_crit(sel_out, sel_tgt) * self.model.args.sel_weight 112 | self.opt.zero_grad() 113 | # backward step with gradient clipping 114 | loss.backward() 115 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip) 116 | self.opt.step() 117 | 118 | if self.args.visual and self.t % 100 == 0: 119 | self.model_plot.update(self.t) 120 | 121 | total_loss += loss.item() 122 | 123 | total_loss /= len(trainset) 124 | time_elapsed = time.time() - start_time 125 | return total_loss, time_elapsed 126 | 127 | def train_single(self, N, trainset): 128 | """A helper function to train on a random batch.""" 129 | batch = random.choice(trainset) 130 | out, hid, tgt, sel_out, sel_tgt = Engine.forward(self.model, batch, requires_grad=True) 131 | loss = self.crit(out.view(-1, N), tgt) + \ 132 | self.sel_crit(sel_out, sel_tgt) * self.model.args.sel_weight 133 | self.opt.zero_grad() 134 | loss.backward() 135 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip) 136 | self.opt.step() 137 | return loss 138 | 139 | def valid_pass(self, N, validset, validset_stats): 140 | """Validation pass.""" 141 | # put the model into the evaluation mode 142 | self.model.eval() 143 | 144 | valid_loss, select_loss = 0, 0 145 | for batch in validset: 146 | # compute forward pass 147 | out, hid, tgt, sel_out, sel_tgt = Engine.forward(self.model, batch, requires_grad=False) 148 | 149 | # evaluate LM and selection losses 150 | valid_loss += tgt.size(0) * self.crit(out.view(-1, N), tgt).item() 151 | select_loss += self.sel_crit(sel_out, sel_tgt).item() 152 | 153 | # dividing by the number of words in the input, not the tokens modeled, 154 | # because the latter includes padding 155 | return valid_loss / validset_stats['nonpadn'], select_loss / len(validset) 156 | 157 | def iter(self, N, epoch, lr, traindata, validdata): 158 | """Performs on iteration of the training. 159 | Runs one epoch on the training and validation datasets. 160 | """ 161 | trainset, _ = traindata 162 | validset, validset_stats = validdata 163 | 164 | train_loss, train_time = self.train_pass(N, trainset) 165 | valid_loss, valid_select_loss = self.valid_pass(N, validset, validset_stats) 166 | 167 | if self.verbose: 168 | logging.info('| epoch %03d | train_loss %.3f | train_ppl %.3f | s/epoch %.2f | lr %0.8f' % ( 169 | epoch, train_loss, np.exp(train_loss), train_time, lr)) 170 | logging.info('| epoch %03d | valid_loss %.3f | valid_ppl %.3f' % ( 171 | epoch, valid_loss, np.exp(valid_loss))) 172 | logging.info('| epoch %03d | valid_select_loss %.3f | valid_select_ppl %.3f' % ( 173 | epoch, valid_select_loss, np.exp(valid_select_loss))) 174 | 175 | if self.args.visual: 176 | self.loss_plot.update('train', epoch, train_loss) 177 | self.loss_plot.update('valid', epoch, valid_loss) 178 | self.loss_plot.update('valid_select', epoch, valid_select_loss) 179 | self.ppl_plot.update('train', epoch, np.exp(train_loss)) 180 | self.ppl_plot.update('valid', epoch, np.exp(valid_loss)) 181 | self.ppl_plot.update('valid_select', epoch, np.exp(valid_select_loss)) 182 | 183 | return train_loss, valid_loss, valid_select_loss 184 | 185 | def train(self, corpus): 186 | """Entry point.""" 187 | N = len(corpus.word_dict) 188 | best_model, best_valid_select_loss = None, 1e20 189 | lr = self.args.lr 190 | last_decay_epoch = 0 191 | self.t = 0 192 | 193 | validdata = corpus.valid_dataset(self.args.bsz, device_id=self.device_id) 194 | for epoch in range(1, self.args.max_epoch + 1): 195 | traindata = corpus.train_dataset(self.args.bsz, device_id=self.device_id) 196 | _, _, valid_select_loss = self.iter(N, epoch, lr, traindata, validdata) 197 | 198 | if valid_select_loss < best_valid_select_loss: 199 | best_valid_select_loss = valid_select_loss 200 | best_model = copy.deepcopy(self.model) 201 | 202 | if self.verbose: 203 | logging.info('| start annealing | best validselectloss %.3f | best validselectppl %.3f' % ( 204 | best_valid_select_loss, np.exp(best_valid_select_loss))) 205 | 206 | self.model = best_model 207 | for epoch in range(self.args.max_epoch + 1, 100): 208 | if epoch - last_decay_epoch >= self.args.decay_every: 209 | last_decay_epoch = epoch 210 | lr /= self.args.decay_rate 211 | if lr < self.args.min_lr: 212 | break 213 | self.opt = optim.SGD(self.model.parameters(), lr=lr) 214 | 215 | traindata = corpus.train_dataset(self.args.bsz, device_id=self.device_id) 216 | train_loss, valid_loss, valid_select_loss = self.iter( 217 | N, epoch, lr, traindata, validdata) 218 | 219 | return train_loss, valid_loss, valid_select_loss 220 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | This is a [PyTorch](http://pytorch.org/) implementation of the following research papers: 3 | * (1) [Hierarchical Text Generation and Planning for Strategic Dialogue](https://arxiv.org/abs/1712.05846) 4 | * (2) [Deal or No Deal? End-to-End Learning for Negotiation Dialogues](https://arxiv.org/abs/1706.05125) 5 | 6 | 7 | The code is developed by [Facebook AI Research](http://research.fb.com/category/facebook-ai-research-fair). 8 | 9 | The code trains neural networks to hold negotiations in natural language, and allows reinforcement learning self play and rollout-based planning. 10 | 11 | 12 | # Citation 13 | If you want to use this code in your research, please cite: 14 | ``` 15 | @inproceedings{DBLP:conf/icml/YaratsL18, 16 | author = {Denis Yarats and 17 | Mike Lewis}, 18 | title = {Hierarchical Text Generation and Planning for Strategic Dialogue}, 19 | booktitle = {Proceedings of the 35th International Conference on Machine Learning, 20 | {ICML} 2018, Stockholmsm{\"{a}}ssan, Stockholm, Sweden, July 21 | 10-15, 2018}, 22 | pages = {5587--5595}, 23 | year = {2018}, 24 | crossref = {DBLP:conf/icml/2018}, 25 | url = {http://proceedings.mlr.press/v80/yarats18a.html}, 26 | timestamp = {Fri, 13 Jul 2018 14:58:25 +0200}, 27 | biburl = {https://dblp.org/rec/bib/conf/icml/YaratsL18}, 28 | bibsource = {dblp computer science bibliography, https://dblp.org} 29 | } 30 | ``` 31 | 32 | 33 | # Dataset 34 | We release our dataset together with the code, you can find it under `data/negotiate`. This dataset consists of 5808 dialogues, based on 2236 unique scenarios. Take a look at §2.3 of the paper to learn about data collection. 35 | 36 | Each dialogue is converted into two training examples in the dataset, showing the complete conversation from the perspective of each agent. The perspectives differ on their input goals, output choice, and in special tokens marking whether a statement was read or written. See §3.1 for the details on data representation. 37 | ``` 38 | # Perspective of Agent 1 39 | 1 4 4 1 1 2 40 | THEM: i would like 4 hats and you can have the rest . YOU: deal THEM: 41 | item0=1 item1=0 item2=1 item0=0 item1=4 item2=0 42 | 1 0 4 2 1 2 43 | 44 | # Perspective of Agent 2 45 | 1 0 4 2 1 2 46 | YOU: i would like 4 hats and you can have the rest . THEM: deal YOU: 47 | item0=0 item1=4 item2=0 item0=1 item1=0 item2=1 48 | 1 4 4 1 1 2 49 | ``` 50 | 51 | # Setup 52 | All code was developed with Python 3.0 on CentOS Linux 7, and tested on Ubuntu 16.04. In addition, we used PyTorch 1.0.0, CUDA 9.0, and Visdom 0.1.8.4. 53 | 54 | We recommend to use [Anaconda](https://www.continuum.io/why-anaconda). In order to set up a working environment follow the steps below: 55 | ``` 56 | # Install anaconda 57 | conda create -n py30 python=3 anaconda 58 | # Activate environment 59 | source activate py30 60 | # Install PyTorch 61 | conda install pytorch torchvision cuda90 -c pytorch 62 | # Install Visdom if you want to use visualization 63 | pip install visdom 64 | ``` 65 | 66 | # Usage 67 | ## Supervised Training 68 | 69 | ### Action Classifier 70 | We use an action classifier to compare performance of various models. The action classifier is described in section 3 of (2). It can be trained by running the following command: 71 | ``` 72 | python train.py \ 73 | --cuda \ 74 | --bsz 16 \ 75 | --clip 2.0 \ 76 | --decay_every 1 \ 77 | --decay_rate 5.0 \ 78 | --domain object_division \ 79 | --dropout 0.1 \ 80 | --init_range 0.2 \ 81 | --lr 0.001 \ 82 | --max_epoch 7 \ 83 | --min_lr 1e-05 \ 84 | --model_type selection_model \ 85 | --momentum 0.1 \ 86 | --nembed_ctx 128 \ 87 | --nembed_word 128 \ 88 | --nhid_attn 128 \ 89 | --nhid_ctx 64 \ 90 | --nhid_lang 128 \ 91 | --nhid_sel 128 \ 92 | --nhid_strat 256 \ 93 | --unk_threshold 20 \ 94 | --skip_values \ 95 | --sep_sel \ 96 | --model_file selection_model.th 97 | ``` 98 | 99 | ### Baseline RNN Model 100 | This is the baseline RNN model that we describe in (1): 101 | ``` 102 | python train.py \ 103 | --cuda \ 104 | --bsz 16 \ 105 | --clip 0.5 \ 106 | --decay_every 1 \ 107 | --decay_rate 5.0 \ 108 | --domain object_division \ 109 | --dropout 0.1 \ 110 | --model_type rnn_model \ 111 | --init_range 0.2 \ 112 | --lr 0.001 \ 113 | --max_epoch 30 \ 114 | --min_lr 1e-07 \ 115 | --momentum 0.1 \ 116 | --nembed_ctx 64 \ 117 | --nembed_word 256 \ 118 | --nhid_attn 64 \ 119 | --nhid_ctx 64 \ 120 | --nhid_lang 128 \ 121 | --nhid_sel 128 \ 122 | --sel_weight 0.6 \ 123 | --unk_threshold 20 \ 124 | --sep_sel \ 125 | --model_file rnn_model.th 126 | ``` 127 | 128 | ### Hierarchical Latent Model 129 | In this section we provide guidelines on how to train the hierarchical latent model from (2). The final model requires two sub-models: the clustering model, which learns compact representations over intents; and the language model, which translates intent representations into language. Please read sections 5 and 6 of (2) for more details. 130 | 131 | **Clustering Model** 132 | ``` 133 | python train.py \ 134 | --cuda \ 135 | --bsz 16 \ 136 | --clip 2.0 \ 137 | --decay_every 1 \ 138 | --decay_rate 5.0 \ 139 | --domain object_division \ 140 | --dropout 0.2 \ 141 | --init_range 0.3 \ 142 | --lr 0.001 \ 143 | --max_epoch 15 \ 144 | --min_lr 1e-05 \ 145 | --model_type latent_clustering_model \ 146 | --momentum 0.1 \ 147 | --nembed_ctx 64 \ 148 | --nembed_word 256 \ 149 | --nhid_ctx 64 \ 150 | --nhid_lang 256 \ 151 | --nhid_sel 128 \ 152 | --nhid_strat 256 \ 153 | --unk_threshold 20 \ 154 | --num_clusters 50 \ 155 | --sep_sel \ 156 | --skip_values \ 157 | --nhid_cluster 256 \ 158 | --selection_model_file selection_model.th \ 159 | --model_file clustering_model.th 160 | ``` 161 | 162 | **Language Model** 163 | ``` 164 | python train.py \ 165 | --cuda \ 166 | --bsz 16 \ 167 | --clip 2.0 \ 168 | --decay_every 1 \ 169 | --decay_rate 5.0 \ 170 | --domain object_division \ 171 | --dropout 0.1 \ 172 | --init_range 0.2 \ 173 | --lr 0.001 \ 174 | --max_epoch 15 \ 175 | --min_lr 1e-05 \ 176 | --model_type latent_clustering_language_model \ 177 | --momentum 0.1 \ 178 | --nembed_ctx 64 \ 179 | --nembed_word 256 \ 180 | --nhid_ctx 64 \ 181 | --nhid_lang 256 \ 182 | --nhid_sel 128 \ 183 | --nhid_strat 256 \ 184 | --unk_threshold 20 \ 185 | --num_clusters 50 \ 186 | --sep_sel \ 187 | --nhid_cluster 256 \ 188 | --skip_values \ 189 | --selection_model_file selection_model.th \ 190 | --cluster_model_file clustering_model.th \ 191 | --model_file clustering_language_model.th 192 | ``` 193 | 194 | **Full Model** 195 | ``` 196 | python train.py \ 197 | --cuda \ 198 | --bsz 16 \ 199 | --clip 2.0 \ 200 | --decay_every 1 \ 201 | --decay_rate 5.0 \ 202 | --domain object_division \ 203 | --dropout 0.2 \ 204 | --init_range 0.3 \ 205 | --lr 0.001 \ 206 | --max_epoch 10 \ 207 | --min_lr 1e-05 \ 208 | --model_type latent_clustering_prediction_model \ 209 | --momentum 0.2 \ 210 | --nembed_ctx 64 \ 211 | --nembed_word 256 \ 212 | --nhid_ctx 64 \ 213 | --nhid_lang 256 \ 214 | --nhid_sel 128 \ 215 | --nhid_strat 256 \ 216 | --unk_threshold 20 \ 217 | --num_clusters 50 \ 218 | --sep_sel \ 219 | --selection_model_file selection_model.th \ 220 | --lang_model_file clustering_language_model.th \ 221 | --model_file full_model.th 222 | ``` 223 | 224 | ## Selfplay 225 | If you want to have two pretrained models to negotiate against each another, use `selfplay.py`. For example, lets have two rnn models to play against each other: 226 | ``` 227 | python selfplay.py \ 228 | --cuda \ 229 | --alice_model_file rnn_model.th \ 230 | --bob_model_file rnn_model.th \ 231 | --context_file data/negotiate/selfplay.txt \ 232 | --temperature 0.5 \ 233 | --selection_model_file selection_model.th 234 | ``` 235 | The script will output generated dialogues, as well as some statistics. For example: 236 | ``` 237 | ================================================================================ 238 | Alice : book=(count:3 value:1) hat=(count:1 value:5) ball=(count:1 value:2) 239 | Bob : book=(count:3 value:1) hat=(count:1 value:1) ball=(count:1 value:6) 240 | -------------------------------------------------------------------------------- 241 | Alice : i would like the hat and the ball . 242 | Bob : i need the ball and the hat 243 | Alice : i can give you the ball and one book . 244 | Bob : i can't make a deal without the ball 245 | Alice : okay then i will take the hat and the ball 246 | Bob : okay , that's fine . 247 | Alice : 248 | Alice : book=0 hat=1 ball=1 book=3 hat=0 ball=0 249 | Bob : book=3 hat=0 ball=0 book=0 hat=1 ball=1 250 | -------------------------------------------------------------------------------- 251 | Agreement! 252 | Alice : 7 points 253 | Bob : 3 points 254 | -------------------------------------------------------------------------------- 255 | dialog_len=4.47 sent_len=6.93 agree=86.67% advantage=3.14 time=2.069s comb_rew=10.93 alice_rew=6.93 alice_sel=60.00% alice_unique=26 bob_rew=4.00 bob_sel=40.00% bob_unique=25 full_match=0.78 256 | -------------------------------------------------------------------------------- 257 | debug: 3 1 1 5 1 2 item0=0 item1=1 item2=1 258 | debug: 3 1 1 1 1 6 item0=3 item1=0 item2=0 259 | ================================================================================ 260 | ``` 261 | 262 | ## Reinforcement Learning 263 | To fine-tune a pretrained model with RL use the `reinforce.py` script: 264 | ``` 265 | python reinforce.py \ 266 | --cuda \ 267 | --alice_model_file rnn_model.th \ 268 | --bob_model_file rnn_model.th \ 269 | --output_model_file rnn_rl_model.th \ 270 | --context_file data/negotiate/selfplay.txt \ 271 | --temperature 0.5 \ 272 | --verbose \ 273 | --log_file rnn_rl.log \ 274 | --sv_train_freq 4 \ 275 | --nepoch 4 \ 276 | --selection_model_file selection_model.th \ 277 | --rl_lr 0.00001 \ 278 | --rl_clip 0.0001 \ 279 | --sep_sel 280 | ``` 281 | 282 | # License 283 | This project is licenced under CC-by-NC, see the LICENSE file for details. 284 | -------------------------------------------------------------------------------- /src/models/rnn_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import sys 8 | import re 9 | import time 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torch.nn.init 16 | from torch.autograd import Variable 17 | import torch.nn.functional as F 18 | 19 | import data 20 | from engines.rnn_engine import RnnEngine 21 | from domain import get_domain 22 | from models.utils import * 23 | from models.ctx_encoder import MlpContextEncoder 24 | 25 | 26 | class RnnModel(nn.Module): 27 | corpus_ty = data.WordCorpus 28 | engine_ty = RnnEngine 29 | def __init__(self, word_dict, item_dict, context_dict, count_dict, args): 30 | super(RnnModel, self).__init__() 31 | 32 | domain = get_domain(args.domain) 33 | 34 | self.word_dict = word_dict 35 | self.item_dict = item_dict 36 | self.context_dict = context_dict 37 | self.count_dict = count_dict 38 | self.args = args 39 | 40 | self.word_encoder = nn.Embedding(len(self.word_dict), args.nembed_word) 41 | self.word_encoder_dropout = nn.Dropout(args.dropout) 42 | 43 | ctx_encoder_ty = MlpContextEncoder 44 | self.ctx_encoder = nn.Sequential( 45 | ctx_encoder_ty(len(self.context_dict), domain.input_length(), args.nembed_ctx, 46 | args.nhid_ctx, args.dropout, args.init_range), 47 | nn.Dropout(args.dropout)) 48 | 49 | self.reader = nn.GRU(args.nhid_ctx + args.nembed_word, args.nhid_lang, bias=True) 50 | self.reader_dropout = nn.Dropout(args.dropout) 51 | 52 | self.decoder = nn.Sequential( 53 | nn.Linear(args.nhid_lang, args.nembed_word), 54 | nn.Dropout(args.dropout)) 55 | 56 | self.writer = nn.GRUCell( 57 | input_size=args.nhid_ctx + args.nembed_word, 58 | hidden_size=args.nhid_lang, 59 | bias=True) 60 | 61 | # Tie the weights of reader and writer 62 | self.writer.weight_ih = self.reader.weight_ih_l0 63 | self.writer.weight_hh = self.reader.weight_hh_l0 64 | self.writer.bias_ih = self.reader.bias_ih_l0 65 | self.writer.bias_hh = self.reader.bias_hh_l0 66 | 67 | self.sel_rnn = nn.GRU( 68 | input_size=args.nhid_lang + args.nembed_word, 69 | hidden_size=args.nhid_attn, 70 | bias=True, 71 | bidirectional=True) 72 | self.sel_dropout = nn.Dropout(args.dropout) 73 | 74 | # Mask for disabling special tokens when generating sentences 75 | self.special_token_mask = torch.FloatTensor(len(self.word_dict)) 76 | 77 | self.sel_encoder = nn.Sequential( 78 | torch.nn.Linear(2 * args.nhid_attn + args.nhid_ctx, args.nhid_sel), 79 | nn.Tanh(), 80 | nn.Dropout(args.dropout)) 81 | self.attn = nn.Sequential( 82 | torch.nn.Linear(2 * args.nhid_attn, args.nhid_attn), 83 | nn.Tanh(), 84 | torch.nn.Linear(args.nhid_attn, 1) 85 | ) 86 | self.sel_decoders = nn.ModuleList() 87 | for i in range(domain.selection_length()): 88 | self.sel_decoders.append(nn.Linear(args.nhid_sel, len(self.item_dict))) 89 | 90 | self.init_weights() 91 | 92 | self.special_token_mask = make_mask(len(word_dict), 93 | [word_dict.get_idx(w) for w in ['', 'YOU:', 'THEM:', '']]) 94 | 95 | def flatten_parameters(self): 96 | self.reader.flatten_parameters() 97 | self.sel_rnn.flatten_parameters() 98 | 99 | def zero_h(self, bsz, nhid=None, copies=None): 100 | nhid = self.args.nhid_lang if nhid is None else nhid 101 | copies = 1 if copies is None else copies 102 | h = torch.Tensor(copies, bsz, nhid).fill_(0) 103 | return Variable(h) 104 | 105 | def word2var(self, word): 106 | x = torch.Tensor(1).fill_(self.word_dict.get_idx(word)).long() 107 | return Variable(x) 108 | 109 | def init_weights(self): 110 | #init_rnn(self.reader, self.args.init_range) 111 | init_cont(self.decoder, self.args.init_range) 112 | self.word_encoder.weight.data.uniform_(-self.args.init_range, self.args.init_range) 113 | 114 | init_cont(self.attn, self.args.init_range) 115 | init_cont(self.sel_encoder, self.args.init_range) 116 | init_cont(self.sel_decoders, self.args.init_range) 117 | 118 | def read(self, inpt, lang_h, ctx_h, prefix_token='THEM:'): 119 | # Add a 'THEM:' token to the start of the message 120 | prefix = self.word2var(prefix_token).unsqueeze(0) 121 | inpt = torch.cat([prefix, inpt]) 122 | 123 | inpt_emb = self.word_encoder(inpt) 124 | 125 | # Append the context embedding to every input word embedding 126 | ctx_h_rep = ctx_h.expand(inpt_emb.size(0), ctx_h.size(1), ctx_h.size(2)) 127 | inpt_emb = torch.cat([inpt_emb, ctx_h_rep], 2) 128 | 129 | # Finally read in the words 130 | out, lang_h = self.reader(inpt_emb, lang_h) 131 | 132 | return out, lang_h 133 | 134 | def generate_choice_logits(self, inpt, lang_h, ctx_h): 135 | # Run a birnn over the concatenation of the input embeddings and language model hidden states 136 | inpt_emb = self.word_encoder(inpt) 137 | inpt_emb = self.word_encoder_dropout(inpt_emb) 138 | h = torch.cat([lang_h.unsqueeze(1), inpt_emb], 2) 139 | 140 | attn_h = self.zero_h(h.size(1), self.args.nhid_attn, copies=2) 141 | h, _ = self.sel_rnn(h, attn_h) 142 | h = h.squeeze(1) 143 | 144 | logit = self.attn(h).squeeze(1) 145 | prob = F.softmax(logit).unsqueeze(1).expand_as(h) 146 | attn = torch.sum(torch.mul(h, prob), 0, keepdim=True) 147 | 148 | ctx_h = ctx_h.squeeze(1) 149 | h = torch.cat([attn, ctx_h], 1) 150 | h = self.sel_encoder.forward(h) 151 | 152 | logits = [decoder.forward(h).squeeze(0) for decoder in self.sel_decoders] 153 | return logits 154 | 155 | def write_batch(self, bsz, lang_h, ctx_h, temperature, max_words=100): 156 | eod = self.word_dict.get_idx('') 157 | 158 | lang_h = lang_h.squeeze(0).expand(bsz, lang_h.size(2)) 159 | ctx_h = ctx_h.squeeze(0).expand(bsz, ctx_h.size(2)) 160 | 161 | inpt = self.word2var('YOU:') 162 | 163 | outs, lang_hs = [], [lang_h.unsqueeze(0)] 164 | done = set() 165 | for _ in range(max_words): 166 | inpt_emb = torch.cat([self.word_encoder(inpt), ctx_h], 1) 167 | lang_h = self.writer(inpt_emb, lang_h) 168 | out = self.decoder(lang_h) 169 | scores = F.linear(out, self.word_encoder.weight).div(temperature) 170 | scores.sub_(scores.max(1)[0].expand(scores.size(0), scores.size(1))) 171 | out = torch.multinomial(scores.exp(), 1).squeeze(1) 172 | outs.append(out.unsqueeze(0)) 173 | lang_hs.append(lang_h.unsqueeze(0)) 174 | inpt = out 175 | 176 | data = out.data.cpu() 177 | for i in range(bsz): 178 | if data[i] == eod: 179 | done.add(i) 180 | if len(done) == bsz: 181 | break 182 | 183 | inpt_emb = torch.cat([self.word_encoder(inpt), ctx_h], 1) 184 | lang_h = self.writer(inpt_emb, lang_h) 185 | lang_hs.append(lang_h.unsqueeze(0)) 186 | 187 | return torch.cat(outs, 0), torch.cat(lang_hs, 0) 188 | 189 | def write(self, lang_h, ctx_h, max_words, temperature, 190 | stop_tokens=data.STOP_TOKENS, resume=False): 191 | """ 192 | Generate a sentence word by word and feed the output of the 193 | previous timestep as input to the next. 194 | """ 195 | outs, logprobs, lang_hs = [], [], [] 196 | # Remove batch dimension 197 | lang_h = lang_h.squeeze(1) 198 | ctx_h = ctx_h.squeeze(1) 199 | inpt = None if resume else self.word2var('YOU:') 200 | 201 | for _ in range(max_words): 202 | if inpt is not None: 203 | # Add the context to the word embedding 204 | inpt_emb = torch.cat([self.word_encoder(inpt), ctx_h], 1) 205 | # Update RNN state with last word 206 | lang_h = self.writer(inpt_emb, lang_h) 207 | lang_hs.append(lang_h) 208 | 209 | # Decode words using the inverse of the word embedding matrix 210 | out = self.decoder(lang_h) 211 | scores = F.linear(out, self.word_encoder.weight).div(temperature) 212 | # Subtract constant to avoid overflows in exponentiation 213 | scores = scores.add(-scores.max().item()).squeeze(0) 214 | 215 | # Disable special tokens from being generated in a normal turns 216 | if not resume: 217 | mask = Variable(self.special_token_mask) 218 | scores = scores.add(mask) 219 | 220 | prob = F.softmax(scores, dim=0) 221 | logprob = F.log_softmax(scores, dim=0) 222 | 223 | word = prob.multinomial(1).detach() 224 | logprob = logprob.gather(0, word) 225 | 226 | logprobs.append(logprob) 227 | outs.append(word.view(word.size()[0], 1)) 228 | 229 | inpt = word 230 | 231 | # Check if we generated an token 232 | if self.word_dict.get_word(word.item()) in stop_tokens: 233 | break 234 | 235 | # Update the hidden state with the token 236 | inpt_emb = torch.cat([self.word_encoder(inpt), ctx_h], 1) 237 | lang_h = self.writer(inpt_emb, lang_h) 238 | lang_hs.append(lang_h) 239 | 240 | # Add batch dimension back 241 | lang_h = lang_h.unsqueeze(1) 242 | 243 | return logprobs, torch.cat(outs), lang_h, torch.cat(lang_hs, 0) 244 | 245 | def score_sent(self, sent, lang_h, ctx_h, temperature): 246 | score = 0 247 | lang_h = lang_h.squeeze(1) 248 | ctx_h = ctx_h.squeeze(1) 249 | inpt = self.word2var('YOU:') 250 | lang_hs = [] 251 | 252 | for word in sent: 253 | inpt_emb = torch.cat([self.word_encoder(inpt), ctx_h], 1) 254 | lang_h = self.writer(inpt_emb, lang_h) 255 | lang_hs.append(lang_h) 256 | 257 | out = self.decoder(lang_h) 258 | scores = F.linear(out, self.word_encoder.weight).div(temperature) 259 | scores = scores.add(-scores.max().data[0]).squeeze(0) 260 | 261 | mask = Variable(self.special_token_mask) 262 | scores = scores.add(mask) 263 | 264 | logprob = F.log_softmax(scores) 265 | score += logprob[word[0]].data[0] 266 | inpt = Variable(word) 267 | 268 | inpt_emb = torch.cat([self.word_encoder(inpt), ctx_h], 1) 269 | lang_h = self.writer(inpt_emb, lang_h) 270 | lang_hs.append(lang_h) 271 | 272 | lang_h = lang_h.unsqueeze(1) 273 | 274 | return score, lang_h, torch.cat(lang_hs) 275 | 276 | def forward_context(self, ctx): 277 | ctx_h = self.ctx_encoder(ctx).unsqueeze(0) 278 | return ctx_h 279 | 280 | def forward_lm(self, inpt_emb, lang_h, ctx_h): 281 | # append the context embedding to every input word embedding 282 | ctx_h_rep = ctx_h.narrow(0, ctx_h.size(0) - 1, 1).expand( 283 | inpt_emb.size(0), ctx_h.size(1), ctx_h.size(2)) 284 | inpt_emb = torch.cat([inpt_emb, ctx_h_rep], 2) 285 | 286 | lang_hs, _ = self.reader(inpt_emb, lang_h) 287 | lang_hs = self.reader_dropout(lang_hs) 288 | 289 | decoded = self.decoder(lang_hs.view(-1, lang_hs.size(2))) 290 | out = F.linear(decoded, self.word_encoder.weight) 291 | 292 | return out, lang_hs 293 | 294 | def forward_selection(self, inpt_emb, lang_h, ctx_h): 295 | # run a birnn over the concatenation of the input embeddings and language model hidden states 296 | h = torch.cat([lang_h, inpt_emb], 2) 297 | 298 | attn_h = self.zero_h(h.size(1), self.args.nhid_attn, copies=2) 299 | h, _ = self.sel_rnn(h, attn_h) 300 | h = self.sel_dropout(h) 301 | 302 | h = h.transpose(0, 1).contiguous() 303 | logit = self.attn(h.view(-1, 2 * self.args.nhid_attn)).view(h.size(0), h.size(1)) 304 | prob = F.softmax(logit, dim=1).unsqueeze(2).expand_as(h) 305 | attn = torch.sum(torch.mul(h, prob), 1, keepdim=True).transpose(0, 1).contiguous() 306 | 307 | h = torch.cat([attn, ctx_h], 2).squeeze(0) 308 | h = self.sel_encoder.forward(h) 309 | 310 | outs = [decoder.forward(h) for decoder in self.sel_decoders] 311 | out = torch.cat(outs, 0) 312 | return out 313 | 314 | def forward(self, inpt, ctx): 315 | ctx_h = self.forward_context(ctx) 316 | lang_h = self.zero_h(ctx_h.size(1), self.args.nhid_lang) 317 | 318 | inpt_emb = self.word_encoder(inpt) 319 | inpt_emb = self.word_encoder_dropout(inpt_emb) 320 | 321 | out, lang_hs = self.forward_lm(inpt_emb, lang_h, ctx_h) 322 | sel_out = self.forward_selection(inpt_emb, lang_hs, ctx_h) 323 | 324 | return out, sel_out 325 | -------------------------------------------------------------------------------- /src/models/attn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.nn.init 12 | from torch.autograd import Variable 13 | 14 | from models.utils import * 15 | 16 | 17 | class Attention(nn.Module): 18 | def __init__(self, query_size, value_size, hid_size, init_range): 19 | super(Attention, self).__init__() 20 | 21 | self.value2hid = nn.Linear(value_size, hid_size) 22 | self.query2hid = nn.Linear(query_size, hid_size) 23 | self.hid2output = nn.Linear(hid_size, 1) 24 | 25 | self.value2hid.weight.data.uniform_(-init_range, init_range) 26 | self.value2hid.bias.data.fill_(0) 27 | self.query2hid.weight.data.uniform_(-init_range, init_range) 28 | self.query2hid.bias.data.fill_(0) 29 | self.hid2output.weight.data.uniform_(-init_range, init_range) 30 | self.hid2output.bias.data.fill_(0) 31 | 32 | def _bottle(self, linear, x): 33 | y = linear(x.view(-1, x.size(-1))) 34 | return y.view(x.size(0), x.size(1), -1) 35 | 36 | def forward_attn(self, h): 37 | logit = self.attn(h.view(-1, h.size(2))).view(h.size(0), h.size(1)) 38 | return logit 39 | 40 | def forward(self, q, v, mask=None): 41 | # q: [batch_size, query_size] 42 | # v: [length, batch_size, value_size] 43 | v = v.transpose(0, 1).contiguous() 44 | 45 | h_v = self._bottle(self.value2hid, v) 46 | h_q = self.query2hid(q) 47 | 48 | h = torch.tanh(h_v + h_q.unsqueeze(1).expand_as(h_v)) 49 | logit = self._bottle(self.hid2output, h).squeeze(2) 50 | 51 | logit = logit.sub(logit.max(1, keepdim=True)[0].expand_as(logit)) 52 | if mask is not None: 53 | logit = torch.add(logit, Variable(mask)) 54 | 55 | p = F.softmax(logit, dim=1) 56 | w = p.unsqueeze(2).expand_as(v) 57 | h = torch.sum(torch.mul(v, w), 1, keepdim=True) 58 | h = h.transpose(0, 1).contiguous() 59 | 60 | return h, p 61 | 62 | 63 | class KeyValueAttention(nn.Module): 64 | def __init__(self, query_size, key_size, value_size, hid_size, init_range): 65 | super(KeyValueAttention, self).__init__() 66 | 67 | self.key2hid = nn.Linear(key_size, hid_size) 68 | self.query2hid = nn.Linear(query_size, hid_size) 69 | self.hid2output = nn.Linear(hid_size, 1) 70 | 71 | self.key2hid.weight.data.uniform_(-init_range, init_range) 72 | self.key2hid.bias.data.fill_(0) 73 | self.query2hid.weight.data.uniform_(-init_range, init_range) 74 | self.query2hid.bias.data.fill_(0) 75 | self.hid2output.weight.data.uniform_(-init_range, init_range) 76 | self.hid2output.bias.data.fill_(0) 77 | 78 | def _bottle(self, linear, x): 79 | y = linear(x.view(-1, x.size(-1))) 80 | return y.view(x.size(0), x.size(1), -1) 81 | 82 | def forward_attn(self, h): 83 | logit = self.attn(h.view(-1, h.size(2))).view(h.size(0), h.size(1)) 84 | return logit 85 | 86 | def forward(self, q, k, v, mask=None): 87 | # q: [batch_size, query_size] 88 | # k: [length, batch_size, key_size] 89 | # v: [length, batch_size, value_size] 90 | k = k.transpose(0, 1).contiguous() 91 | v = v.transpose(0, 1).contiguous() 92 | 93 | h_k = self._bottle(self.key2hid, k) 94 | h_q = self.query2hid(q) 95 | 96 | h = F.tanh(h_k + h_q.unsqueeze(1).expand_as(h_k)) 97 | logit = self._bottle(self.hid2output, h).squeeze(2) 98 | 99 | logit = logit.sub(logit.max(1, keepdim=True)[0].expand_as(logit)) 100 | if mask is not None: 101 | logit = torch.add(logit, Variable(mask)) 102 | 103 | p = F.softmax(logit) 104 | w = p.unsqueeze(2).expand_as(v) 105 | h = torch.sum(torch.mul(v, w), 1, keepdim=True) 106 | h = h.transpose(0, 1).contiguous() 107 | 108 | return h, p 109 | 110 | 111 | class MaskedAttention(nn.Module): 112 | def __init__(self, query_size, value_size, hid_size, init_range): 113 | super(MaskedAttention, self).__init__() 114 | self.attn = Attention(query_size, value_size, hid_size, init_range) 115 | 116 | def make_mask(self, v, ln): 117 | mask = torch.Tensor(v.size(1), v.size(0)).fill_(-1000) 118 | for i in range(v.size(1)): 119 | mask.narrow(0, i, 1).narrow(1, 0, ln[i] + 1).fill_(0) 120 | return mask 121 | 122 | def forward(self, q, v, ln=None): 123 | mask = self.make_mask(v, ln) if ln is not None else None 124 | return self.attn(q, v, mask=mask) 125 | 126 | 127 | class ChunkedAttention(nn.Module): 128 | def __init__(self, query_size, value_size, hid_size, init_range): 129 | super(ChunkedAttention, self).__init__() 130 | self.query_size = query_size 131 | self.value_size = value_size 132 | self.hid_size = hid_size 133 | self.init_range = init_range 134 | 135 | def zero_h(self, bsz, n=1): 136 | h = torch.Tensor(n, bsz, self.hid_size).fill_(0) 137 | return Variable(h) 138 | 139 | def make_mask(self, fwd_hs, lens): 140 | bsz = fwd_hs[0].size(1) 141 | n = sum(fwd_h.size(0) for fwd_h in fwd_hs) 142 | 143 | mask = torch.Tensor(bsz, n).fill_(-1000) 144 | offset = 0 145 | for fwd_h, ln in zip(fwd_hs, lens): 146 | for i in range(bsz): 147 | mask.narrow(0, i, 1).narrow(1, offset, ln[i] + 1).fill_(0) 148 | offset += fwd_h.size(0) 149 | return mask 150 | 151 | def reverse(self, fwd_inpts, fwd_lens, rev_idxs): 152 | bwd_inpts, bwd_lens = [], [] 153 | for inpt, ln, rev_idx in zip(reversed(fwd_inpts), reversed(fwd_lens), reversed(rev_idxs)): 154 | bwd_inpt = inpt.gather(0, 155 | rev_idx.expand(rev_idx.size(0), rev_idx.size(1), inpt.size(2))) 156 | bwd_inpts.append(bwd_inpt) 157 | bwd_lens.append(ln) 158 | 159 | return bwd_inpts, bwd_lens 160 | 161 | 162 | class BiRnnAttention(ChunkedAttention): 163 | def __init__(self, query_size, value_size, hid_size, dropout, init_range): 164 | super(BiRnnAttention, self).__init__(query_size, value_size, hid_size, init_range) 165 | 166 | self.dropout = nn.Dropout(dropout) 167 | 168 | self.fwd_rnn = nn.GRU(value_size, hid_size, bias=True) 169 | self.bwd_rnn = nn.GRU(value_size, hid_size, bias=True) 170 | 171 | self.attn = Attention(query_size, 2 * hid_size, hid_size, init_range) 172 | 173 | def flatten_parameters(self): 174 | self.fwd_rnn.flatten_parameters() 175 | self.bwd_rnn.flatten_parameters() 176 | 177 | def forward_attn(self, query, fwd_hs, bwd_hs, lens): 178 | fwd_h = torch.cat(fwd_hs, 0) 179 | bwd_h = torch.cat(bwd_hs, 0) 180 | h = torch.cat([fwd_h, bwd_h], 2) 181 | h = self.dropout(h) 182 | 183 | mask = self.make_mask(fwd_hs, lens) 184 | h, p = self.attn(query, h, mask) 185 | return h, p 186 | 187 | def forward_rnn(self, rnn, inpts, lens, hid_idxs): 188 | bsz = inpts[0].size(1) 189 | hs = [] 190 | h = self.zero_h(bsz) 191 | for inpt, ln, hid_idx in zip(inpts, lens, hid_idxs): 192 | out, _ = rnn(inpt, h) 193 | hs.append(out) 194 | h = out.gather(0, hid_idx.expand(hid_idx.size(0), hid_idx.size(1), out.size(2))) 195 | return hs 196 | 197 | def forward(self, query, fwd_inpts, fwd_lens, rev_idxs, hid_idxs): 198 | # reverse inputs 199 | bwd_inpts, bwd_lens = self.reverse(fwd_inpts, fwd_lens, rev_idxs) 200 | 201 | fwd_hs = self.forward_rnn(self.fwd_rnn, fwd_inpts, fwd_lens, hid_idxs) 202 | bwd_hs = self.forward_rnn(self.bwd_rnn, bwd_inpts, bwd_lens, reversed(hid_idxs)) 203 | 204 | # reverse them back to align with fwd_inpts 205 | bwd_hs, _ = self.reverse(bwd_hs, bwd_lens, list(reversed(rev_idxs))) 206 | 207 | h, p = self.forward_attn(query, fwd_hs, bwd_hs, fwd_lens) 208 | h = h.squeeze(0) 209 | return h, p 210 | 211 | 212 | class HierarchicalAttention(ChunkedAttention): 213 | def __init__(self, query_size, value_size, hid_size, dropout, init_range): 214 | super(HierarchicalAttention, self).__init__( 215 | query_size, value_size, hid_size, init_range) 216 | 217 | self.word_dropout = nn.Dropout(dropout) 218 | self.sent_dropout = nn.Dropout(dropout) 219 | 220 | self.fwd_word_rnn = nn.GRU(value_size, hid_size, bias=True) 221 | self.bwd_word_rnn = nn.GRU(value_size, hid_size, bias=True) 222 | self.word_attn = Attention(query_size, 2 * hid_size, hid_size, init_range) 223 | 224 | self.sent_rnn = nn.GRU(2 * hid_size, hid_size, bidirectional=True, bias=True) 225 | self.sent_attn = Attention(query_size, 2 * hid_size, hid_size, init_range) 226 | 227 | init_rnn(self.fwd_word_rnn, init_range) 228 | init_rnn(self.bwd_word_rnn, init_range) 229 | init_rnn(self.sent_rnn, init_range) 230 | 231 | def flatten_parameters(self): 232 | self.fwd_word_rnn.flatten_parameters() 233 | self.bwd_word_rnn.flatten_parameters() 234 | self.sent_rnn.flatten_parameters() 235 | 236 | def forward_word_attn(self, query, fwd_word_hs, bwd_word_hs, ln, rev_idx, hid_idx): 237 | # reverse bwd_word_h 238 | bwd_word_hs = bwd_word_hs.gather(0, 239 | rev_idx.expand(rev_idx.size(0), rev_idx.size(1), bwd_word_hs.size(2))) 240 | 241 | word_hs = torch.cat([fwd_word_hs, bwd_word_hs], 2) 242 | word_hs = self.word_dropout(word_hs) 243 | 244 | mask = self.make_mask([fwd_word_hs], [ln]) 245 | word_h, word_p = self.word_attn(query, word_hs, mask) 246 | return word_h, word_p 247 | 248 | def forward_word_rnn(self, rnn, bsz, inpts, lens, rev_idxs, hid_idxs): 249 | hs = [] 250 | for inpt, ln, rev_idx, hid_idx in zip(inpts, lens, rev_idxs, hid_idxs): 251 | h, _ = rnn(inpt, self.zero_h(bsz)) 252 | hs.append(h) 253 | return hs 254 | 255 | def forward_sent_attn(self, query, word_hs): 256 | bsz = query.size(0) 257 | if len(word_hs) == 0: 258 | sent_h = Variable(torch.Tensor(1, bsz, 2 * self.hid_size).fill_(0)) 259 | else: 260 | word_hs = torch.cat(word_hs, 0) 261 | 262 | sent_h, _ = self.sent_rnn(word_hs, self.zero_h(bsz, n=2)) 263 | sent_h = self.sent_dropout(sent_h) 264 | 265 | sent_h, sent_p = self.sent_attn(query, sent_h) 266 | return sent_h, sent_p 267 | 268 | def forward(self, query, fwd_inpts, fwd_lens, rev_idxs, hid_idxs): 269 | # query: [batch_size, query_size] 270 | # fwd_inpts: [length, batch_size, value_size] 271 | bwd_inpts, bwd_lens = self.reverse(fwd_inpts, fwd_lens, rev_idxs) 272 | 273 | bsz = query.size(0) 274 | fwd_word_hs = self.forward_word_rnn(self.fwd_word_rnn, bsz, fwd_inpts, 275 | fwd_lens, rev_idxs, hid_idxs) 276 | bwd_word_hs = self.forward_word_rnn(self.bwd_word_rnn, bsz, bwd_inpts, 277 | reversed(fwd_lens), reversed(rev_idxs), reversed(hid_idxs)) 278 | 279 | iterator = zip(fwd_word_hs, reversed(bwd_word_hs), fwd_lens, rev_idxs, hid_idxs) 280 | word_hs, word_ps = [], [] 281 | for fwd_word_h, bwd_word_h, ln, rev_idx, hid_idx in iterator: 282 | word_h, word_p = self.forward_word_attn(query, fwd_word_h, 283 | bwd_word_h, ln, rev_idx, hid_idx) 284 | word_hs.append(word_h) 285 | word_ps.append(word_p) 286 | 287 | sent_h, sent_p = self.forward_sent_attn(query, word_hs) 288 | # remove the length dimension 289 | sent_h = sent_h.squeeze(0) 290 | 291 | return (sent_h, sent_p), (word_hs, word_ps) 292 | 293 | 294 | class SentenceAttention(ChunkedAttention): 295 | def __init__(self, query_size, value_size, hid_size, dropout, init_range): 296 | super(SentenceAttention, self).__init__( 297 | query_size, value_size, hid_size, init_range) 298 | 299 | self.word_dropout = nn.Dropout(dropout) 300 | 301 | self.fwd_word_rnn = nn.GRU(value_size, hid_size, bias=True) 302 | self.bwd_word_rnn = nn.GRU(value_size, hid_size, bias=True) 303 | self.word_attn = Attention(query_size, 2 * hid_size, hid_size, init_range) 304 | 305 | init_rnn(self.fwd_word_rnn, init_range) 306 | init_rnn(self.bwd_word_rnn, init_range) 307 | 308 | def flatten_parameters(self): 309 | self.fwd_word_rnn.flatten_parameters() 310 | self.bwd_word_rnn.flatten_parameters() 311 | 312 | def forward_word_attn(self, query, fwd_word_hs, bwd_word_hs, ln, rev_idx, hid_idx): 313 | # reverse bwd_word_h 314 | bwd_word_hs = bwd_word_hs.gather(0, 315 | rev_idx.expand(rev_idx.size(0), rev_idx.size(1), bwd_word_hs.size(2))) 316 | 317 | word_hs = torch.cat([fwd_word_hs, bwd_word_hs], 2) 318 | word_hs = self.word_dropout(word_hs) 319 | 320 | mask = self.make_mask([fwd_word_hs], [ln]) 321 | word_h, word_p = self.word_attn(query, word_hs, mask) 322 | return word_h, word_p 323 | 324 | def forward_word_rnn(self, rnn, bsz, inpts, lens, rev_idxs, hid_idxs): 325 | hs = [] 326 | for inpt, ln, rev_idx, hid_idx in zip(inpts, lens, rev_idxs, hid_idxs): 327 | h, _ = rnn(inpt, self.zero_h(bsz)) 328 | hs.append(h) 329 | return hs 330 | 331 | def forward(self, query, fwd_inpt, fwd_len, rev_idx, hid_idx): 332 | # query: [batch_size, query_size] 333 | # fwd_inpts: [length, batch_size, value_size] 334 | bwd_inpts, bwd_lens = self.reverse([fwd_inpt], [fwd_len], [rev_idx]) 335 | 336 | bsz = query.size(0) 337 | fwd_word_hs = self.forward_word_rnn(self.fwd_word_rnn, bsz, [fwd_inpt], 338 | [fwd_len], [rev_idx], [hid_idx]) 339 | bwd_word_hs = self.forward_word_rnn(self.bwd_word_rnn, bsz, bwd_inpts, 340 | reversed([fwd_len]), reversed([rev_idx]), reversed([hid_idx])) 341 | 342 | iterator = zip(fwd_word_hs, reversed(bwd_word_hs), [fwd_len], [rev_idx], [hid_idx]) 343 | word_hs, word_ps = [], [] 344 | for fwd_word_h, bwd_word_h, ln, rev_idx, hid_idx in iterator: 345 | word_h, word_p = self.forward_word_attn(query, fwd_word_h, 346 | bwd_word_h, ln, rev_idx, hid_idx) 347 | word_hs.append(word_h) 348 | word_ps.append(word_p) 349 | 350 | return word_hs[0].squeeze(0) 351 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import random 9 | import sys 10 | import pdb 11 | import copy 12 | import re 13 | from collections import OrderedDict 14 | 15 | import torch 16 | import numpy as np 17 | 18 | 19 | SPECIAL = [ 20 | '', 21 | '', 22 | '', 23 | '', 24 | ] 25 | 26 | STOP_TOKENS = [ 27 | '', 28 | '', 29 | ] 30 | 31 | 32 | def get_tag(tokens, tag): 33 | return tokens[tokens.index('<' + tag + '>') + 1:tokens.index('')] 34 | 35 | 36 | def read_lines(file_name): 37 | assert os.path.exists(file_name), 'file does not exists %s' % file_name 38 | lines = [] 39 | with open(file_name, 'r') as f: 40 | for line in f: 41 | lines.append(line.strip()) 42 | return lines 43 | 44 | 45 | class Dictionary(object): 46 | def __init__(self, init=True): 47 | self.word2idx = OrderedDict() 48 | self.idx2word = [] 49 | if init: 50 | for i, k in enumerate(SPECIAL): 51 | self.word2idx[k] = i 52 | self.idx2word.append(k) 53 | 54 | def add_word(self, word): 55 | if word not in self.word2idx: 56 | self.word2idx[word] = len(self.idx2word) 57 | self.idx2word.append(word) 58 | return self.word2idx[word] 59 | 60 | def i2w(self, idx): 61 | return [self.idx2word[i] for i in idx] 62 | 63 | def w2i(self, words): 64 | unk = self.word2idx.get('', None) 65 | return [self.word2idx.get(w, unk) for w in words] 66 | 67 | def get_idx(self, word): 68 | unk = self.word2idx.get('', None) 69 | return self.word2idx.get(word, unk) 70 | 71 | def get_word(self, idx): 72 | return self.idx2word[idx] 73 | 74 | def __len__(self): 75 | return len(self.idx2word) 76 | 77 | def read_tag(domain, file_name, tag, freq_cutoff=-1, init_dict=True): 78 | token_freqs = OrderedDict() 79 | with open(file_name, 'r') as f: 80 | for line in f: 81 | tokens = line.strip().split() 82 | tokens = get_tag(tokens, tag) 83 | for token in tokens: 84 | token_freqs[token] = token_freqs.get(token, 0) + 1 85 | dictionary = Dictionary(init=init_dict) 86 | token_freqs = sorted(token_freqs.items(), key=lambda x: x[1], reverse=True) 87 | for token, freq in token_freqs: 88 | if freq > freq_cutoff: 89 | dictionary.add_word(token) 90 | return dictionary 91 | 92 | 93 | class ItemDictionary(Dictionary): 94 | def __init__(self, selection_size, init=True): 95 | super(ItemDictionary, self).__init__(init) 96 | self.selection_size = selection_size 97 | 98 | def w2i(self, words, inv=False): 99 | # pick last selection_size if inv=True, otherwise first selection_size 100 | words = words[self.selection_size:] if inv else words[:self.selection_size] 101 | token = ' '.join(words) 102 | return self.word2idx[token] 103 | 104 | def read_tag(domain, file_name, tag, init_dict=False): 105 | dictionary = ItemDictionary(domain.selection_length() // 2, init=False) 106 | 107 | def generate(item_id, selection=[]): 108 | if item_id >= dictionary.selection_size: 109 | dictionary.add_word(' '.join(selection)) 110 | return 111 | for i in range(5): 112 | selection.append('item%d=%d' % (item_id, i)) 113 | generate(item_id + 1, selection) 114 | selection.pop() 115 | 116 | generate(0) 117 | 118 | for token in ['', '', '']: 119 | dictionary.add_word(' '.join([token] * dictionary.selection_size)) 120 | 121 | return dictionary 122 | 123 | 124 | class CountDictionary(Dictionary): 125 | def __init__(self, init=True): 126 | super(CountDictionary, self).__init__(init) 127 | 128 | def get_key(words): 129 | key = '_'.join(words[i] for i in range(0, len(words), 2)) 130 | return key 131 | 132 | def get_idx(self, words): 133 | key = CountDictionary.get_key(words) 134 | return self.word2idx[key] 135 | 136 | def read_tag(domain, file_name, tag, init_dict=False): 137 | token_freqs = OrderedDict() 138 | with open(file_name, 'r') as f: 139 | for line in f: 140 | tokens = line.strip().split() 141 | tokens = get_tag(tokens, tag) 142 | key = CountDictionary.get_key(tokens) 143 | token_freqs[key] = token_freqs.get(key, 0) + 1 144 | token_freqs = sorted(token_freqs.items(), key=lambda x: x[1], reverse=True) 145 | dictionary = CountDictionary(init=init_dict) 146 | for token, freq in token_freqs: 147 | dictionary.add_word(token) 148 | return dictionary 149 | 150 | 151 | def create_dicts_from_file(domain, file_name, freq_cutoff): 152 | assert os.path.exists(file_name) 153 | word_dict = Dictionary.read_tag(domain, file_name, 'dialogue', freq_cutoff=freq_cutoff) 154 | item_dict = ItemDictionary.read_tag(domain, file_name, 'output', init_dict=False) 155 | item_dict_old = Dictionary.read_tag(domain, file_name, 'output', init_dict=False) 156 | context_dict = Dictionary.read_tag(domain, file_name, 'input', init_dict=False) 157 | count_dict = CountDictionary.read_tag(domain, file_name, 'input', init_dict=False) 158 | return word_dict, item_dict, context_dict, item_dict_old, count_dict 159 | 160 | 161 | class WordCorpus(object): 162 | def __init__(self, domain, path, freq_cutoff=2, train='train.txt', 163 | valid='val.txt', test='test.txt', verbose=False, sep_sel=False): 164 | self.domain = domain 165 | self.verbose = verbose 166 | self.sep_sel = sep_sel 167 | # Only add words from the train dataset 168 | self.word_dict, self.item_dict, self.context_dict, self.item_dict_old, self.count_dict = create_dicts_from_file( 169 | domain, 170 | os.path.join(path, train), 171 | freq_cutoff=freq_cutoff) 172 | 173 | self.train = self.tokenize(os.path.join(path, train)) if train else [] 174 | self.valid = self.tokenize(os.path.join(path, valid)) if valid else [] 175 | self.test = self.tokenize(os.path.join(path, test)) if test else [] 176 | 177 | # find out the output length from the train dataset 178 | self.output_length = max([len(x[2]) for x in self.train]) 179 | 180 | def tokenize(self, file_name): 181 | lines = read_lines(file_name) 182 | random.shuffle(lines) 183 | 184 | def make_mask(choices, inv=False): 185 | items = torch.Tensor([self.item_dict.w2i(c, inv=inv) for c in choices]).long() 186 | mask = torch.Tensor(len(self.item_dict)).zero_() 187 | mask.scatter_(0, items, torch.Tensor(items.size(0)).fill_(1)) 188 | return mask.unsqueeze(0) 189 | 190 | def make_indexes(choices): 191 | items = torch.Tensor([self.item_dict.w2i(c) for c in choices]).long() 192 | return items 193 | 194 | 195 | unk = self.word_dict.get_idx('') 196 | dataset, total, unks = [], 0, 0 197 | for line in lines: 198 | tokens = line.split() 199 | input_idxs = self.context_dict.w2i(get_tag(tokens, 'input')) 200 | count_idx = self.count_dict.get_idx(get_tag(tokens, 'input')) 201 | word_idxs = self.word_dict.w2i(get_tag(tokens, 'dialogue')) 202 | item_idx = self.item_dict.w2i(get_tag(tokens, 'output'), inv=False) 203 | item_idx_inv = self.item_dict.w2i(get_tag(tokens, 'output'), inv=True) 204 | items = self.item_dict_old.w2i(get_tag(tokens, 'output')) 205 | 206 | #valid_choices = self.domain.generate_choices(get_tag(tokens, 'input'), with_disagreement=False) 207 | #valid_mask = make_mask(valid_choices) 208 | partner_input_idxs = self.context_dict.w2i(get_tag(tokens, 'partner_input')) 209 | if self.sep_sel: 210 | dataset.append((input_idxs, word_idxs, items, partner_input_idxs, count_idx)) 211 | else: 212 | dataset.append((input_idxs, word_idxs, [item_idx, item_idx_inv], 213 | partner_input_idxs, count_idx)) 214 | 215 | total += len(input_idxs) + len(word_idxs) + len(partner_input_idxs) 216 | unks += np.count_nonzero([idx == unk for idx in word_idxs]) 217 | 218 | if self.verbose: 219 | print('dataset %s, total %d, unks %s, ratio %0.2f%%' % ( 220 | file_name, total, unks, 100. * unks / total)) 221 | return dataset 222 | 223 | def train_dataset(self, bsz, shuffle=True): 224 | return self._split_into_batches(copy.copy(self.train), bsz, shuffle=shuffle) 225 | 226 | def valid_dataset(self, bsz, shuffle=True): 227 | return self._split_into_batches(copy.copy(self.valid), bsz, shuffle=shuffle) 228 | 229 | def test_dataset(self, bsz, shuffle=True): 230 | return self._split_into_batches(copy.copy(self.test), bsz, shuffle=shuffle) 231 | 232 | def _split_into_batches(self, dataset, bsz, shuffle=True): 233 | if shuffle: 234 | random.shuffle(dataset) 235 | 236 | # Sort and pad 237 | dataset.sort(key=lambda x: len(x[1])) 238 | pad = self.word_dict.get_idx('') 239 | 240 | batches = [] 241 | stats = { 242 | 'n': 0, 243 | 'nonpadn': 0 244 | } 245 | 246 | for i in range(0, len(dataset), bsz): 247 | inputs, words = [], [] 248 | if self.sep_sel: 249 | items = [] 250 | else: 251 | items = [[], []] 252 | for j in range(i, min(i + bsz, len(dataset))): 253 | inputs.append(dataset[j][0]) 254 | words.append(dataset[j][1]) 255 | if self.sep_sel: 256 | items.append(dataset[j][2]) 257 | else: 258 | for k in range(2): 259 | items[k].append(dataset[j][2][k]) 260 | 261 | max_len = len(words[-1]) 262 | 263 | for j in range(len(words)): 264 | stats['n'] += max_len 265 | stats['nonpadn'] += len(words[j]) 266 | words[j] += [pad] * (max_len - len(words[j])) 267 | 268 | ctx = torch.Tensor(inputs).long().transpose(0, 1).contiguous() 269 | data = torch.Tensor(words).long().transpose(0, 1).contiguous() 270 | if self.sep_sel: 271 | sel_tgt = torch.Tensor(items).long().transpose(0, 1).contiguous().view(-1) 272 | else: 273 | sel_tgt = [torch.Tensor(it).long().view(-1) for it in items] 274 | 275 | inpt = data.narrow(0, 0, data.size(0) - 1) 276 | tgt = data.narrow(0, 1, data.size(0) - 1).view(-1) 277 | 278 | batches.append((ctx, inpt, tgt, sel_tgt)) 279 | 280 | if shuffle: 281 | random.shuffle(batches) 282 | 283 | return batches, stats 284 | 285 | 286 | class SentenceCorpus(WordCorpus): 287 | def _split_into_sentences(self, dataset): 288 | stops = [self.word_dict.get_idx(w) for w in ['YOU:', 'THEM:']] 289 | sent_dataset = [] 290 | for ctx, words, items, partner_ctx, count_idx in dataset: 291 | sents, current = [], [] 292 | for w in words: 293 | if w in stops: 294 | if len(current) > 0: 295 | sents.append(current) 296 | current = [] 297 | current.append(w) 298 | if len(current) > 0: 299 | sents.append(current) 300 | sent_dataset.append((ctx, sents, items, partner_ctx, count_idx)) 301 | # Sort by numper of sentences in a dialog 302 | sent_dataset.sort(key=lambda x: len(x[1])) 303 | 304 | return sent_dataset 305 | 306 | def _make_reverse_idxs(self, inpts, lens): 307 | idxs = [] 308 | for inpt, ln in zip(inpts, lens): 309 | idx = torch.Tensor(inpt.size(0), inpt.size(1), 1).long().fill_(-1) 310 | for i in range(inpt.size(1)): 311 | arngmt = torch.Tensor(inpt.size(0), 1, 1).long() 312 | for j in range(arngmt.size(0)): 313 | arngmt[j][0][0] = j if j > ln[i] else ln[i] - j 314 | idx.narrow(1, i, 1).copy_(arngmt) 315 | idxs.append(idx) 316 | return idxs 317 | 318 | def _make_hidden_idxs(self, lens): 319 | idxs = [] 320 | for s, ln in enumerate(lens): 321 | idx = torch.Tensor(1, ln.size(0), 1).long() 322 | for i in range(ln.size(0)): 323 | idx[0][i][0] = ln[i] 324 | idxs.append(idx) 325 | return idxs 326 | 327 | def _split_into_batches(self, dataset, bsz, shuffle=True): 328 | if shuffle: 329 | random.shuffle(dataset) 330 | 331 | dataset = self._split_into_sentences(dataset) 332 | 333 | pad = self.word_dict.get_idx('') 334 | 335 | batches = [] 336 | stats = { 337 | 'n': 0, 338 | 'nonpadn': 0 339 | } 340 | 341 | i = 0 342 | while i < len(dataset): 343 | dial_len = len(dataset[i][1]) 344 | 345 | ctxs, dials, partner_ctxs, count_idxs = [], [], [], [] 346 | if self.sep_sel: 347 | items = [] 348 | else: 349 | items = [[], []] 350 | for _ in range(bsz): 351 | if i >= len(dataset) or len(dataset[i][1]) != dial_len: 352 | break 353 | ctxs.append(dataset[i][0]) 354 | dials.append(dataset[i][1]) 355 | if self.sep_sel: 356 | items.append(dataset[i][2]) 357 | else: 358 | for j in range(2): 359 | items[j].append(dataset[i][2][j]) 360 | partner_ctxs.append(dataset[i][3]) 361 | count_idxs.append(dataset[i][4]) 362 | i += 1 363 | 364 | inpts, lens, tgts = [], [], [] 365 | for s in range(dial_len): 366 | batch = [] 367 | for dial in dials: 368 | batch.append(dial[s]) 369 | if s + 1 < dial_len: 370 | # add YOU:/THEM: as the last tokens in order to connect sentences 371 | for j in range(len(batch)): 372 | batch[j].append(dials[j][s + 1][0]) 373 | else: 374 | # add after 375 | for j in range(len(batch)): 376 | batch[j].append(pad) 377 | 378 | 379 | max_len = max([len(sent) for sent in batch]) 380 | ln = torch.LongTensor(len(batch)) 381 | for j in range(len(batch)): 382 | stats['n'] += max_len 383 | stats['nonpadn'] += len(batch[j]) - 1 384 | ln[j] = len(batch[j]) - 2 385 | batch[j] += [pad] * (max_len - len(batch[j])) 386 | sent = torch.Tensor(batch).long().transpose(0, 1).contiguous() 387 | inpt = sent.narrow(0, 0, sent.size(0) - 1) 388 | tgt = sent.narrow(0, 1, sent.size(0) - 1).view(-1) 389 | inpts.append(inpt) 390 | lens.append(ln) 391 | tgts.append(tgt) 392 | 393 | ctx = torch.Tensor(ctxs).long().transpose(0, 1).contiguous() 394 | partner_ctx = torch.Tensor(partner_ctxs).long()#.view(-1).contiguous() #.transpose(0, 1).contiguous() 395 | if self.sep_sel: 396 | sel_tgt = torch.Tensor(items).long().view(-1) 397 | else: 398 | sel_tgt = [torch.Tensor(it).long().view(-1) for it in items] 399 | 400 | cnt = torch.Tensor(count_idxs).long() 401 | 402 | rev_idxs = self._make_reverse_idxs(inpts, lens) 403 | hid_idxs = self._make_hidden_idxs(lens) 404 | 405 | batches.append((ctx, partner_ctx, inpts, lens, tgts, sel_tgt, rev_idxs, hid_idxs, cnt)) 406 | 407 | if shuffle: 408 | random.shuffle(batches) 409 | 410 | return batches, stats 411 | 412 | 413 | class PhraseCorpus(WordCorpus): 414 | def tokenize(self, file_name): 415 | lines = read_lines(file_name) 416 | random.shuffle(lines) 417 | 418 | unk = self.word_dict.get_idx('') 419 | dataset, total, unks = [], 0, 0 420 | for line in lines: 421 | tokens = line.split() 422 | dialog, items = get_tag(tokens, 'dialogue'), get_tag(tokens, 'output') 423 | 424 | for words in re.split(r'(?:YOU|THEM):', ' '.join(dialog))[1:-1]: 425 | if words is '': 426 | continue 427 | words = words.strip().split(' ') 428 | word_idxs = self.word_dict.w2i(words) 429 | item_idxs = self.item_dict.w2i(items) 430 | dataset.append(word_idxs) 431 | total += len(word_idxs) + len(item_idxs) 432 | unks += np.count_nonzero([idx == unk for idx in word_idxs]) 433 | 434 | if self.verbose: 435 | print('dataset %s, total %d, unks %s, ratio %0.2f%%' % ( 436 | file_name, total, unks, 100. * unks / total)) 437 | return dataset 438 | 439 | def _split_into_batches(self, dataset, bsz, shuffle=True): 440 | if shuffle: 441 | random.shuffle(dataset) 442 | 443 | # Sort and pad 444 | dataset.sort(key=lambda x: len(x)) 445 | pad = self.word_dict.get_idx('') 446 | 447 | batches = [] 448 | stats = { 449 | 'n': 0, 450 | 'nonpadn': 0 451 | } 452 | 453 | for i in range(0, len(dataset), bsz): 454 | words = [] 455 | for j in range(i, min(i + bsz, len(dataset))): 456 | words.append(dataset[j]) 457 | 458 | max_len = len(words[-1]) 459 | 460 | for j in range(len(words)): 461 | stats['n'] += max_len 462 | stats['nonpadn'] += len(words[j]) 463 | words[j] += [pad] * (max_len - len(words[j])) 464 | 465 | data = torch.Tensor(words).transpose(0, 1).long().contiguous() 466 | # inpt = data.narrow(0, CONTEXT_LENGTH, data.size(0) - 1 - CONTEXT_LENGTH) 467 | 468 | batches.append(data) 469 | 470 | if shuffle: 471 | random.shuffle(batches) 472 | 473 | return batches, stats 474 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | 142 | Section 2 -- Scope. 143 | 144 | a. License grant. 145 | 146 | 1. Subject to the terms and conditions of this Public License, 147 | the Licensor hereby grants You a worldwide, royalty-free, 148 | non-sublicensable, non-exclusive, irrevocable license to 149 | exercise the Licensed Rights in the Licensed Material to: 150 | 151 | a. reproduce and Share the Licensed Material, in whole or 152 | in part, for NonCommercial purposes only; and 153 | 154 | b. produce, reproduce, and Share Adapted Material for 155 | NonCommercial purposes only. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. No downstream restrictions. You may not offer or impose 186 | any additional or different terms or conditions on, or 187 | apply any Effective Technological Measures to, the 188 | Licensed Material if doing so restricts exercise of the 189 | Licensed Rights by any recipient of the Licensed 190 | Material. 191 | 192 | 6. No endorsement. Nothing in this Public License constitutes or 193 | may be construed as permission to assert or imply that You 194 | are, or that Your use of the Licensed Material is, connected 195 | with, or sponsored, endorsed, or granted official status by, 196 | the Licensor or others designated to receive attribution as 197 | provided in Section 3(a)(1)(A)(i). 198 | 199 | b. Other rights. 200 | 201 | 1. Moral rights, such as the right of integrity, are not 202 | licensed under this Public License, nor are publicity, 203 | privacy, and/or other similar personality rights; however, to 204 | the extent possible, the Licensor waives and/or agrees not to 205 | assert any such rights held by the Licensor to the limited 206 | extent necessary to allow You to exercise the Licensed 207 | Rights, but not otherwise. 208 | 209 | 2. Patent and trademark rights are not licensed under this 210 | Public License. 211 | 212 | 3. To the extent possible, the Licensor waives any right to 213 | collect royalties from You for the exercise of the Licensed 214 | Rights, whether directly or through a collecting society 215 | under any voluntary or waivable statutory or compulsory 216 | licensing scheme. In all other cases the Licensor expressly 217 | reserves any right to collect such royalties, including when 218 | the Licensed Material is used other than for NonCommercial 219 | purposes. 220 | 221 | 222 | Section 3 -- License Conditions. 223 | 224 | Your exercise of the Licensed Rights is expressly made subject to the 225 | following conditions. 226 | 227 | a. Attribution. 228 | 229 | 1. If You Share the Licensed Material (including in modified 230 | form), You must: 231 | 232 | a. retain the following if it is supplied by the Licensor 233 | with the Licensed Material: 234 | 235 | i. identification of the creator(s) of the Licensed 236 | Material and any others designated to receive 237 | attribution, in any reasonable manner requested by 238 | the Licensor (including by pseudonym if 239 | designated); 240 | 241 | ii. a copyright notice; 242 | 243 | iii. a notice that refers to this Public License; 244 | 245 | iv. a notice that refers to the disclaimer of 246 | warranties; 247 | 248 | v. a URI or hyperlink to the Licensed Material to the 249 | extent reasonably practicable; 250 | 251 | b. indicate if You modified the Licensed Material and 252 | retain an indication of any previous modifications; and 253 | 254 | c. indicate the Licensed Material is licensed under this 255 | Public License, and include the text of, or the URI or 256 | hyperlink to, this Public License. 257 | 258 | 2. You may satisfy the conditions in Section 3(a)(1) in any 259 | reasonable manner based on the medium, means, and context in 260 | which You Share the Licensed Material. For example, it may be 261 | reasonable to satisfy the conditions by providing a URI or 262 | hyperlink to a resource that includes the required 263 | information. 264 | 265 | 3. If requested by the Licensor, You must remove any of the 266 | information required by Section 3(a)(1)(A) to the extent 267 | reasonably practicable. 268 | 269 | 4. If You Share Adapted Material You produce, the Adapter's 270 | License You apply must not prevent recipients of the Adapted 271 | Material from complying with this Public License. 272 | 273 | 274 | Section 4 -- Sui Generis Database Rights. 275 | 276 | Where the Licensed Rights include Sui Generis Database Rights that 277 | apply to Your use of the Licensed Material: 278 | 279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 280 | to extract, reuse, reproduce, and Share all or a substantial 281 | portion of the contents of the database for NonCommercial purposes 282 | only; 283 | 284 | b. if You include all or a substantial portion of the database 285 | contents in a database in which You have Sui Generis Database 286 | Rights, then the database in which You have Sui Generis Database 287 | Rights (but not its individual contents) is Adapted Material; and 288 | 289 | c. You must comply with the conditions in Section 3(a) if You Share 290 | all or a substantial portion of the contents of the database. 291 | 292 | For the avoidance of doubt, this Section 4 supplements and does not 293 | replace Your obligations under this Public License where the Licensed 294 | Rights include other Copyright and Similar Rights. 295 | 296 | 297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 298 | 299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 309 | 310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 319 | 320 | c. The disclaimer of warranties and limitation of liability provided 321 | above shall be interpreted in a manner that, to the extent 322 | possible, most closely approximates an absolute disclaimer and 323 | waiver of all liability. 324 | 325 | 326 | Section 6 -- Term and Termination. 327 | 328 | a. This Public License applies for the term of the Copyright and 329 | Similar Rights licensed here. However, if You fail to comply with 330 | this Public License, then Your rights under this Public License 331 | terminate automatically. 332 | 333 | b. Where Your right to use the Licensed Material has terminated under 334 | Section 6(a), it reinstates: 335 | 336 | 1. automatically as of the date the violation is cured, provided 337 | it is cured within 30 days of Your discovery of the 338 | violation; or 339 | 340 | 2. upon express reinstatement by the Licensor. 341 | 342 | For the avoidance of doubt, this Section 6(b) does not affect any 343 | right the Licensor may have to seek remedies for Your violations 344 | of this Public License. 345 | 346 | c. For the avoidance of doubt, the Licensor may also offer the 347 | Licensed Material under separate terms or conditions or stop 348 | distributing the Licensed Material at any time; however, doing so 349 | will not terminate this Public License. 350 | 351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 352 | License. 353 | 354 | 355 | Section 7 -- Other Terms and Conditions. 356 | 357 | a. The Licensor shall not be bound by any additional or different 358 | terms or conditions communicated by You unless expressly agreed. 359 | 360 | b. Any arrangements, understandings, or agreements regarding the 361 | Licensed Material not stated herein are separate from and 362 | independent of the terms and conditions of this Public License. 363 | 364 | 365 | Section 8 -- Interpretation. 366 | 367 | a. For the avoidance of doubt, this Public License does not, and 368 | shall not be interpreted to, reduce, limit, restrict, or impose 369 | conditions on any use of the Licensed Material that could lawfully 370 | be made without permission under this Public License. 371 | 372 | b. To the extent possible, if any provision of this Public License is 373 | deemed unenforceable, it shall be automatically reformed to the 374 | minimum extent necessary to make it enforceable. If the provision 375 | cannot be reformed, it shall be severed from this Public License 376 | without affecting the enforceability of the remaining terms and 377 | conditions. 378 | 379 | c. No term or condition of this Public License will be waived and no 380 | failure to comply consented to unless expressly agreed to by the 381 | Licensor. 382 | 383 | d. Nothing in this Public License constitutes or may be interpreted 384 | as a limitation upon, or waiver of, any privileges and immunities 385 | that apply to the Licensor or You, including from the legal 386 | processes of any jurisdiction or authority. 387 | 388 | ======================================================================= 389 | 390 | Creative Commons is not a party to its public 391 | licenses. Notwithstanding, Creative Commons may elect to apply one of 392 | its public licenses to material it publishes and in those instances 393 | will be considered the “Licensor.” The text of the Creative Commons 394 | public licenses is dedicated to the public domain under the CC0 Public 395 | Domain Dedication. Except for the limited purpose of indicating that 396 | material is shared under a Creative Commons public license or as 397 | otherwise permitted by the Creative Commons policies published at 398 | creativecommons.org/policies, Creative Commons does not authorize the 399 | use of the trademark "Creative Commons" or any other trademark or logo 400 | of Creative Commons without its prior written consent including, 401 | without limitation, in connection with any unauthorized modifications 402 | to any of its public licenses or any other arrangements, 403 | understandings, or agreements concerning use of licensed material. For 404 | the avoidance of doubt, this paragraph does not form part of the 405 | public licenses. 406 | 407 | Creative Commons may be contacted at creativecommons.org. 408 | 409 | -------------------------------------------------------------------------------- /src/engines/latent_clustering_engine.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017-present, Facebook, Inc. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import time 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.autograd import Variable 12 | import torch.nn.functional as F 13 | from torch import optim 14 | 15 | from engines import EngineBase, Criterion 16 | import utils 17 | 18 | 19 | class LatentClusteringEngine(EngineBase): 20 | def __init__(self, model, args, verbose=False): 21 | super(LatentClusteringEngine, self).__init__(model, args, verbose) 22 | self.crit = nn.CrossEntropyLoss(reduction='sum') 23 | self.kldiv = nn.KLDivLoss(reduction='sum') 24 | self.cluster_crit = nn.NLLLoss(reduction='sum') 25 | self.sel_crit = Criterion( 26 | self.model.item_dict, 27 | bad_toks=['', ''], 28 | reduction='mean' if args.sep_sel else 'none') 29 | 30 | self.sel_model = utils.load_model(args.selection_model_file) 31 | self.sel_model.eval() 32 | 33 | def _make_sel_tgt_probs(self, inpts, lens, rev_idxs, hid_idxs, ctx): 34 | sel_tgt_probs = [] 35 | for i in range(len(inpts)): 36 | sel_prob = self.sel_model(inpts[:i], lens[:i], rev_idxs[:i], hid_idxs[:i], ctx) 37 | sel_tgt_probs.append(F.softmax(sel_prob.detach(), dim=1)) 38 | return sel_tgt_probs 39 | 40 | def _append_pad(self, inpts, tgts, sel_tgt_probs, lens, rev_idxs, hid_idxs): 41 | bsz = inpts[0].size(1) 42 | pad = torch.Tensor(bsz).fill_(self.model.word_dict.get_idx('')).long() 43 | inpts.append(Variable(pad.unsqueeze(0))) 44 | tgts.append(Variable(pad)) 45 | sel_tgt_probs.append(sel_tgt_probs[-1].clone()) 46 | lens.append(torch.Tensor(bsz).cpu().fill_(0).long()) 47 | rev_idxs.append(torch.Tensor(1, bsz, 1).fill_(0).long()) 48 | hid_idxs.append(torch.Tensor(1, bsz, 1).fill_(0).long()) 49 | return inpts, tgts, sel_tgt_probs, lens, rev_idxs, hid_idxs 50 | 51 | def _forward(self, batch, norm_lang=True): 52 | ctx, _, inpts, lens, tgts, sel_tgt, rev_idxs, hid_idxs, cnt = batch 53 | ctx = Variable(ctx) 54 | cnt = Variable(cnt) 55 | inpts = [Variable(inpt) for inpt in inpts] 56 | tgts = [Variable(tgt) for tgt in tgts] 57 | rev_idxs = [Variable(idx) for idx in rev_idxs] 58 | hid_idxs = [Variable(idx) for idx in hid_idxs] 59 | sel_tgt_probs = self._make_sel_tgt_probs(inpts, lens, rev_idxs, hid_idxs, ctx) 60 | sel_tgt = Variable(sel_tgt) 61 | 62 | inpts, tgts, sel_tgt_probs, lens, rev_idxs, hid_idxs = self._append_pad( 63 | inpts, tgts, sel_tgt_probs, lens, rev_idxs, hid_idxs) 64 | 65 | outs, sel_outs, z_probs, z_tgts, stats = self.model( 66 | inpts, tgts, sel_tgt_probs, hid_idxs, ctx, cnt) 67 | 68 | lang_loss, n = 0, 0 69 | for out, tgt, ln in zip(outs, tgts[1:], lens[1:]): 70 | lang_loss += self.crit(out, tgt) 71 | n += ln.sum() 72 | if norm_lang: 73 | lang_loss /= n 74 | 75 | z_loss, n = 0, 0 76 | for z_prob, z_tgt in zip(z_probs, z_tgts): 77 | z_loss += self.cluster_crit(z_prob.log(), z_tgt) 78 | n += z_tgt.size(0) 79 | z_loss /= n 80 | 81 | kldiv_loss, n = 0, 0 82 | for sel_out, sel_tgt_prob in zip(sel_outs, sel_tgt_probs[1:]): 83 | kldiv_loss += self.kldiv(F.log_softmax(sel_out, dim=1), sel_tgt_prob) 84 | n += sel_out.size(0) 85 | kldiv_loss /= n 86 | 87 | sel_loss = self.sel_crit(sel_outs[-1], sel_tgt) 88 | 89 | return lang_loss, sel_loss, kldiv_loss, z_loss, stats 90 | 91 | def combine_loss(self, lang_loss, select_loss): 92 | return lang_loss + select_loss 93 | 94 | def train_batch(self, batch): 95 | lang_loss, sel_loss, kldiv_loss, z_loss, stats = self._forward( 96 | batch) 97 | 98 | loss = lang_loss + sel_loss + kldiv_loss + z_loss 99 | 100 | self.opt.zero_grad() 101 | loss.backward() 102 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip) 103 | self.opt.step() 104 | 105 | return loss.item(), stats 106 | 107 | def valid_batch(self, batch): 108 | with torch.no_grad(): 109 | lang_loss, sel_loss, kldiv_loss, z_loss, stats = self._forward( 110 | batch, norm_lang=False) 111 | 112 | return lang_loss.item(), sel_loss.item(), 0, stats 113 | 114 | def train_pass(self, trainset): 115 | self.model.train() 116 | 117 | total_loss = 0 118 | total_entropy = 0 119 | total_max_prob = 0 120 | total_top3_prob = 0 121 | total_enc_entropy = 0 122 | total_enc_max_prob = 0 123 | total_enc_top3_prob = 0 124 | 125 | start_time = time.time() 126 | 127 | for batch in trainset: 128 | self.t += 1 129 | loss, (entropy, max_prob, top3_prob, enc_entropy, enc_max_prob, enc_top3_prob) = self.train_batch(batch) 130 | 131 | if self.args.visual and self.t % 100 == 0: 132 | self.model_plot.update(self.t) 133 | 134 | total_loss += loss 135 | total_entropy += entropy 136 | total_max_prob += max_prob 137 | total_top3_prob += top3_prob 138 | total_enc_entropy += enc_entropy 139 | total_enc_max_prob += enc_max_prob 140 | total_enc_top3_prob += enc_top3_prob 141 | 142 | total_loss /= len(trainset) 143 | total_entropy /= len(trainset) 144 | total_max_prob /= len(trainset) 145 | total_top3_prob /= len(trainset) 146 | total_enc_entropy /= len(trainset) 147 | total_enc_max_prob /= len(trainset) 148 | total_enc_top3_prob /= len(trainset) 149 | 150 | time_elapsed = time.time() - start_time 151 | print('| train | avg entropy %.3f | avg max prob %.3f | avg top3 prob %.3f ' % ( 152 | total_entropy, total_max_prob, total_top3_prob)) 153 | print('| train | enc avg entropy %.3f | enc avg max prob %.3f | enc avg top3 prob %.3f ' % ( 154 | total_enc_entropy, total_enc_max_prob, total_enc_top3_prob)) 155 | 156 | return total_loss, time_elapsed 157 | 158 | def valid_pass(self, validset, validset_stats): 159 | self.model.eval() 160 | 161 | total_valid_loss, total_select_loss, total_partner_ctx_loss = 0, 0, 0 162 | total_entropy = 0 163 | total_max_prob = 0 164 | total_top3_prob = 0 165 | total_enc_entropy = 0 166 | total_enc_max_prob = 0 167 | total_enc_top3_prob = 0 168 | for batch in validset: 169 | valid_loss, select_loss, partner_ctx_loss, (entropy, max_prob, top3_prob, enc_entropy, enc_max_prob, enc_top3_prob) = self.valid_batch(batch) 170 | total_valid_loss += valid_loss 171 | total_select_loss += select_loss 172 | total_partner_ctx_loss += partner_ctx_loss 173 | total_entropy += entropy 174 | total_max_prob += max_prob 175 | total_top3_prob += top3_prob 176 | total_enc_entropy += enc_entropy 177 | total_enc_max_prob += enc_max_prob 178 | total_enc_top3_prob += enc_top3_prob 179 | 180 | # Dividing by the number of words in the input, not the tokens modeled, 181 | # because the latter includes padding 182 | total_valid_loss /= validset_stats['nonpadn'] 183 | total_select_loss /= len(validset) 184 | total_partner_ctx_loss /= len(validset) 185 | #total_future_loss /= len(validset) 186 | total_entropy /= len(validset) 187 | total_max_prob /= len(validset) 188 | total_top3_prob /= len(validset) 189 | total_enc_entropy /= len(validset) 190 | total_enc_max_prob /= len(validset) 191 | total_enc_top3_prob /= len(validset) 192 | 193 | print('| valid | avg entropy %.3f | avg max prob %.3f | avg top3 prob %.3f ' % ( 194 | total_entropy, total_max_prob, total_top3_prob)) 195 | print('| valid | enc avg entropy %.3f | enc avg max prob %.3f | enc avg top3 prob %.3f ' % ( 196 | total_enc_entropy, total_enc_max_prob, total_enc_top3_prob)) 197 | 198 | extra = { 199 | 'entropy': total_entropy, 200 | 'avg_max_prob': total_max_prob, 201 | 'avg_top3_prob': total_top3_prob, 202 | } 203 | 204 | return total_valid_loss, total_select_loss, total_partner_ctx_loss, extra 205 | 206 | 207 | class LatentClusteringPredictionEngine(EngineBase): 208 | def __init__(self, model, args, verbose=False): 209 | super(LatentClusteringPredictionEngine, self).__init__(model, args, verbose) 210 | self.crit = nn.CrossEntropyLoss(reduction='sum') 211 | self.model.train() 212 | 213 | self.sel_model = utils.load_model(args.selection_model_file) 214 | self.sel_model.eval() 215 | 216 | def _make_sel_tgt_probs(self, inpts, lens, rev_idxs, hid_idxs, ctx): 217 | sel_tgt_probs = [] 218 | for i in range(len(inpts)): 219 | sel_prob = self.sel_model(inpts[:i], lens[:i], rev_idxs[:i], hid_idxs[:i], ctx) 220 | sel_tgt_probs.append(F.softmax(sel_prob.detach(), dim=1)) 221 | return sel_tgt_probs 222 | 223 | def _forward(self, batch): 224 | ctx, _, inpts, lens, tgts, sel_tgt, rev_idxs, hid_idxs, cnt = batch 225 | ctx = Variable(ctx) 226 | cnt = Variable(cnt) 227 | inpts = [Variable(inpt) for inpt in inpts] 228 | tgts = [Variable(tgt) for tgt in tgts] 229 | rev_idxs = [Variable(idx) for idx in rev_idxs] 230 | hid_idxs = [Variable(idx) for idx in hid_idxs] 231 | sel_tgt_probs = self._make_sel_tgt_probs(inpts, lens, rev_idxs, hid_idxs, ctx) 232 | 233 | losses, stats = self.model.forward(inpts, tgts, hid_idxs, ctx, cnt) 234 | 235 | return losses, stats, lens 236 | 237 | def train_batch(self, batch): 238 | losses, stats, lens = self._forward(batch) 239 | 240 | lang_loss, n = 0, 0 241 | for l, ln in zip(losses, lens): 242 | lang_loss += l 243 | n += ln.sum() 244 | lang_loss /= n 245 | 246 | loss = lang_loss 247 | self.opt.zero_grad() 248 | loss.backward() 249 | 250 | # don't update clusters 251 | self.model.latent_bottleneck.zero_grad() 252 | # don't update language model 253 | self.model.lang_model.zero_grad() 254 | 255 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip) 256 | self.opt.step() 257 | return lang_loss.item(), stats 258 | 259 | def valid_batch(self, batch): 260 | with torch.no_grad(): 261 | losses, stats, lens = self._forward(batch) 262 | 263 | loss = 0 264 | for l in losses: 265 | loss += l 266 | 267 | return loss.item(), 0, 0, stats 268 | 269 | def train_pass(self, trainset): 270 | self.model.train() 271 | 272 | total_loss = 0 273 | total_entropy = 0 274 | total_max_prob = 0 275 | total_top3_prob = 0 276 | start_time = time.time() 277 | 278 | for batch in trainset: 279 | self.t += 1 280 | loss, (entropy, max_prob, top3_prob) = self.train_batch(batch) 281 | 282 | if self.args.visual and self.t % 100 == 0: 283 | self.model_plot.update(self.t) 284 | 285 | total_loss += loss 286 | total_entropy += entropy 287 | total_max_prob += max_prob 288 | total_top3_prob += top3_prob 289 | 290 | total_loss /= len(trainset) 291 | total_entropy /= len(trainset) 292 | total_max_prob /= len(trainset) 293 | total_top3_prob /= len(trainset) 294 | time_elapsed = time.time() - start_time 295 | print('| train | avg entropy %.3f | avg max prob %.3f | avg top3 prob %.3f' % ( 296 | total_entropy, total_max_prob, total_top3_prob)) 297 | 298 | return total_loss, time_elapsed 299 | 300 | def valid_pass(self, validset, validset_stats): 301 | self.model.eval() 302 | 303 | total_valid_loss, total_select_loss, total_partner_ctx_loss = 0, 0, 0 304 | total_entropy = 0 305 | total_max_prob = 0 306 | total_top3_prob = 0 307 | for batch in validset: 308 | valid_loss, select_loss, partner_ctx_loss, (entropy, max_prob, top3_prob) = self.valid_batch(batch) 309 | total_valid_loss += valid_loss 310 | total_select_loss += select_loss 311 | total_partner_ctx_loss += partner_ctx_loss 312 | total_entropy += entropy 313 | total_max_prob += max_prob 314 | total_top3_prob += top3_prob 315 | 316 | # Dividing by the number of words in the input, not the tokens modeled, 317 | # because the latter includes padding 318 | total_valid_loss /= validset_stats['nonpadn'] 319 | total_select_loss /= len(validset) 320 | total_partner_ctx_loss /= len(validset) 321 | #total_future_loss /= len(validset) 322 | total_entropy /= len(validset) 323 | total_max_prob /= len(validset) 324 | total_top3_prob /= len(validset) 325 | print('| valid | avg entropy %.3f | avg max prob %.3f | avg top3 prob %.3f' % ( 326 | total_entropy, total_max_prob, total_top3_prob)) 327 | 328 | extra = { 329 | 'entropy': total_entropy, 330 | 'avg_max_prob': total_max_prob, 331 | 'avg_top3_prob': total_top3_prob, 332 | } 333 | 334 | return total_valid_loss, total_select_loss, total_partner_ctx_loss, extra 335 | 336 | 337 | class LatentClusteringLanguageEngine(EngineBase): 338 | def __init__(self, model, args, verbose=False): 339 | super(LatentClusteringLanguageEngine, self).__init__(model, args, verbose) 340 | self.crit = nn.CrossEntropyLoss(reduction='sum') 341 | self.cluster_crit = nn.NLLLoss(reduction='sum') 342 | 343 | self.sel_model = utils.load_model(args.selection_model_file) 344 | self.sel_model.eval() 345 | 346 | def _make_sel_tgt_probs(self, inpts, lens, rev_idxs, hid_idxs, ctx): 347 | sel_tgt_probs = [] 348 | for i in range(len(inpts)): 349 | sel_prob = self.sel_model(inpts[:i], lens[:i], rev_idxs[:i], hid_idxs[:i], ctx) 350 | sel_tgt_probs.append(F.softmax(sel_prob.detach(), dim=1)) 351 | return sel_tgt_probs 352 | 353 | def _append_pad(self, inpts, tgts, sel_tgt_probs, lens, rev_idxs, hid_idxs): 354 | bsz = inpts[0].size(1) 355 | pad = torch.Tensor(bsz).fill_(self.model.word_dict.get_idx('')).long() 356 | inpts.append(Variable(pad.unsqueeze(0))) 357 | tgts.append(Variable(pad)) 358 | sel_tgt_probs.append(sel_tgt_probs[-1].clone()) 359 | lens.append(torch.Tensor(bsz).cpu().fill_(0).long()) 360 | rev_idxs.append(torch.Tensor(1, bsz, 1).fill_(0).long()) 361 | hid_idxs.append(torch.Tensor(1, bsz, 1).fill_(0).long()) 362 | return inpts, tgts, sel_tgt_probs, lens, rev_idxs, hid_idxs 363 | 364 | def _forward(self, model, batch, sep_sel=False, norm_lang=False): 365 | ctx, _, inpts, lens, tgts, sel_tgt, rev_idxs, hid_idxs, cnt = batch 366 | ctx = Variable(ctx) 367 | cnt = Variable(cnt) 368 | inpts = [Variable(inpt) for inpt in inpts] 369 | tgts = [Variable(tgt) for tgt in tgts] 370 | rev_idxs = [Variable(idx) for idx in rev_idxs] 371 | hid_idxs = [Variable(idx) for idx in hid_idxs] 372 | sel_tgt_probs = self._make_sel_tgt_probs(inpts, lens, rev_idxs, hid_idxs, ctx) 373 | 374 | inpts, tgts, sel_tgt_probs, lens, rev_idxs, hid_idxs = self._append_pad( 375 | inpts, tgts, sel_tgt_probs, lens, rev_idxs, hid_idxs) 376 | 377 | outs = model(inpts, tgts, sel_tgt_probs, hid_idxs, ctx, cnt) 378 | 379 | lang_loss, n = 0, 0 380 | for out, tgt, ln in zip(outs, tgts, lens): 381 | lang_loss += self.crit(out, tgt) 382 | n += ln.sum() 383 | if norm_lang: 384 | lang_loss /= n 385 | 386 | return lang_loss 387 | 388 | def train_batch(self, batch): 389 | lang_loss = self._forward( 390 | self.model, batch, sep_sel=self.args.sep_sel, norm_lang=True) 391 | 392 | loss = lang_loss 393 | 394 | self.opt.zero_grad() 395 | loss.backward() 396 | 397 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip) 398 | self.opt.step() 399 | return lang_loss.item() 400 | 401 | def valid_batch(self, batch): 402 | with torch.no_grad(): 403 | lang_loss = self._forward( 404 | self.model, batch, sep_sel=self.args.sep_sel, norm_lang=False) 405 | 406 | return lang_loss.item(), 0, 0 407 | 408 | 409 | class BaselineClusteringEngine(EngineBase): 410 | def __init__(self, model, args, verbose=False): 411 | super(BaselineClusteringEngine, self).__init__(model, args, verbose) 412 | self.model.train() 413 | 414 | def _forward(self, batch): 415 | ctx, _, inpts, lens, tgts, sel_tgt, rev_idxs, hid_idxs, cnt = batch 416 | ctx = Variable(ctx) 417 | cnt = Variable(cnt) 418 | inpts = [Variable(inpt) for inpt in inpts] 419 | tgts = [Variable(tgt) for tgt in tgts] 420 | rev_idxs = [Variable(idx) for idx in rev_idxs] 421 | hid_idxs = [Variable(idx) for idx in hid_idxs] 422 | 423 | losses, stats = self.model.forward(inpts, tgts, hid_idxs, ctx, cnt) 424 | 425 | return losses, stats, lens 426 | 427 | def train_batch(self, batch): 428 | losses, stats, lens = self._forward(batch) 429 | 430 | lang_loss, n = 0, 0 431 | for l, ln in zip(losses, lens): 432 | lang_loss += l 433 | n += ln.sum() 434 | lang_loss /= n 435 | 436 | loss = lang_loss 437 | self.opt.zero_grad() 438 | loss.backward() 439 | 440 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip) 441 | self.opt.step() 442 | return lang_loss.item(), stats 443 | 444 | def valid_batch(self, batch): 445 | with torch.no_grad(): 446 | losses, stats, lens = self._forward(batch) 447 | 448 | loss = 0 449 | for l in losses: 450 | loss += l 451 | 452 | return loss.item(), 0, 0, stats 453 | 454 | def train_pass(self, trainset): 455 | self.model.train() 456 | 457 | total_loss = 0 458 | total_entropy = 0 459 | total_max_prob = 0 460 | total_top3_prob = 0 461 | start_time = time.time() 462 | 463 | for batch in trainset: 464 | self.t += 1 465 | loss, (entropy, max_prob, top3_prob) = self.train_batch(batch) 466 | 467 | if self.args.visual and self.t % 100 == 0: 468 | self.model_plot.update(self.t) 469 | 470 | total_loss += loss 471 | total_entropy += entropy 472 | total_max_prob += max_prob 473 | total_top3_prob += top3_prob 474 | 475 | total_loss /= len(trainset) 476 | total_entropy /= len(trainset) 477 | total_max_prob /= len(trainset) 478 | total_top3_prob /= len(trainset) 479 | time_elapsed = time.time() - start_time 480 | print('| train | avg entropy %.3f | avg max prob %.3f | avg top3 prob %.3f' % ( 481 | total_entropy, total_max_prob, total_top3_prob)) 482 | 483 | return total_loss, time_elapsed 484 | 485 | def valid_pass(self, validset, validset_stats): 486 | self.model.eval() 487 | 488 | total_valid_loss, total_select_loss, total_partner_ctx_loss = 0, 0, 0 489 | total_entropy = 0 490 | total_max_prob = 0 491 | total_top3_prob = 0 492 | for batch in validset: 493 | valid_loss, select_loss, partner_ctx_loss, (entropy, max_prob, top3_prob) = self.valid_batch(batch) 494 | total_valid_loss += valid_loss 495 | total_select_loss += select_loss 496 | total_partner_ctx_loss += partner_ctx_loss 497 | total_entropy += entropy 498 | total_max_prob += max_prob 499 | total_top3_prob += top3_prob 500 | 501 | # Dividing by the number of words in the input, not the tokens modeled, 502 | # because the latter includes padding 503 | total_valid_loss /= validset_stats['nonpadn'] 504 | total_select_loss /= len(validset) 505 | total_partner_ctx_loss /= len(validset) 506 | #total_future_loss /= len(validset) 507 | total_entropy /= len(validset) 508 | total_max_prob /= len(validset) 509 | total_top3_prob /= len(validset) 510 | print('| valid | avg entropy %.3f | avg max prob %.3f | avg top3 prob %.3f' % ( 511 | total_entropy, total_max_prob, total_top3_prob)) 512 | 513 | extra = { 514 | 'entropy': total_entropy, 515 | 'avg_max_prob': total_max_prob, 516 | 'avg_top3_prob': total_top3_prob, 517 | } 518 | 519 | return total_valid_loss, total_select_loss, total_partner_ctx_loss, extra 520 | --------------------------------------------------------------------------------