├── __init__.py ├── tests ├── bleu │ ├── cand │ ├── ref │ └── bleu_evaluation └── data │ ├── ttest.triples.pkl │ ├── ttrain.dict.pkl │ ├── ttrain.triples.pkl │ ├── tvalid.triples.pkl │ ├── ttrain.txt │ └── tvalid.txt ├── README.md ├── .gitignore ├── numpy_compat.py ├── model.py ├── adam.py ├── SS_dataset.py ├── sample.py ├── chat.py ├── data_iterator.py ├── state.py ├── convert-text2dict.py ├── utils.py ├── search.py ├── evaluation.py ├── train.py └── dialog_encdec.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/bleu/cand: -------------------------------------------------------------------------------- 1 | i got a good fire 2 | -------------------------------------------------------------------------------- /tests/bleu/ref: -------------------------------------------------------------------------------- 1 | i got a good place 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # hed-dlg 2 | Hierarchical Encoder Decoder for Dialog Modelling 3 | -------------------------------------------------------------------------------- /tests/data/ttest.triples.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sordonia/hed-dlg/HEAD/tests/data/ttest.triples.pkl -------------------------------------------------------------------------------- /tests/data/ttrain.dict.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sordonia/hed-dlg/HEAD/tests/data/ttrain.dict.pkl -------------------------------------------------------------------------------- /tests/data/ttrain.triples.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sordonia/hed-dlg/HEAD/tests/data/ttrain.triples.pkl -------------------------------------------------------------------------------- /tests/data/tvalid.triples.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sordonia/hed-dlg/HEAD/tests/data/tvalid.triples.pkl -------------------------------------------------------------------------------- /tests/data/ttrain.txt: -------------------------------------------------------------------------------- 1 | how are you ? fine thanks ! and you ? 2 | what are you doing ? nothing much . are you serious ? 3 | -------------------------------------------------------------------------------- /tests/data/tvalid.txt: -------------------------------------------------------------------------------- 1 | how are you ? fine thanks ! and you ? 2 | what are you doing ? nothing much . are you serious ? 3 | -------------------------------------------------------------------------------- /tests/bleu/bleu_evaluation: -------------------------------------------------------------------------------- 1 | how are you ? fine thanks ! and you ? 2 | what are you doing ? nothing much . are you serious ? 3 | what are you doing ? nothing much . serious, fine ? 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | state.py 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | 47 | # Translations 48 | *.mo 49 | *.pot 50 | 51 | # Django stuff: 52 | *.log 53 | 54 | # Sphinx documentation 55 | docs/_build/ 56 | 57 | # PyBuilder 58 | target/ 59 | -------------------------------------------------------------------------------- /numpy_compat.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Compatibility with older numpy's providing argpartition replacement. 3 | 4 | ''' 5 | 6 | 7 | ''' 8 | Created on Sep 12, 2014 9 | 10 | @author: chorows 11 | ''' 12 | 13 | __all__ = ['argpartition'] 14 | 15 | import numpy 16 | import warnings 17 | 18 | if hasattr(numpy, 'argpartition'): 19 | argpartition = numpy.argpartition 20 | else: 21 | try: 22 | import bottleneck 23 | #warnings.warn('Your numpy is too old (You have %s, we need 1.7.1), but we have found argpartsort in bottleneck' % (numpy.__version__,)) 24 | def argpartition(a, kth, axis=-1): 25 | return bottleneck.argpartsort(a, kth, axis) 26 | except ImportError: 27 | warnings.warn('''Beam search will be slow! 28 | 29 | Your numpy is old (you have v. %s) and doesn't provide an argpartition function. 30 | Either upgrade numpy, or install bottleneck (https://pypi.python.org/pypi/Bottleneck). 31 | 32 | If you run this from within LISA lab you probably want to run: pip install bottleneck --user 33 | ''' % (numpy.__version__,)) 34 | def argpartition(a, kth, axis=-1, order=None): 35 | return numpy.argsort(a, axis=axis, order=order) 36 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy 3 | import theano 4 | logger = logging.getLogger(__name__) 5 | 6 | class Model(object): 7 | def __init__(self): 8 | self.floatX = theano.config.floatX 9 | # Parameters of the model 10 | self.params = [] 11 | 12 | def save(self, filename): 13 | """ 14 | Save the model to file `filename` 15 | """ 16 | vals = dict([(x.name, x.get_value()) for x in self.params]) 17 | numpy.savez(filename, **vals) 18 | 19 | def load(self, filename): 20 | """ 21 | Load the model. 22 | """ 23 | vals = numpy.load(filename) 24 | for p in self.params: 25 | if p.name in vals: 26 | logger.debug('Loading {} of {}'.format(p.name, p.get_value(borrow=True).shape)) 27 | if p.get_value().shape != vals[p.name].shape: 28 | raise Exception('Shape mismatch: {} != {} for {}'.format(p.get_value().shape, vals[p.name].shape, p.name)) 29 | p.set_value(vals[p.name]) 30 | else: 31 | logger.error('No parameter {} given: default initialization used'.format(p.name)) 32 | unknown = set(vals.keys()) - {p.name for p in self.params} 33 | if len(unknown): 34 | logger.error('Unknown parameters {} given'.format(unknown)) 35 | -------------------------------------------------------------------------------- /adam.py: -------------------------------------------------------------------------------- 1 | """ 2 | The MIT License (MIT) 3 | 4 | Copyright (c) 2015 Alec Radford 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | import theano 26 | import theano.tensor as T 27 | 28 | def sharedX(value, name=None, borrow=False, dtype=None): 29 | if dtype is None: 30 | dtype = theano.config.floatX 31 | return theano.shared(theano._asarray(value, dtype=dtype), 32 | name=name, 33 | borrow=borrow) 34 | 35 | def Adam(grads, lr=0.0002, b1=0.1, b2=0.001, e=1e-8): 36 | updates = [] 37 | i = sharedX(0.) 38 | i_t = i + 1. 39 | fix1 = 1. - (1. - b1)**i_t 40 | fix2 = 1. - (1. - b2)**i_t 41 | lr_t = lr * (T.sqrt(fix2) / fix1) 42 | for p, g in grads.items(): 43 | m = sharedX(p.get_value() * 0.) 44 | v = sharedX(p.get_value() * 0.) 45 | m_t = (b1 * g) + ((1. - b1) * m) 46 | v_t = (b2 * T.sqr(g)) + ((1. - b2) * v) 47 | g_t = m_t / (T.sqrt(v_t) + e) 48 | p_t = p - (lr_t * g_t) 49 | updates.append((m, m_t)) 50 | updates.append((v, v_t)) 51 | updates.append((p, p_t)) 52 | updates.append((i, i_t)) 53 | return updates 54 | -------------------------------------------------------------------------------- /SS_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os, gc 3 | import cPickle 4 | import copy 5 | import logging 6 | 7 | import threading 8 | import Queue 9 | 10 | import collections 11 | 12 | logger = logging.getLogger(__name__) 13 | np.random.seed(1234) 14 | 15 | class SSFetcher(threading.Thread): 16 | def __init__(self, parent): 17 | threading.Thread.__init__(self) 18 | self.parent = parent 19 | self.indexes = np.arange(parent.data_len) 20 | 21 | def run(self): 22 | diter = self.parent 23 | # Shuffle with parents random generator 24 | self.parent.rng.shuffle(self.indexes) 25 | 26 | offset = 0 27 | # Take groups of 10000 triples and group by length 28 | while not diter.exit_flag: 29 | last_batch = False 30 | triples = [] 31 | 32 | while len(triples) < diter.batch_size: 33 | if offset == diter.data_len: 34 | if not diter.use_infinite_loop: 35 | last_batch = True 36 | break 37 | else: 38 | # Infinite loop here, we reshuffle the indexes 39 | # and reset the offset 40 | self.parent.rng.shuffle(self.indexes) 41 | offset = 0 42 | 43 | index = self.indexes[offset] 44 | s = diter.data[index] 45 | offset += 1 46 | 47 | # Append only if it is shorter than max_len 48 | if len(s) <= diter.max_len: 49 | triples.append(s) 50 | 51 | if len(triples): 52 | diter.queue.put(triples) 53 | 54 | if last_batch: 55 | diter.queue.put(None) 56 | return 57 | 58 | class SSIterator(object): 59 | def __init__(self, 60 | rng, 61 | batch_size, 62 | triple_file=None, 63 | dtype="int32", 64 | can_fit=False, 65 | queue_size=100, 66 | cache_size=100, 67 | shuffle=True, 68 | use_infinite_loop=True, 69 | max_len=1000): 70 | 71 | args = locals() 72 | args.pop("self") 73 | self.__dict__.update(args) 74 | self.rng = rng 75 | self.load_files() 76 | self.exit_flag = False 77 | 78 | def load_files(self): 79 | self.data = cPickle.load(open(self.triple_file, 'r')) 80 | self.data_len = len(self.data) 81 | logger.debug('Data len is %d' % self.data_len) 82 | 83 | def start(self): 84 | self.exit_flag = False 85 | self.queue = Queue.Queue(maxsize=self.queue_size) 86 | self.gather = SSFetcher(self) 87 | self.gather.daemon = True 88 | self.gather.start() 89 | 90 | def __del__(self): 91 | if hasattr(self, 'gather'): 92 | self.gather.exitFlag = True 93 | self.gather.join() 94 | 95 | def __iter__(self): 96 | return self 97 | 98 | def next(self): 99 | if self.exit_flag: 100 | return None 101 | 102 | batch = self.queue.get() 103 | if not batch: 104 | self.exit_flag = True 105 | return batch 106 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import cPickle 5 | import traceback 6 | import logging 7 | import time 8 | import sys 9 | 10 | import os 11 | import numpy 12 | import codecs 13 | import search 14 | import utils 15 | 16 | from dialog_encdec import DialogEncoderDecoder 17 | from numpy_compat import argpartition 18 | from state import prototype_state 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | class Timer(object): 23 | def __init__(self): 24 | self.total = 0 25 | 26 | def start(self): 27 | self.start_time = time.time() 28 | 29 | def finish(self): 30 | self.total += time.time() - self.start_time 31 | 32 | def parse_args(): 33 | parser = argparse.ArgumentParser("Sample (with beam-search) from the session model") 34 | 35 | parser.add_argument("--ignore-unk", 36 | action="store_false", 37 | help="Ignore unknown words") 38 | 39 | parser.add_argument("model_prefix", 40 | help="Path to the model prefix (without _model.npz or _state.pkl)") 41 | 42 | parser.add_argument("context", 43 | help="File of input contexts (pair of sentences, tab separated)") 44 | 45 | parser.add_argument("output", 46 | help="Output file") 47 | 48 | parser.add_argument("--beam_search", 49 | action="store_true", 50 | help="Use beam search instead of random search") 51 | 52 | parser.add_argument("--n-samples", 53 | default="1", type=int, 54 | help="Number of samples") 55 | 56 | parser.add_argument("--n-turns", 57 | default=1, type=int, 58 | help="Number of dialog turns to generate") 59 | 60 | parser.add_argument("--normalize", 61 | action="store_true", default=False, 62 | help="Normalize log-prob with the word count") 63 | 64 | parser.add_argument("--verbose", 65 | action="store_true", default=False, 66 | help="Be verbose") 67 | 68 | parser.add_argument("changes", nargs="?", default="", help="Changes to state") 69 | return parser.parse_args() 70 | 71 | def main(): 72 | args = parse_args() 73 | state = prototype_state() 74 | 75 | state_path = args.model_prefix + "_state.pkl" 76 | model_path = args.model_prefix + "_model.npz" 77 | 78 | with open(state_path) as src: 79 | state.update(cPickle.load(src)) 80 | 81 | logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s") 82 | 83 | model = DialogEncoderDecoder(state) 84 | 85 | sampler = search.RandomSampler(model) 86 | if args.beam_search: 87 | sampler = search.BeamSampler(model) 88 | 89 | if os.path.isfile(model_path): 90 | logger.debug("Loading previous model") 91 | model.load(model_path) 92 | else: 93 | raise Exception("Must specify a valid model path") 94 | 95 | contexts = [[]] 96 | lines = open(args.context, "r").readlines() 97 | if len(lines): 98 | contexts = [x.strip().split('\t') for x in lines] 99 | 100 | context_samples, context_costs = sampler.sample(contexts, 101 | n_samples=args.n_samples, 102 | n_turns=args.n_turns, 103 | ignore_unk=args.ignore_unk, 104 | verbose=args.verbose) 105 | 106 | # Write to output file 107 | output_handle = open(args.output, "w") 108 | for context_sample in context_samples: 109 | print >> output_handle, '\t'.join(context_sample) 110 | output_handle.close() 111 | 112 | if __name__ == "__main__": 113 | main() 114 | -------------------------------------------------------------------------------- /chat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | __docformat__ = 'restructedtext en' 3 | __authors__ = ("Julian Serban, Alessandro Sordoni") 4 | __contact__ = "Julian Serban " 5 | 6 | import argparse 7 | import cPickle 8 | import traceback 9 | import itertools 10 | import logging 11 | import time 12 | import sys 13 | import search 14 | 15 | import collections 16 | import string 17 | import os 18 | import numpy 19 | import codecs 20 | 21 | import nltk 22 | from random import randint 23 | 24 | from dialog_encdec import DialogEncoderDecoder 25 | from numpy_compat import argpartition 26 | from state import prototype_state 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | class Timer(object): 31 | def __init__(self): 32 | self.total = 0 33 | 34 | def start(self): 35 | self.start_time = time.time() 36 | 37 | def finish(self): 38 | self.total += time.time() - self.start_time 39 | 40 | def sample(model, seqs=[[]], n_samples=1, beam_search=None, ignore_unk=False): 41 | if beam_search: 42 | sentences = [] 43 | 44 | seq = model.words_to_indices(seqs[0]) 45 | gen_ids, gen_costs = beam_search.search(seq, n_samples, ignore_unk=ignore_unk) 46 | 47 | for i in range(len(gen_ids)): 48 | sentence = model.indices_to_words(gen_ids[i]) 49 | sentences.append(sentence) 50 | 51 | return sentences 52 | else: 53 | raise Exception("I don't know what to do") 54 | 55 | def parse_args(): 56 | parser = argparse.ArgumentParser("Sample (with beam-search) from the session model") 57 | 58 | parser.add_argument("--ignore-unk", 59 | default=True, action="store_true", 60 | help="Ignore unknown words") 61 | 62 | parser.add_argument("model_prefix", 63 | help="Path to the model prefix (without _model.npz or _state.pkl)") 64 | 65 | parser.add_argument("--normalize", 66 | action="store_true", default=False, 67 | help="Normalize log-prob with the word count") 68 | 69 | return parser.parse_args() 70 | 71 | def main(): 72 | args = parse_args() 73 | state = prototype_state() 74 | 75 | state_path = args.model_prefix + "_state.pkl" 76 | model_path = args.model_prefix + "_model.npz" 77 | 78 | with open(state_path) as src: 79 | state.update(cPickle.load(src)) 80 | 81 | logging.basicConfig(level=getattr(logging, state['level']), format="%(asctime)s: %(name)s: %(levelname)s: %(message)s") 82 | 83 | model = DialogEncoderDecoder(state) 84 | if os.path.isfile(model_path): 85 | logger.debug("Loading previous model") 86 | model.load(model_path) 87 | else: 88 | raise Exception("Must specify a valid model path") 89 | 90 | logger.info("This model uses " + model.decoder_bias_type + " bias type") 91 | 92 | beam_search = None 93 | sampler = None 94 | 95 | beam_search = search.BeamSearch(model) 96 | beam_search.compile() 97 | 98 | # Start chat loop 99 | utterances = collections.deque() 100 | 101 | while (True): 102 | var = raw_input("User - ") 103 | 104 | while len(utterances) > 2: 105 | utterances.popleft() 106 | 107 | current_utterance = [ model.start_sym_sentence ] + var.split() + [ model.end_sym_sentence ] 108 | utterances.append(current_utterance) 109 | 110 | # Sample a random reply. To spicy it up, we could pick the longest reply or the reply with the fewest placeholders... 111 | seqs = list(itertools.chain(*utterances)) 112 | 113 | sentences = sample(model, \ 114 | seqs=[seqs], ignore_unk=args.ignore_unk, \ 115 | beam_search=beam_search, n_samples=5) 116 | 117 | if len(sentences) == 0: 118 | raise ValueError("Generation error, no sentences were produced!") 119 | 120 | reply = " ".join(sentences[0]).encode('utf-8') 121 | print "AI - ", reply 122 | 123 | utterances.append(sentences[0]) 124 | 125 | if __name__ == "__main__": 126 | # Run with THEANO_FLAGS=mode=FAST_RUN,floatX=float32,allow_gc=True,scan.allow_gc=False,nvcc.flags=-use_fast_math python chat.py Model_Name 127 | main() 128 | 129 | 130 | -------------------------------------------------------------------------------- /data_iterator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import theano 3 | import theano.tensor as T 4 | import sys, getopt 5 | import logging 6 | 7 | from state import * 8 | from utils import * 9 | from SS_dataset import * 10 | 11 | import itertools 12 | import sys 13 | import pickle 14 | import random 15 | import datetime 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | def create_padded_batch(state, x): 20 | mx = state['seqlen'] 21 | n = state['bs'] 22 | 23 | X = numpy.zeros((mx, n), dtype='int32') 24 | Xmask = numpy.zeros((mx, n), dtype='float32') 25 | 26 | # Variables to store last utterance (used to compute mutual information metric) 27 | X_last_utterance = numpy.zeros((mx, n), dtype='int32') 28 | Xmask_last_utterance = numpy.zeros((mx, n), dtype='float32') 29 | 30 | # Fill X and Xmask 31 | # Keep track of number of predictions and maximum triple length 32 | num_preds = 0 33 | max_length = 0 34 | for idx in xrange(len(x[0])): 35 | # Insert sequence idx in a column of matrix X 36 | triple_length = len(x[0][idx]) 37 | 38 | # Fiddle-it if it is too long .. 39 | if mx < triple_length: 40 | continue 41 | 42 | X[:triple_length, idx] = x[0][idx][:triple_length] 43 | 44 | max_length = max(max_length, triple_length) 45 | 46 | # Set the number of predictions == sum(Xmask), for cost purposes 47 | num_preds += triple_length 48 | 49 | # Mark the end of phrase 50 | if len(x[0][idx]) < mx: 51 | X[triple_length:, idx] = state['eos_sym'] 52 | 53 | # Initialize Xmask column with ones in all positions that 54 | # were just set in X 55 | Xmask[:triple_length, idx] = 1. 56 | 57 | # Find start of last utterance 58 | eos_indices = numpy.where(X[:, idx] == state['eos_sym'])[0] 59 | assert (len(eos_indices) > 2) 60 | start_of_last_utterance = eos_indices[1]+1 61 | X_last_utterance[0:(triple_length-start_of_last_utterance), idx] = X[start_of_last_utterance:triple_length, idx] 62 | Xmask_last_utterance[0:(triple_length-start_of_last_utterance), idx] = Xmask[start_of_last_utterance:triple_length, idx] 63 | 64 | 65 | 66 | 67 | 68 | assert num_preds == numpy.sum(Xmask) 69 | return {'x': X, 'x_mask': Xmask, 'x_last_utterance': X_last_utterance, 'x_mask_last_utterance': Xmask_last_utterance, 'num_preds': num_preds, 'max_length': max_length} 70 | 71 | def get_batch_iterator(rng, state): 72 | class Iterator(SSIterator): 73 | def __init__(self, *args, **kwargs): 74 | SSIterator.__init__(self, rng, *args, **kwargs) 75 | self.batch_iter = None 76 | 77 | def get_homogenous_batch_iter(self, batch_size = -1): 78 | while True: 79 | k_batches = state['sort_k_batches'] 80 | batch_size = self.batch_size if (batch_size == -1) else batch_size 81 | 82 | data = [] 83 | for k in range(k_batches): 84 | batch = SSIterator.next(self) 85 | if batch: 86 | data.append(batch) 87 | 88 | if not len(data): 89 | return 90 | 91 | x = numpy.asarray(list(itertools.chain(*data))) 92 | lens = numpy.asarray([map(len, x)]) 93 | order = numpy.argsort(lens.max(axis=0)) 94 | 95 | for k in range(len(data)): 96 | indices = order[k * batch_size:(k + 1) * batch_size] 97 | batch = create_padded_batch(state, [x[indices]]) 98 | if batch: 99 | yield batch 100 | 101 | def start(self): 102 | SSIterator.start(self) 103 | self.batch_iter = None 104 | 105 | def next(self, batch_size = -1): 106 | """ 107 | We can specify a batch size, 108 | independent of the object initialization. 109 | """ 110 | if not self.batch_iter: 111 | self.batch_iter = self.get_homogenous_batch_iter(batch_size) 112 | try: 113 | batch = next(self.batch_iter) 114 | except StopIteration: 115 | return None 116 | return batch 117 | 118 | train_data = Iterator( 119 | batch_size=int(state['bs']), 120 | triple_file=state['train_triples'], 121 | queue_size=100, 122 | use_infinite_loop=True, 123 | max_len=state['seqlen']) 124 | 125 | valid_data = Iterator( 126 | batch_size=int(state['bs']), 127 | triple_file=state['valid_triples'], 128 | use_infinite_loop=False, 129 | queue_size=100, 130 | max_len=state['seqlen']) 131 | 132 | test_data = Iterator( 133 | batch_size=int(state['bs']), 134 | triple_file=state['test_triples'], 135 | use_infinite_loop=False, 136 | queue_size=100, 137 | max_len=state['seqlen']) 138 | 139 | return train_data, valid_data, test_data 140 | -------------------------------------------------------------------------------- /state.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | def prototype_state(): 4 | state = {} 5 | 6 | # Random seed 7 | state['seed'] = 1234 8 | 9 | # Logging level 10 | state['level'] = 'DEBUG' 11 | 12 | state['oov'] = '' 13 | state['len_sample'] = 40 14 | 15 | # These are end-of-sequence marks 16 | state['start_sym_sentence'] = '' 17 | state['end_sym_sentence'] = '' 18 | state['end_sym_triple'] = '' 19 | 20 | state['unk_sym'] = 0 21 | state['eot_sym'] = 3 22 | state['eos_sym'] = 2 23 | state['sos_sym'] = 1 24 | 25 | state['maxout_out'] = False 26 | state['deep_out'] = True 27 | 28 | # ----- ACTIV ---- 29 | state['sent_rec_activation'] = 'lambda x: T.tanh(x)' 30 | state['triple_rec_activation'] = 'lambda x: T.tanh(x)' 31 | 32 | state['decoder_bias_type'] = 'all' # first, or selective 33 | 34 | state['sent_step_type'] = 'gated' 35 | state['triple_step_type'] = 'gated' 36 | 37 | # ----- SIZES ---- 38 | # Dimensionality of hidden layers 39 | state['qdim'] = 512 40 | # Dimensionality of triple hidden layer 41 | state['sdim'] = 1000 42 | # Dimensionality of low-rank approximation 43 | state['rankdim'] = 256 44 | 45 | # Threshold to clip the gradient 46 | state['cutoff'] = 1. 47 | state['lr'] = 0.0001 48 | 49 | # Early stopping configuration 50 | state['patience'] = 5 51 | state['cost_threshold'] = 1.003 52 | 53 | # ----- TRAINING METHOD ----- 54 | # Choose optimization algorithm 55 | state['updater'] = 'adam' 56 | # Maximum sequence length / trim batches 57 | state['seqlen'] = 80 58 | # Batch size 59 | state['bs'] = 80 60 | # Sort by length groups of 61 | state['sort_k_batches'] = 20 62 | 63 | # Maximum number of iterations 64 | state['max_iters'] = 10 65 | # Modify this in the prototype 66 | state['save_dir'] = './' 67 | 68 | # ----- TRAINING PROCESS ----- 69 | # Frequency of training error reports (in number of batches) 70 | state['train_freq'] = 10 71 | # Validation frequency 72 | state['valid_freq'] = 5000 73 | # Number of batches to process 74 | state['loop_iters'] = 3000000 75 | # Maximum number of minutes to run 76 | state['time_stop'] = 24*60*31 77 | # Error level to stop at 78 | state['minerr'] = -1 79 | 80 | # ----- EVALUATION PROCESS ----- 81 | state['track_extrema_validation_samples'] = True # If set to true will print the extrema (lowest and highest log-likelihood scoring) validation samples 82 | state['track_extrema_samples_count'] = 100 # Set of extrema samples to track 83 | state['print_extrema_samples_count'] = 5 # Number of extrema samples to print (chosen at random from the extrema sets) 84 | 85 | state['compute_mutual_information'] = True # If true, the empirical mutural information will be calculcated on the validation set 86 | 87 | 88 | return state 89 | 90 | def prototype_test(): 91 | state = prototype_state() 92 | 93 | # Fill your paths here! 94 | state['train_triples'] = "./tests/data/ttrain.triples.pkl" 95 | state['test_triples'] = "./tests/data/ttest.triples.pkl" 96 | state['valid_triples'] = "./tests/data/tvalid.triples.pkl" 97 | state['dictionary'] = "./tests/data/ttrain.dict.pkl" 98 | state['save_dir'] = "./tests/models/" 99 | 100 | # Handle bleu evaluation 101 | state['bleu_evaluation'] = "./tests/bleu/bleu_evaluation" 102 | state['bleu_context_length'] = 2 103 | 104 | 105 | 106 | # Validation frequency 107 | state['valid_freq'] = 50 108 | 109 | # Varia 110 | state['prefix'] = "testmodel_" 111 | state['updater'] = 'adam' 112 | 113 | state['maxout_out'] = False 114 | state['deep_out'] = True 115 | 116 | # If out of memory, modify this! 117 | state['bs'] = 80 118 | state['use_nce'] = True 119 | state['decoder_bias_type'] = 'all' #'selective' 120 | 121 | state['qdim'] = 50 122 | # Dimensionality of triple hidden layer 123 | state['sdim'] = 100 124 | # Dimensionality of low-rank approximation 125 | state['rankdim'] = 25 126 | return state 127 | 128 | def prototype_moviedic(): 129 | state = prototype_state() 130 | 131 | # Fill your paths here! 132 | state['train_triples'] = "Data/Training.triples.pkl" 133 | state['test_triples'] = "Data/Test.triples.pkl" 134 | state['valid_triples'] = "Data/Validation.triples.pkl" 135 | state['dictionary'] = "Data/Training.dict.pkl" 136 | state['save_dir'] = "Output" 137 | 138 | # Handle bleu evaluation 139 | state['bleu_evaluation'] = "Data/Validation_Shuffled_Dataset.txt" 140 | state['bleu_context_length'] = 2 141 | 142 | # Validation frequency 143 | state['valid_freq'] = 5000 144 | 145 | # Varia 146 | state['prefix'] = "MovieScriptModel_" 147 | state['updater'] = 'adam' 148 | 149 | state['maxout_out'] = True 150 | state['deep_out'] = True 151 | 152 | # If out of memory, modify this! 153 | state['bs'] = 80 154 | state['use_nce'] = False 155 | state['decoder_bias_type'] = 'all' # Choose between 'first', 'all' and 'selective' 156 | 157 | # Increase sequence length to fit movie dialogues better 158 | state['seqlen'] = 160 159 | 160 | state['qdim'] = 600 161 | # Dimensionality of triple hidden layer 162 | state['sdim'] = 300 163 | # Dimensionality of low-rank approximation 164 | state['rankdim'] = 300 165 | return state 166 | 167 | 168 | -------------------------------------------------------------------------------- /convert-text2dict.py: -------------------------------------------------------------------------------- 1 | """ 2 | Takes as input a triple file and creates a processed version of it. 3 | If given an external dictionary, the input triple file will be converted 4 | using that input dictionary. 5 | 6 | @author Alessandro Sordoni 7 | """ 8 | 9 | import collections 10 | import numpy 11 | import operator 12 | import os 13 | import sys 14 | import logging 15 | import cPickle 16 | import itertools 17 | from collections import Counter 18 | 19 | logging.basicConfig(level=logging.INFO) 20 | logger = logging.getLogger('text2dict') 21 | 22 | def safe_pickle(obj, filename): 23 | if os.path.isfile(filename): 24 | logger.info("Overwriting %s." % filename) 25 | else: 26 | logger.info("Saving to %s." % filename) 27 | 28 | with open(filename, 'wb') as f: 29 | cPickle.dump(obj, f, protocol=cPickle.HIGHEST_PROTOCOL) 30 | 31 | import argparse 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("input", type=str, help="Tab separated triple file (assumed shuffled)") 34 | parser.add_argument("--cutoff", type=int, default=-1, help="Vocabulary cutoff (optional)") 35 | parser.add_argument("--dict", type=str, default="", help="External dictionary (pkl file)") 36 | parser.add_argument("--use_all_triples", action='store_true', help="If false, all training examples with more than or less than three utterances will be removed. If true, training examples with less than three utterances will have empty utterances appended at the beginning, and training examples with more than three utterances will have their first utterances discarded.") 37 | 38 | parser.add_argument("output", type=str, help="Prefix of the pickle binarized triple corpus") 39 | args = parser.parse_args() 40 | 41 | if not os.path.isfile(args.input): 42 | raise Exception("Input file not found!") 43 | 44 | unk = "" 45 | 46 | ############################### 47 | # Part I: Create the dictionary 48 | ############################### 49 | if args.dict != "": 50 | # Load external dictionary 51 | assert os.path.isfile(args.dict) 52 | vocab = dict([(x[0], x[1]) for x in cPickle.load(open(args.dict, "r"))]) 53 | 54 | # Check consistency 55 | assert '' in vocab 56 | assert '' in vocab 57 | assert '' in vocab 58 | else: 59 | word_counter = Counter() 60 | 61 | 62 | for line in open(args.input, 'r'): 63 | s = [x for x in line.strip().split()] 64 | word_counter.update(s) 65 | 66 | total_freq = sum(word_counter.values()) 67 | logger.info("Total word frequency in dictionary %d " % total_freq) 68 | 69 | if args.cutoff != -1: 70 | logger.info("Cutoff %d" % args.cutoff) 71 | vocab_count = word_counter.most_common(args.cutoff) 72 | else: 73 | vocab_count = word_counter.most_common() 74 | 75 | 76 | # Add special tokens to the vocabulary 77 | vocab = {'': 0, '': 1, '': 2} 78 | for i, (word, count) in enumerate(vocab_count): 79 | vocab[word] = i + 3 80 | 81 | 82 | logger.info("Vocab size %d" % len(vocab)) 83 | 84 | ################################# 85 | # Part II: Binarize the triples 86 | ################################# 87 | 88 | # Everything is loaded into memory for the moment 89 | binarized_corpus = [] 90 | # Some statistics 91 | mean_sl = 0. 92 | unknowns = 0. 93 | num_terms = 0. 94 | freqs = collections.defaultdict(lambda: 0) 95 | 96 | # counts the number of triples each unique word exists in; also known as document frequency 97 | df = collections.defaultdict(lambda: 0) 98 | 99 | for line, triple in enumerate(open(args.input, 'r')): 100 | triple_lst = [] 101 | 102 | utterances = triple.split('\t') 103 | for i, utterance in enumerate(utterances): 104 | 105 | utterance_lst = [] 106 | for word in utterance.strip().split(): 107 | word_id = vocab.get(word, 0) 108 | unknowns += 1 * (word_id == 0) 109 | utterance_lst.append(word_id) 110 | freqs[word_id] += 1 111 | 112 | num_terms += len(utterance_lst) 113 | 114 | # Here, we filter out unknown triple text and empty triples 115 | # i.e. or 0 116 | if utterance_lst != [0] and len(utterance_lst): 117 | triple_lst.append([1] + utterance_lst + [2]) 118 | freqs[1] += 1 119 | freqs[2] += 1 120 | df[1] += 1 121 | df[2] += 1 122 | 123 | if args.use_all_triples == True: 124 | if len(triple_lst) > 3: 125 | triple_lst = triple_lst[len(triple_lst)-2:len(triple_lst)] 126 | else: 127 | while len(triple_lst) < 3: 128 | triple_lst.insert(0, [1] + [2]) 129 | 130 | if len(triple_lst) == 3: 131 | # Flatten out binarized triple 132 | # [[a, b, c], [c, d, e]] -> [a, b, c, d, e] 133 | binarized_triple = list(itertools.chain(*triple_lst)) 134 | binarized_corpus.append(binarized_triple) 135 | 136 | unique_word_indices = [] 137 | for i in range(len(triple_lst)): 138 | for word_id in triple_lst[i]: 139 | unique_word_indices.append(word_id) 140 | 141 | unique_word_indices = set(unique_word_indices) 142 | for word_id in unique_word_indices: 143 | df[word_id] += 1 144 | 145 | safe_pickle(binarized_corpus, args.output + ".triples.pkl") 146 | 147 | if args.dict == "": 148 | safe_pickle([(word, word_id, freqs[word_id], df[word_id]) for word, word_id in vocab.items()], args.output + ".dict.pkl") 149 | 150 | logger.info("Number of unknowns %d" % unknowns) 151 | logger.info("Number of terms %d" % num_terms) 152 | logger.info("Mean triple length %f" % float(sum(map(len, binarized_corpus))/len(binarized_corpus))) 153 | logger.info("Writing training %d triples (%d left out)" % (len(binarized_corpus), line + 1 - len(binarized_corpus))) 154 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import adam 3 | import theano 4 | import theano.tensor as T 5 | from collections import OrderedDict 6 | 7 | PRINT_VARS = True 8 | 9 | def DPrint(name, var): 10 | if PRINT_VARS is False: 11 | return var 12 | 13 | return theano.printing.Print(name)(var) 14 | 15 | def sharedX(value, name=None, borrow=False, dtype=None): 16 | if dtype is None: 17 | dtype = theano.config.floatX 18 | return theano.shared(theano._asarray(value, dtype=dtype), 19 | name=name, 20 | borrow=borrow) 21 | 22 | def Adam(grads, lr=0.0002, b1=0.1, b2=0.001, e=1e-8): 23 | return adam.Adam(grads, lr, b1, b2, e) 24 | 25 | def Adagrad(grads, lr): 26 | updates = OrderedDict() 27 | for param in grads.keys(): 28 | # sum_square_grad := \sum g^2 29 | sum_square_grad = sharedX(param.get_value() * 0.) 30 | if param.name is not None: 31 | sum_square_grad.name = 'sum_square_grad_' + param.name 32 | 33 | # Accumulate gradient 34 | new_sum_squared_grad = sum_square_grad + T.sqr(grads[param]) 35 | 36 | # Compute update 37 | delta_x_t = (- lr / T.sqrt(numpy.float32(1e-5) + new_sum_squared_grad)) * grads[param] 38 | 39 | # Apply update 40 | updates[sum_square_grad] = new_sum_squared_grad 41 | updates[param] = param + delta_x_t 42 | return updates 43 | 44 | def Adadelta(grads, decay=0.95, epsilon=1e-6): 45 | updates = OrderedDict() 46 | for param in grads.keys(): 47 | # mean_squared_grad := E[g^2]_{t-1} 48 | mean_square_grad = sharedX(param.get_value() * 0.) 49 | # mean_square_dx := E[(\Delta x)^2]_{t-1} 50 | mean_square_dx = sharedX(param.get_value() * 0.) 51 | 52 | if param.name is not None: 53 | mean_square_grad.name = 'mean_square_grad_' + param.name 54 | mean_square_dx.name = 'mean_square_dx_' + param.name 55 | 56 | # Accumulate gradient 57 | new_mean_squared_grad = ( 58 | decay * mean_square_grad + 59 | (1 - decay) * T.sqr(grads[param]) 60 | ) 61 | 62 | # Compute update 63 | rms_dx_tm1 = T.sqrt(mean_square_dx + epsilon) 64 | rms_grad_t = T.sqrt(new_mean_squared_grad + epsilon) 65 | delta_x_t = - rms_dx_tm1 / rms_grad_t * grads[param] 66 | 67 | # Accumulate updates 68 | new_mean_square_dx = ( 69 | decay * mean_square_dx + 70 | (1 - decay) * T.sqr(delta_x_t) 71 | ) 72 | 73 | # Apply update 74 | updates[mean_square_grad] = new_mean_squared_grad 75 | updates[mean_square_dx] = new_mean_square_dx 76 | updates[param] = param + delta_x_t 77 | 78 | return updates 79 | 80 | def RMSProp(grads, lr, decay=0.95, eta=0.9, epsilon=1e-6): 81 | """ 82 | RMSProp gradient method 83 | """ 84 | updates = OrderedDict() 85 | for param in grads.keys(): 86 | # mean_squared_grad := E[g^2]_{t-1} 87 | mean_square_grad = sharedX(param.get_value() * 0.) 88 | mean_grad = sharedX(param.get_value() * 0.) 89 | delta_grad = sharedX(param.get_value() * 0.) 90 | 91 | if param.name is None: 92 | raise ValueError("Model parameters must be named.") 93 | 94 | mean_square_grad.name = 'mean_square_grad_' + param.name 95 | 96 | # Accumulate gradient 97 | 98 | new_mean_grad = (decay * mean_grad + (1 - decay) * grads[param]) 99 | new_mean_squared_grad = (decay * mean_square_grad + (1 - decay) * T.sqr(grads[param])) 100 | 101 | # Compute update 102 | scaled_grad = grads[param] / T.sqrt(new_mean_squared_grad - new_mean_grad ** 2 + epsilon) 103 | new_delta_grad = eta * delta_grad - lr * scaled_grad 104 | 105 | # Apply update 106 | updates[delta_grad] = new_delta_grad 107 | updates[mean_grad] = new_mean_grad 108 | updates[mean_square_grad] = new_mean_squared_grad 109 | updates[param] = param + new_delta_grad 110 | 111 | return updates 112 | 113 | class Maxout(object): 114 | def __init__(self, maxout_part): 115 | self.maxout_part = maxout_part 116 | 117 | def __call__(self, x): 118 | shape = x.shape 119 | if x.ndim == 2: 120 | shape1 = T.cast(shape[1] / self.maxout_part, 'int64') 121 | shape2 = T.cast(self.maxout_part, 'int64') 122 | x = x.reshape([shape[0], shape1, shape2]) 123 | x = x.max(2) 124 | else: 125 | shape1 = T.cast(shape[2] / self.maxout_part, 'int64') 126 | shape2 = T.cast(self.maxout_part, 'int64') 127 | x = x.reshape([shape[0], shape[1], shape1, shape2]) 128 | x = x.max(3) 129 | return x 130 | 131 | def UniformInit(rng, sizeX, sizeY, lb=-0.01, ub=0.01): 132 | """ Uniform Init """ 133 | return rng.uniform(size=(sizeX, sizeY), low=lb, high=ub).astype(theano.config.floatX) 134 | 135 | def OrthogonalInit(rng, sizeX, sizeY, sparsity=-1, scale=1): 136 | """ 137 | Orthogonal Initialization 138 | """ 139 | 140 | sizeX = int(sizeX) 141 | sizeY = int(sizeY) 142 | 143 | assert sizeX == sizeY, 'for orthogonal init, sizeX == sizeY' 144 | 145 | if sparsity < 0: 146 | sparsity = sizeY 147 | else: 148 | sparsity = numpy.minimum(sizeY, sparsity) 149 | 150 | values = numpy.zeros((sizeX, sizeY), dtype=theano.config.floatX) 151 | for dx in xrange(sizeX): 152 | perm = rng.permutation(sizeY) 153 | new_vals = rng.normal(loc=0, scale=scale, size=(sparsity,)) 154 | values[dx, perm[:sparsity]] = new_vals 155 | 156 | u,s,v = numpy.linalg.svd(values) 157 | values = u * scale 158 | return values.astype(theano.config.floatX) 159 | 160 | def GrabProbs(classProbs, target, gRange=None): 161 | if classProbs.ndim > 2: 162 | classProbs = classProbs.reshape((classProbs.shape[0] * classProbs.shape[1], classProbs.shape[2])) 163 | else: 164 | classProbs = classProbs 165 | 166 | if target.ndim > 1: 167 | tflat = target.flatten() 168 | else: 169 | tflat = target 170 | return T.diag(classProbs.T[tflat]) 171 | 172 | def NormalInit(rng, sizeX, sizeY, scale=0.01, sparsity=-1): 173 | """ 174 | Normal Initialization 175 | """ 176 | 177 | sizeX = int(sizeX) 178 | sizeY = int(sizeY) 179 | 180 | if sparsity < 0: 181 | sparsity = sizeY 182 | 183 | sparsity = numpy.minimum(sizeY, sparsity) 184 | values = numpy.zeros((sizeX, sizeY), dtype=theano.config.floatX) 185 | for dx in xrange(sizeX): 186 | perm = rng.permutation(sizeY) 187 | new_vals = rng.normal(loc=0, scale=scale, size=(sparsity,)) 188 | values[dx, perm[:sparsity]] = new_vals 189 | 190 | return values.astype(theano.config.floatX) 191 | 192 | def ConvertTimedelta(seconds_diff): 193 | hours = seconds_diff // 3600 194 | minutes = (seconds_diff % 3600) // 60 195 | seconds = (seconds_diff % 60) 196 | return hours, minutes, seconds 197 | 198 | def SoftMax(x): 199 | x = T.exp(x - T.max(x, axis=x.ndim-1, keepdims=True)) 200 | return x / T.sum(x, axis=x.ndim-1, keepdims=True) 201 | -------------------------------------------------------------------------------- /search.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import cPickle 5 | import traceback 6 | import logging 7 | import time 8 | import sys 9 | 10 | import os 11 | import numpy 12 | import codecs 13 | 14 | from dialog_encdec import DialogEncoderDecoder 15 | from numpy_compat import argpartition 16 | from state import prototype_state 17 | logger = logging.getLogger(__name__) 18 | 19 | def sample_wrapper(sample_logic): 20 | def sample_apply(*args, **kwargs): 21 | sampler = args[0] 22 | contexts = args[1] 23 | 24 | verbose = kwargs.get('verbose', False) 25 | 26 | if verbose: 27 | logger.info("Starting {} : {} start sequences in total".format(sampler.name, len(contexts))) 28 | 29 | context_samples = [] 30 | context_costs = [] 31 | 32 | # Start loop for each sentence 33 | for context_id, context_sentences in enumerate(contexts): 34 | if verbose: 35 | logger.info("Searching for {}".format(context_sentences)) 36 | 37 | # Convert contextes into list of ids 38 | joined_context = [] 39 | if len(context_sentences) == 0: 40 | joined_context = [sampler.model.eos_sym] 41 | else: 42 | for sentence in context_sentences: 43 | sentence_ids = sampler.model.words_to_indices(sentence.split()) 44 | # Add sos and eos tokens 45 | joined_context += [sampler.model.sos_sym] + sentence_ids + [sampler.model.eos_sym] 46 | 47 | samples, costs = sample_logic(sampler, joined_context, **kwargs) 48 | 49 | # Convert back indices to list of words 50 | converted_samples = map(lambda sample : sampler.model.indices_to_words(sample), samples) 51 | # Join the list of words 52 | converted_samples = map(' '.join, converted_samples) 53 | 54 | if verbose: 55 | for i in range(len(converted_samples)): 56 | print "{}: {}".format(costs[i], converted_samples[i].encode('utf-8')) 57 | 58 | context_samples.append(converted_samples) 59 | context_costs.append(costs) 60 | 61 | return context_samples, context_costs 62 | return sample_apply 63 | 64 | class Sampler(object): 65 | """ 66 | An abstract sampler class 67 | """ 68 | def __init__(self, model): 69 | # Compile beam search 70 | self.name = 'Sampler' 71 | self.model = model 72 | self.compiled = False 73 | 74 | def compile(self): 75 | self.next_probs_predictor = self.model.build_next_probs_function() 76 | self.compute_encoding = self.model.build_encoder_function() 77 | self.compiled = True 78 | 79 | def select_next_words(self, next_probs, step_num, how_many): 80 | pass 81 | 82 | def count_n_turns(self, sentence): 83 | return len([w for w in sentence \ 84 | if w == self.model.eos_sym]) 85 | 86 | @sample_wrapper 87 | def sample(self, *args, **kwargs): 88 | context = args[0] 89 | 90 | n_samples = kwargs.get('n_samples', 1) 91 | ignore_unk = kwargs.get('ignore_unk', True) 92 | min_length = kwargs.get('min_length', 1) 93 | max_length = kwargs.get('max_length', 100) 94 | beam_diversity = kwargs.get('beam_diversity', 1) 95 | normalize_by_length = kwargs.get('normalize_by_length', True) 96 | verbose = kwargs.get('verbose', False) 97 | n_turns = kwargs.get('n_turns', 1) 98 | 99 | if not self.compiled: 100 | self.compile() 101 | 102 | # Convert to matrix, each column is a context 103 | # [[1,1,1],[4,4,4],[2,2,2]] 104 | context = numpy.repeat(numpy.array(context, dtype='int32')[:,None], 105 | n_samples, axis=1) 106 | if context[-1, 0] != self.model.eos_sym: 107 | raise Exception('Last token of context, when present,' 108 | 'should be the end of sentence: %d' % self.model.eos_sym) 109 | 110 | prev_hd = numpy.zeros((n_samples, self.model.qdim), dtype='float32') 111 | prev_hs = numpy.zeros((n_samples, self.model.sdim), dtype='float32') 112 | 113 | fin_gen = [] 114 | fin_costs = [] 115 | 116 | gen = [[] for i in range(n_samples)] 117 | costs = [0. for i in range(n_samples)] 118 | beam_empty = False 119 | 120 | for k in range(max_length): 121 | if len(fin_gen) >= n_samples or beam_empty: 122 | break 123 | 124 | if verbose: 125 | logger.info("{} : sampling step {}, beams alive {}".format(self.name, k, len(gen))) 126 | 127 | # Here we aggregate the context and recompute the hidden state 128 | # at both session level and query level. 129 | # Stack only when we sampled something 130 | if k > 0: 131 | context = numpy.vstack([context, \ 132 | numpy.array(map(lambda g: g[-1], gen))]).astype('int32') 133 | prev_words = context[-1, :] 134 | 135 | # Recompute hs only for those particular sentences 136 | # that met the end-of-sentence token 137 | indx_update_hs = [num for num, prev_word in enumerate(prev_words) 138 | if prev_word == self.model.eos_sym] 139 | if len(indx_update_hs): 140 | encoder_states = self.compute_encoding(context[:, indx_update_hs]) 141 | prev_hs[indx_update_hs] = encoder_states[-1][-1] 142 | 143 | # ... done 144 | next_probs, new_hd = self.next_probs_predictor(prev_hs, prev_words, prev_hd) 145 | assert next_probs.shape[1] == self.model.idim 146 | 147 | # Adjust log probs according to search restrictions 148 | if ignore_unk: 149 | next_probs[:, self.model.unk_sym] = 0 150 | if k <= min_length: 151 | next_probs[:, self.model.eos_sym] = 0 152 | 153 | # Update costs 154 | next_costs = numpy.array(costs)[:, None] - numpy.log(next_probs) 155 | 156 | # Select next words here 157 | (beam_indx, word_indx), costs = self.select_next_words(next_costs, next_probs, k, n_samples) 158 | 159 | # Update the stacks 160 | new_gen = [] 161 | new_costs = [] 162 | new_sources = [] 163 | 164 | for num, (beam_ind, word_ind, cost) in enumerate(zip(beam_indx, word_indx, costs)): 165 | if len(new_gen) > n_samples: 166 | break 167 | 168 | hypothesis = gen[beam_ind] + [word_ind] 169 | 170 | # End of sentence has been detected 171 | n_turns_hypothesis = self.count_n_turns(hypothesis) 172 | if n_turns_hypothesis == n_turns: 173 | if verbose: 174 | logger.debug("adding sentence {} from beam {}".format(hypothesis, beam_ind)) 175 | 176 | # We finished sampling 177 | fin_gen.append(hypothesis) 178 | fin_costs.append(cost) 179 | else: 180 | # Hypothesis recombination 181 | # TODO: pick the one with lowest cost 182 | has_similar = False 183 | if self.hyp_rec > 0: 184 | has_similar = len([g for g in new_gen if \ 185 | g[-self.hyp_rec:] == hypothesis[-self.hyp_rec:]]) != 0 186 | 187 | if not has_similar: 188 | new_sources.append(beam_ind) 189 | new_gen.append(hypothesis) 190 | new_costs.append(cost) 191 | 192 | if verbose: 193 | for gen in new_gen: 194 | logger.debug("partial -> {}".format(' '.join(self.model.indices_to_words(gen)))) 195 | 196 | prev_hd = new_hd[new_sources] 197 | prev_hs = prev_hs[new_sources] 198 | context = context[:, new_sources] 199 | gen = new_gen 200 | costs = new_costs 201 | beam_empty = len(gen) == 0 202 | 203 | # If we have not sampled anything 204 | # then force include stuff 205 | if len(fin_gen) == 0: 206 | fin_gen = gen 207 | fin_costs = costs 208 | 209 | # Normalize costs 210 | if normalize_by_length: 211 | fin_costs = [(fin_costs[num]/len(fin_gen[num])) \ 212 | for num in range(len(fin_gen))] 213 | 214 | fin_gen = numpy.array(fin_gen)[numpy.argsort(fin_costs)] 215 | fin_costs = numpy.array(sorted(fin_costs)) 216 | return fin_gen[:n_samples], fin_costs[:n_samples] 217 | 218 | class RandomSampler(Sampler): 219 | def __init__(self, model): 220 | Sampler.__init__(self, model) 221 | self.name = 'RandomSampler' 222 | self.hyp_rec = 0 223 | 224 | def select_next_words(self, next_costs, next_probs, step_num, how_many): 225 | # Choice is complaining 226 | next_probs = next_probs.astype("float64") 227 | word_indx = numpy.array([self.model.rng.choice(self.model.idim, p = x/numpy.sum(x)) 228 | for x in next_probs], dtype='int32') 229 | beam_indx = range(next_probs.shape[0]) 230 | 231 | args = numpy.ravel_multi_index(numpy.array([beam_indx, word_indx]), next_costs.shape) 232 | return (beam_indx, word_indx), next_costs.flatten()[args] 233 | 234 | class BeamSampler(Sampler): 235 | def __init__(self, model): 236 | Sampler.__init__(self, model) 237 | self.name = 'BeamSampler' 238 | self.hyp_rec = 3 239 | 240 | def select_next_words(self, next_costs, next_probs, step_num, how_many): 241 | # Pick only on the first line (for the beginning of sampling) 242 | # This will avoid duplicate token. 243 | if step_num == 0: 244 | flat_next_costs = next_costs[:1, :].flatten() 245 | else: 246 | # Set the next cost to infinite for finished sentences (they will be replaced) 247 | # by other sentences in the beam 248 | flat_next_costs = next_costs.flatten() 249 | 250 | voc_size = next_costs.shape[1] 251 | 252 | args = numpy.argpartition(flat_next_costs, how_many)[:how_many] 253 | args = args[numpy.argsort(flat_next_costs[args])] 254 | 255 | return numpy.unravel_index(args, next_costs.shape), flat_next_costs[args] 256 | 257 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Computes BLEU@n / Jaccard. 3 | """ 4 | __docformat__ = 'restructedtext en' 5 | __authors__ = ("Alessandro Sordoni") 6 | __contact__ = "Alessandro Sordoni " 7 | 8 | import sys 9 | import math 10 | import copy 11 | import re 12 | import operator 13 | import collections 14 | from collections import Counter 15 | 16 | import numpy 17 | 18 | def get_ref_length(ref_lens, candidate_len, method='closest'): 19 | if method == 'closest': 20 | len_diff = [(x, numpy.abs(x - candidate_len)) for x in ref_lens] 21 | min_len = sorted(len_diff, key=operator.itemgetter(1))[0][0] 22 | elif method == 'shortest': 23 | min_len = min(ref_lens) 24 | elif method == 'average': 25 | min_len = float(sum(ref_lens))/len(ref_lens) 26 | return min_len 27 | 28 | def normalize(sentence): 29 | return sentence.strip().split() 30 | 31 | def count_ngrams(sentences, n=4): 32 | global_counts = {} 33 | for sentence in sentences: 34 | local_counts = {} 35 | list_len = len(sentence) 36 | 37 | for k in xrange(1, n + 1): 38 | for i in range(list_len - k + 1): 39 | ngram = tuple(sentence[i:i+k]) 40 | local_counts[ngram] = local_counts.get(ngram, 0) + 1 41 | 42 | ### Store maximum occurrence; useful for multireference bleu 43 | for ngram, count in local_counts.items(): 44 | global_counts[ngram] = max(global_counts.get(ngram, 0), count) 45 | return global_counts 46 | 47 | def count_letter_ngram(sentence, n=3): 48 | local_counts = set() 49 | for k in range(len(sentence.strip()) - n + 1): 50 | local_counts.add(sentence[k:k+n]) 51 | return local_counts 52 | 53 | class Jaccard: 54 | """ 55 | Jaccard n-letter-gram similarity. 56 | Use: 57 | >>> j = Jaccard() 58 | >>> j.update("i have it", "i have is") 59 | >>> print j.compute() 60 | 0.75 61 | >>> j.reset() 62 | """ 63 | def __init__(self, n=3): 64 | self.n = n 65 | self.statistics = [] 66 | 67 | def aggregate(self): 68 | if len(self.statistics) == 0: 69 | return numpy.zeros((1,)) 70 | stat_matrix = numpy.array(self.statistics) 71 | return numpy.mean(stat_matrix) 72 | 73 | def update(self, candidate, ref): 74 | stats = numpy.zeros((1,)) 75 | 76 | cand_ngrams = count_letter_ngram(candidate, self.n) 77 | ref_ngrams = count_letter_ngram(ref, self.n) 78 | stats[0] = float(len(cand_ngrams & ref_ngrams)) / len(cand_ngrams | ref_ngrams) 79 | self.statistics.append(stats) 80 | 81 | def compute(self): 82 | stats = self.aggregate() 83 | #return stats[0] 84 | return stats 85 | 86 | def reset(self): 87 | self.statistics = [] 88 | 89 | class JaccardEvaluator(object): 90 | """ Jaccard evaluator 91 | """ 92 | def __init__(self): 93 | self.jaccard = Jaccard() 94 | 95 | def evaluate(self, prediction, target): 96 | if len(target) != len(prediction): 97 | raise ValueError('Target and predictions length mismatch!') 98 | 99 | # Assume ordered list and take only the first one 100 | if isinstance(prediction[0], list): 101 | prediction = [x[0] for x in prediction] 102 | 103 | self.jaccard.reset() 104 | for ts, ps in zip(target, prediction): 105 | self.jaccard.update(ps, *ts) 106 | return self.jaccard.compute() 107 | 108 | class Bleu: 109 | """ 110 | Bleu score. 111 | Use: 112 | >>> b = Bleu() 113 | >>> b.update("i have this", "i have this :)", "oh my my") # multi-references 114 | >>> b.compute() 115 | >>> b.reset() 116 | """ 117 | def __init__(self, n=4): 118 | # Statistics are 119 | # - 1-gramcount, 120 | # - 2-gramcount, 121 | # - 3-gramcount, 122 | # - 4-gramcount, 123 | # - 1-grammatch, 124 | # - 2-grammatch, 125 | # - 3-grammatch, 126 | # - 4-grammatch, 127 | # - reflen 128 | self.n = n 129 | self.statistics = [] 130 | 131 | def aggregate(self): 132 | if len(self.statistics) == 0: 133 | return numpy.zeros((2 * self.n + 1,)) 134 | stat_matrix = numpy.array(self.statistics) 135 | return numpy.sum(stat_matrix, axis=0) 136 | 137 | def update(self, candidate, *refs): 138 | refs = [normalize(ref) for ref in refs] 139 | candidate = normalize(candidate) 140 | 141 | stats = numpy.zeros((2 * self.n + 1,)) 142 | stats[-1] = get_ref_length(map(len, refs), len(candidate)) 143 | 144 | cand_ngram_counts = count_ngrams([candidate], self.n) 145 | refs_ngram_counts = count_ngrams(refs, self.n) 146 | 147 | for ngram, count in cand_ngram_counts.items(): 148 | stats[len(ngram) + self.n - 1] += min(count, refs_ngram_counts.get(ngram, 0)) 149 | for k in xrange(1, self.n + 1): 150 | stats[k - 1] = max(len(candidate) - k + 1, 0) 151 | self.statistics.append(stats) 152 | 153 | def compute(self, smoothing=0, length_penalty=1): 154 | precs = numpy.zeros((self.n + 1,)) 155 | stats = self.aggregate() 156 | log_bleu = 0. 157 | 158 | for k in range(self.n): 159 | correct = float(stats[self.n + k] + smoothing) 160 | if correct == 0.: 161 | return 0., precs 162 | total = float(stats[k] + 2*smoothing) 163 | precs[k] = numpy.log(correct) - numpy.log(total) 164 | log_bleu += precs[k] 165 | 166 | log_bleu /= float(self.n) 167 | stats[-1] = stats[-1] * length_penalty 168 | log_bleu += min(0, 1 - float(stats[0]/stats[-1])) 169 | return numpy.exp(log_bleu), numpy.exp(precs) 170 | 171 | def reset(self): 172 | self.statistics = [] 173 | 174 | class BleuEvaluator(object): 175 | """ Bleu evaluator 176 | """ 177 | def __init__(self): 178 | self.bleu = Bleu() 179 | 180 | def evaluate(self, prediction, target): 181 | if len(target) != len(prediction): 182 | raise ValueError('Target and predictions length mismatch!') 183 | 184 | # Assume ordered list and take only the first one 185 | if isinstance(prediction[0], list): 186 | prediction = [x[0] for x in prediction] 187 | 188 | self.bleu.reset() 189 | for ts, ps in zip(target, prediction): 190 | self.bleu.update(ps, *ts) 191 | return self.bleu.compute() 192 | 193 | 194 | 195 | class Recall: 196 | """ 197 | Evaluate mean recall at utterance level. 198 | Use: 199 | >>> r = Recall() 200 | >>> r.update("i have it", ["i have is", "i have some"]) 201 | >>> r.update("i have it", ["i have is", "i have it"]) 202 | >>> print r.compute() 203 | 0.5 204 | >>> r.reset() 205 | """ 206 | def __init__(self, n): 207 | self.n = n 208 | self.statistics = [] 209 | 210 | def aggregate(self): 211 | if len(self.statistics) == 0: 212 | return numpy.zeros((1,)) 213 | stat_matrix = numpy.array(self.statistics) 214 | return float(numpy.mean(stat_matrix)) 215 | 216 | def update(self, candidates, ref): 217 | stats = numpy.zeros((1,)) 218 | 219 | for candidate in candidates: 220 | if candidate == ref: 221 | stats[0] = 1 222 | self.statistics.append(stats) 223 | return 224 | 225 | stats[0] = 0 226 | self.statistics.append(stats) 227 | 228 | def compute(self): 229 | stats = self.aggregate() 230 | return stats 231 | 232 | def reset(self): 233 | self.statistics = [] 234 | 235 | class RecallEvaluator(object): 236 | """ Recall evaluator 237 | """ 238 | def __init__(self, n=5): 239 | self.recall = Recall(n) 240 | self.n = n 241 | 242 | def evaluate(self, prediction, target): 243 | if len(target) != len(prediction): 244 | raise ValueError('Target and predictions length mismatch!') 245 | 246 | self.recall.reset() 247 | for ts, ps in zip(target, prediction): 248 | #assert(len(ps) >= self.n) 249 | # Replace missing samples with last sample instead of throwing an error 250 | samples_len = len(ps) 251 | if samples_len >= self.n: 252 | ps_complete = ps[0:self.n] 253 | else: 254 | ps_complete = ps[0:samples_len] 255 | miss = self.n - samples_len 256 | last_element = ps[samples_len-1] 257 | for i in range(miss): 258 | ps_complete.append(last_element) 259 | 260 | self.recall.update(ps_complete, *ts) 261 | 262 | return self.recall.compute() 263 | 264 | 265 | class MRR: 266 | """ 267 | Evaluate mean reciprocal rank. 268 | Use: 269 | >>> r = MRR() 270 | >>> r.update("i have it", ["i have is", "i have some"]) 271 | >>> r.update("i have it", ["i have is", "i have it"]) 272 | >>> print r.compute() 273 | 0.25 274 | >>> r.reset() 275 | """ 276 | def __init__(self, n): 277 | self.n = n 278 | self.statistics = [] 279 | 280 | def aggregate(self): 281 | if len(self.statistics) == 0: 282 | return numpy.zeros((1,)) 283 | stat_matrix = numpy.array(self.statistics) 284 | return float(numpy.mean(stat_matrix)) 285 | 286 | def update(self, candidates, ref): 287 | stats = numpy.zeros((1,)) 288 | 289 | for index in range(len(candidates)): 290 | if candidates[index] == ref: 291 | stats[0] = 1/(index+1) 292 | self.statistics.append(stats) 293 | return 294 | 295 | self.statistics.append(stats) 296 | 297 | def compute(self): 298 | stats = self.aggregate() 299 | return stats 300 | 301 | def reset(self): 302 | self.statistics = [] 303 | 304 | class MRREvaluator(object): 305 | """ Mean reciprocal rank evaluator 306 | """ 307 | def __init__(self, n=5): 308 | self.mrr = MRR(n) 309 | self.n = n 310 | 311 | def evaluate(self, prediction, target): 312 | if len(target) != len(prediction): 313 | raise ValueError('Target and predictions length mismatch!') 314 | 315 | self.mrr.reset() 316 | for ts, ps in zip(target, prediction): 317 | #assert(len(ps) >= self.n) 318 | # Replace missing samples with last sample instead of throwing an error 319 | samples_len = len(ps) 320 | if samples_len >= self.n: 321 | ps_complete = ps[0:self.n] 322 | else: 323 | ps_complete = ps[0:samples_len] 324 | miss = self.n - samples_len 325 | last_element = ps[samples_len-1] 326 | for i in range(miss): 327 | ps_complete.append(last_element) 328 | 329 | self.mrr.update(ps_complete, *ts) 330 | 331 | return self.mrr.compute() 332 | 333 | 334 | class TFIDF_CS: 335 | """ 336 | Evaluate TF-IDF-based cosine similarity. 337 | Use: 338 | >>> tfidf_cs = TFIDF_CS() 339 | >>> tfidf_cs.update("i have it", ["i have is", "i have some"]) 340 | >>> tfidf_cs.update("i have it", ["i have is", "i have it"]) 341 | >>> print tfidf_cs.compute() 342 | >>> tfidf_cs.reset() 343 | """ 344 | def __init__(self, model, document_count, n): 345 | self.model = model 346 | self.document_count = document_count 347 | self.n = n 348 | self.statistics = [] 349 | 350 | def aggregate(self): 351 | if len(self.statistics) == 0: 352 | return numpy.zeros((1,)) 353 | stat_matrix = numpy.array(self.statistics) 354 | return float(numpy.mean(stat_matrix)) 355 | 356 | def update(self, candidates, ref): 357 | stats = numpy.zeros((1,)) 358 | 359 | # Split reference (target) into word indices and count each word 360 | ref_words = normalize(ref) 361 | 362 | # We don't count empty targets, since these would always give cosine similarity one with empty responses! 363 | if len(ref_words) == 0: 364 | return 365 | 366 | ref_indices = self.model.words_to_indices(ref_words) 367 | 368 | ref_counter = Counter(ref_indices) 369 | ref_indices_unique = list(set(ref_indices)) 370 | 371 | # Compute reference (target) vector 372 | ref_vector = numpy.zeros((len(ref_indices))) 373 | for i in range(len(ref_indices_unique)): 374 | word_index = ref_indices_unique[i] 375 | ref_vector[i] = ref_counter[word_index] * math.log(self.document_count/max(1, self.model.document_freq[word_index])) 376 | 377 | ref_vector_norm = numpy.sqrt(numpy.sum(ref_vector**2)) 378 | 379 | # We don't count references which we cannot match (this should never happen in the dataset anyway, but it does happen in our tests...) 380 | if ref_vector_norm < 0.0000001: 381 | return 382 | 383 | best_score = 0 384 | for candidate in candidates: 385 | # Split candidate into word indices and count each word 386 | cand_words = normalize(candidate) 387 | cand_indices = self.model.words_to_indices(cand_words) 388 | cand_counter = Counter(cand_indices) 389 | 390 | # Loop over unique indices (to speed up calculations) and compute un-normalized cosine similarity 391 | current_score = 0 392 | cand_norm = 0 393 | ref_norm = 0 394 | cand_vector = numpy.zeros((len(ref_indices))) 395 | cand_vector_norm = 0 396 | 397 | # Compute irrespective of reference 398 | for word_index in cand_counter.keys(): 399 | cand_vector_norm += (cand_counter[word_index] * math.log(self.document_count/max(1, self.model.document_freq[word_index])))**2 400 | cand_vector_norm = numpy.sqrt(cand_vector_norm) 401 | 402 | # Compute candidate vector 403 | for i in range(len(ref_indices_unique)): 404 | word_index = ref_indices_unique[i] 405 | if cand_counter[word_index] > 0: 406 | cand_vector[i] = cand_counter[word_index] * math.log(self.document_count/max(1, self.model.document_freq[word_index])) 407 | 408 | if cand_vector_norm > 0: 409 | current_score = float(numpy.dot(cand_vector.T, ref_vector) / (cand_vector_norm*ref_vector_norm)) 410 | 411 | if current_score > best_score: 412 | best_score = current_score 413 | 414 | 415 | self.statistics.append(best_score) 416 | 417 | def compute(self): 418 | stats = self.aggregate() 419 | return stats 420 | 421 | def reset(self): 422 | self.statistics = [] 423 | 424 | class TFIDF_CS_Evaluator(object): 425 | """ Mean TF-IDF-based cosine similarity evaluator 426 | """ 427 | def __init__(self, model, document_count, n): 428 | self.tfidf_cs = TFIDF_CS(model, document_count, n) 429 | self.n = n 430 | 431 | def evaluate(self, prediction, target): 432 | if len(target) != len(prediction): 433 | raise ValueError('Target and predictions length mismatch!') 434 | 435 | self.tfidf_cs.reset() 436 | for ts, ps in zip(target, prediction): 437 | #assert(len(ps) >= self.n) 438 | # Replace missing samples with last sample instead of throwing an error 439 | samples_len = len(ps) 440 | if samples_len >= self.n: 441 | ps_complete = ps[0:self.n] 442 | else: 443 | ps_complete = ps[0:samples_len] 444 | miss = self.n - samples_len 445 | last_element = ps[samples_len-1] 446 | for i in range(miss): 447 | ps_complete.append(last_element) 448 | 449 | self.tfidf_cs.update(ps_complete, *ts) 450 | 451 | return self.tfidf_cs.compute() 452 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #!/usr/bin/env python 3 | 4 | from data_iterator import * 5 | from state import * 6 | from dialog_encdec import * 7 | from utils import * 8 | from evaluation import * 9 | 10 | import time 11 | import traceback 12 | import os.path 13 | import sys 14 | import argparse 15 | import cPickle 16 | import logging 17 | import search 18 | import pprint 19 | import numpy 20 | import collections 21 | import signal 22 | import math 23 | 24 | import matplotlib 25 | matplotlib.use('Agg') 26 | import pylab 27 | 28 | class Unbuffered: 29 | def __init__(self, stream): 30 | self.stream = stream 31 | 32 | def write(self, data): 33 | self.stream.write(data) 34 | self.stream.flush() 35 | 36 | def __getattr__(self, attr): 37 | return getattr(self.stream, attr) 38 | 39 | sys.stdout = Unbuffered(sys.stdout) 40 | logger = logging.getLogger(__name__) 41 | 42 | ### Unique RUN_ID for this execution 43 | RUN_ID = str(time.time()) 44 | 45 | ### Additional measures can be set here 46 | measures = ["train_cost", "train_misclass", 47 | "valid_cost", "valid_misclass", 48 | "valid_emi", "valid_bleu", 49 | "valid_jaccard", "valid_recall_at_1", 50 | "valid_recall_at_5", "valid_mrr_at_5", 51 | "tfidf_cs_at_1", "tfidf_cs_at_5"] 52 | 53 | def init_timings(): 54 | timings = {} 55 | for m in measures: 56 | timings[m] = [] 57 | return timings 58 | 59 | def save(model, timings): 60 | print "Saving the model..." 61 | 62 | # ignore keyboard interrupt while saving 63 | start = time.time() 64 | s = signal.signal(signal.SIGINT, signal.SIG_IGN) 65 | 66 | model.save(model.state['save_dir'] + '/' + model.state['run_id'] + "_" + model.state['prefix'] + 'model.npz') 67 | cPickle.dump(model.state, open(model.state['save_dir'] + '/' + model.state['run_id'] + "_" + model.state['prefix'] + 'state.pkl', 'w')) 68 | numpy.savez(model.state['save_dir'] + '/' + model.state['run_id'] + "_" + model.state['prefix'] + 'timing.npz', **timings) 69 | signal.signal(signal.SIGINT, s) 70 | 71 | print "Model saved, took {}".format(time.time() - start) 72 | 73 | def load(model, filename): 74 | print "Loading the model..." 75 | 76 | # ignore keyboard interrupt while saving 77 | start = time.time() 78 | s = signal.signal(signal.SIGINT, signal.SIG_IGN) 79 | model.load(filename) 80 | signal.signal(signal.SIGINT, s) 81 | 82 | print "Model loaded, took {}".format(time.time() - start) 83 | 84 | def main(args): 85 | logging.basicConfig(level = logging.DEBUG, 86 | format = "%(asctime)s: %(name)s: %(levelname)s: %(message)s") 87 | 88 | state = eval(args.prototype)() 89 | timings = init_timings() 90 | 91 | 92 | if args.resume != "": 93 | logger.debug("Resuming %s" % args.resume) 94 | 95 | state_file = args.resume + '_state.pkl' 96 | timings_file = args.resume + '_timing.npz' 97 | 98 | if os.path.isfile(state_file) and os.path.isfile(timings_file): 99 | logger.debug("Loading previous state") 100 | 101 | state = cPickle.load(open(state_file, 'r')) 102 | timings = dict(numpy.load(open(timings_file, 'r'))) 103 | for x, y in timings.items(): 104 | timings[x] = list(y) 105 | else: 106 | raise Exception("Cannot resume, cannot find files!") 107 | 108 | logger.debug("State:\n{}".format(pprint.pformat(state))) 109 | logger.debug("Timings:\n{}".format(pprint.pformat(timings))) 110 | 111 | model = DialogEncoderDecoder(state) 112 | rng = model.rng 113 | 114 | if args.resume != "": 115 | filename = args.resume + '_model.npz' 116 | if os.path.isfile(filename): 117 | logger.debug("Loading previous model") 118 | load(model, filename) 119 | else: 120 | raise Exception("Cannot resume, cannot find model file!") 121 | 122 | if 'run_id' not in model.state: 123 | raise Exception('Backward compatibility not ensured! (need run_id in state)') 124 | else: 125 | # assign new run_id key 126 | model.state['run_id'] = RUN_ID 127 | 128 | logger.debug("Compile trainer") 129 | if not state["use_nce"]: 130 | train_batch = model.build_train_function() 131 | else: 132 | train_batch = model.build_nce_function() 133 | 134 | eval_batch = model.build_eval_function() 135 | eval_misclass_batch = model.build_eval_misclassification_function() 136 | 137 | random_sampler = search.RandomSampler(model) 138 | beam_sampler = search.BeamSampler(model) 139 | 140 | logger.debug("Load data") 141 | train_data, \ 142 | valid_data, \ 143 | test_data = get_batch_iterator(rng, state) 144 | 145 | train_data.start() 146 | 147 | # Build the data structures for Bleu evaluation 148 | if 'bleu_evaluation' in state: 149 | bleu_eval = BleuEvaluator() 150 | jaccard_eval = JaccardEvaluator() 151 | recall_at_1_eval = RecallEvaluator(n=1) 152 | recall_at_5_eval = RecallEvaluator(n=5) 153 | mrr_at_5_eval = MRREvaluator(n=5) 154 | tfidf_cs_at_1_eval = TFIDF_CS_Evaluator(model, train_data.data_len, 1) 155 | tfidf_cs_at_5_eval = TFIDF_CS_Evaluator(model, train_data.data_len, 5) 156 | 157 | samples = open(state['bleu_evaluation'], 'r').readlines() 158 | n = state['bleu_context_length'] 159 | 160 | contexts = [] 161 | targets = [] 162 | for x in samples: 163 | sentences = x.strip().split('\t') 164 | assert len(sentences) > n 165 | contexts.append(sentences[:n]) 166 | targets.append(sentences[n:]) 167 | 168 | # Start looping through the dataset 169 | step = 0 170 | patience = state['patience'] 171 | start_time = time.time() 172 | 173 | train_cost = 0 174 | train_misclass = 0 175 | train_done = 0 176 | ex_done = 0 177 | 178 | while (step < state['loop_iters'] and 179 | (time.time() - start_time)/60. < state['time_stop'] and 180 | patience >= 0): 181 | 182 | # Sample stuff 183 | if step % 200 == 0: 184 | for param in model.params: 185 | print "%s = %.4f" % (param.name, numpy.sum(param.get_value() ** 2) ** 0.5) 186 | 187 | samples, costs = random_sampler.sample([[]], n_samples=1, n_turns=3) 188 | print "Sampled : {}".format(samples[0]) 189 | 190 | # Training phase 191 | batch = train_data.next() 192 | 193 | # Train finished 194 | if not batch: 195 | # Restart training 196 | logger.debug("Got None...") 197 | break 198 | 199 | logger.debug("[TRAIN] - Got batch %d,%d" % (batch['x'].shape[1], batch['max_length'])) 200 | 201 | x_data = batch['x'] 202 | max_length = batch['max_length'] 203 | x_cost_mask = batch['x_mask'] 204 | 205 | if state['use_nce']: 206 | y_neg = rng.choice(size=(10, max_length, x_data.shape[1]), a=model.idim, p=model.noise_probs).astype('int32') 207 | c = train_batch(x_data, y_neg, max_length, x_cost_mask) 208 | else: 209 | c = train_batch(x_data, max_length, x_cost_mask) 210 | 211 | if numpy.isinf(c) or numpy.isnan(c): 212 | logger.warn("Got NaN cost .. skipping") 213 | continue 214 | 215 | train_cost += c 216 | 217 | # Compute word-error rate 218 | miscl = eval_misclass_batch(x_data, max_length, x_cost_mask) 219 | if numpy.isinf(c) or numpy.isnan(c): 220 | logger.warn("Got NaN misclassification .. skipping") 221 | continue 222 | 223 | train_misclass += miscl 224 | 225 | train_done += batch['num_preds'] 226 | 227 | this_time = time.time() 228 | if step % state['train_freq'] == 0: 229 | elapsed = this_time - start_time 230 | h, m, s = ConvertTimedelta(this_time - start_time) 231 | print ".. %.2d:%.2d:%.2d %4d mb # %d bs %d maxl %d acc_cost = %.4f acc_word_perplexity = %.4f acc_mean_word_error = %.4f " % (h, m, s,\ 232 | state['time_stop'] - (time.time() - start_time)/60.,\ 233 | step, \ 234 | batch['x'].shape[1], \ 235 | batch['max_length'], \ 236 | float(train_cost/train_done), \ 237 | math.exp(float(train_cost/train_done)), \ 238 | float(train_misclass)/float(train_done)) 239 | 240 | 241 | 242 | 243 | if valid_data is not None and\ 244 | step % state['valid_freq'] == 0 and step > 1: 245 | valid_data.start() 246 | valid_cost = 0 247 | valid_misclass = 0 248 | valid_empirical_mutual_information = 0 249 | valid_wordpreds_done = 0 250 | valid_triples_done = 0 251 | 252 | 253 | # Prepare variables for plotting histogram over word-perplexities and mutual information 254 | valid_data_len = valid_data.data_len 255 | valid_cost_list = numpy.zeros((valid_data_len,)) 256 | valid_pmi_list = numpy.zeros((valid_data_len,)) 257 | 258 | 259 | # Prepare variables for printing the training examples the model performs best and worst on 260 | valid_extrema_setsize = min(state['track_extrema_samples_count'], valid_data_len) 261 | valid_extrema_samples_to_print = min(state['print_extrema_samples_count'], valid_extrema_setsize) 262 | 263 | valid_lowest_costs = numpy.ones((valid_extrema_setsize,))*1000 264 | valid_lowest_triples = numpy.ones((valid_extrema_setsize,state['seqlen']))*1000 265 | valid_highest_costs = numpy.ones((valid_extrema_setsize,))*(-1000) 266 | valid_highest_triples = numpy.ones((valid_extrema_setsize,state['seqlen']))*(-1000) 267 | 268 | 269 | logger.debug("[VALIDATION START]") 270 | 271 | while True: 272 | batch = valid_data.next() 273 | # Train finished 274 | if not batch: 275 | break 276 | 277 | logger.debug("[VALID] - Got batch %d,%d" % (batch['x'].shape[1], batch['max_length'])) 278 | 279 | x_data = batch['x'] 280 | max_length = batch['max_length'] 281 | x_cost_mask = batch['x_mask'] 282 | 283 | 284 | c, c_list = eval_batch(x_data, max_length, x_cost_mask) 285 | c_list = c_list.reshape((batch['x'].shape[1],max_length), order=(1,0)) 286 | c_list = numpy.sum(c_list, axis=1) 287 | 288 | words_in_triples = numpy.sum(x_cost_mask, axis=0) 289 | c_list = c_list / words_in_triples 290 | 291 | if numpy.isinf(c) or numpy.isnan(c): 292 | continue 293 | 294 | valid_cost += c 295 | nxt = min((valid_triples_done+batch['x'].shape[1]), valid_data_len) 296 | triples_in_batch = nxt-valid_triples_done 297 | valid_cost_list[(nxt-triples_in_batch):nxt] = numpy.exp(c_list[0:triples_in_batch]) 298 | 299 | # Store best and worst validation costs 300 | con_costs = np.concatenate([valid_lowest_costs, c_list[0:triples_in_batch]]) 301 | con_triples = np.concatenate([valid_lowest_triples, x_data[:, 0:triples_in_batch].T], axis=0) 302 | con_indices = con_costs.argsort()[0:valid_extrema_setsize][::1] 303 | valid_lowest_costs = con_costs[con_indices] 304 | valid_lowest_triples = con_triples[con_indices] 305 | 306 | con_costs = np.concatenate([valid_highest_costs, c_list[0:triples_in_batch]]) 307 | con_triples = np.concatenate([valid_highest_triples, x_data[:, 0:triples_in_batch].T], axis=0) 308 | con_indices = con_costs.argsort()[-valid_extrema_setsize:][::-1] 309 | valid_highest_costs = con_costs[con_indices] 310 | valid_highest_triples = con_triples[con_indices] 311 | 312 | 313 | # Compute word-error rate 314 | miscl = eval_misclass_batch(x_data, max_length, x_cost_mask) 315 | if numpy.isinf(c) or numpy.isnan(c): 316 | continue 317 | 318 | valid_misclass += miscl 319 | 320 | # Compute empirical mutual information 321 | if state['compute_mutual_information'] == True: 322 | # Compute marginal log-likelihood of last utterance in triple: 323 | # We approximate it with the margina log-probabiltiy of the utterance being observed first in the triple 324 | x_data_last_utterance = batch['x_last_utterance'] 325 | x_cost_mask_last_utterance = batch['x_mask_last_utterance'] 326 | marginal_last_utterance_loglikelihood, marginal_last_utterance_loglikelihood_list = eval_batch(x_data_last_utterance, max_length, x_cost_mask_last_utterance) 327 | marginal_last_utterance_loglikelihood_list = marginal_last_utterance_loglikelihood_list.reshape((batch['x'].shape[1],max_length), order=(1,0)) 328 | marginal_last_utterance_loglikelihood_list = numpy.sum(marginal_last_utterance_loglikelihood_list, axis=1) 329 | # If we wanted to normalize histogram plots by utterance length, we should enable this: 330 | #words_in_last_utterance = numpy.sum(x_cost_mask_last_utterance, axis=0) 331 | #marginal_last_utterance_loglikelihood_list = marginal_last_utterance_loglikelihood_list / words_in_last_utterance 332 | 333 | # Compute marginal log-likelihood of first utterances in triple by masking the last utterance 334 | x_cost_mask_first_utterances = x_cost_mask - x_cost_mask_last_utterance 335 | marginal_first_utterances_loglikelihood, marginal_first_utterances_loglikelihood_list = eval_batch(x_data, max_length, x_cost_mask_first_utterances) 336 | 337 | marginal_first_utterances_loglikelihood_list = marginal_first_utterances_loglikelihood_list.reshape((batch['x'].shape[1],max_length), order=(1,0)) 338 | marginal_first_utterances_loglikelihood_list = numpy.sum(marginal_first_utterances_loglikelihood_list, axis=1) 339 | 340 | # If we wanted to normalize histogram plots by utterance length, we should enable this: 341 | #words_in_first_utterances = numpy.sum(x_cost_mask_first_utterances, axis=0) 342 | #marginal_first_utterances_loglikelihood_list = marginal_first_utterances_loglikelihood_list / words_in_first_utterances 343 | 344 | 345 | # Compute empirical mutual information and pointwise empirical mutual information 346 | valid_empirical_mutual_information += -c + marginal_first_utterances_loglikelihood + marginal_last_utterance_loglikelihood 347 | valid_pmi_list[(nxt-triples_in_batch):nxt] = (-c_list*words_in_triples + marginal_first_utterances_loglikelihood_list + marginal_last_utterance_loglikelihood_list)[0:triples_in_batch] 348 | 349 | valid_wordpreds_done += batch['num_preds'] 350 | valid_triples_done += batch['x'].shape[1] 351 | 352 | 353 | logger.debug("[VALIDATION END]") 354 | 355 | valid_cost /= valid_wordpreds_done 356 | valid_misclass /= float(valid_wordpreds_done) 357 | valid_empirical_mutual_information /= float(valid_triples_done) 358 | 359 | 360 | if len(timings["valid_cost"]) == 0 or valid_cost < timings["valid_cost"][-1]: 361 | patience = state['patience'] 362 | # Saving model if decrease in validation cost 363 | save(model, timings) 364 | elif valid_cost >= timings["valid_cost"][-1] * state['cost_threshold']: 365 | patience -= 1 366 | 367 | print "** valid cost = %.4f, valid word-perplexity = %.4f, valid mean word-error = %.4f, valid emp. mutual information = %.4f, patience = %d" % (float(valid_cost), float(math.exp(valid_cost)), float(valid_misclass), valid_empirical_mutual_information, patience) 368 | 369 | timings["train_cost"].append(train_cost/train_done) 370 | timings["train_misclass"].append(float(train_misclass)/float(train_done)) 371 | timings["valid_cost"].append(valid_cost) 372 | timings["valid_misclass"].append(valid_misclass) 373 | timings["valid_emi"].append(valid_empirical_mutual_information) 374 | 375 | 376 | # Reset train cost, train misclass and train done 377 | train_cost = 0 378 | train_misclass = 0 379 | train_done = 0 380 | 381 | # Plot histogram over validation costs 382 | try: 383 | pylab.figure() 384 | bins = range(0, 50, 1) 385 | pylab.hist(valid_cost_list, normed=1, histtype='bar') 386 | pylab.savefig(model.state['save_dir'] + '/' + model.state['run_id'] + "_" + model.state['prefix'] + 'Valid_WordPerplexities_'+ str(step) + '.png') 387 | except: 388 | pass 389 | 390 | 391 | # Print 5 of 10% validation samples with highest log-likelihood 392 | if state['track_extrema_validation_samples']==True: 393 | print " highest word log-likelihood valid samples: " 394 | np.random.shuffle(valid_lowest_triples) 395 | for i in range(valid_extrema_samples_to_print): 396 | print " Sample: {}".format(" ".join(model.indices_to_words(numpy.ravel(valid_lowest_triples[i,:])))) 397 | 398 | print " lowest word log-likelihood valid samples: " 399 | np.random.shuffle(valid_highest_triples) 400 | for i in range(valid_extrema_samples_to_print): 401 | print " Sample: {}".format(" ".join(model.indices_to_words(numpy.ravel(valid_highest_triples[i,:])))) 402 | 403 | # Plot histogram over empirical pointwise mutual informations 404 | if state['compute_mutual_information'] == True: 405 | try: 406 | pylab.figure() 407 | bins = range(0, 100, 1) 408 | pylab.hist(valid_pmi_list, normed=1, histtype='bar') 409 | pylab.savefig(model.state['save_dir'] + '/' + model.state['run_id'] + "_" + model.state['prefix'] + 'Valid_PMI_'+ str(step) + '.png') 410 | except: 411 | pass 412 | 413 | if 'bleu_evaluation' in state and \ 414 | step % state['valid_freq'] == 0 and step > 1: 415 | 416 | # Compute samples with beam search 417 | logger.debug("Executing beam search to get targets for bleu, jaccard etc.") 418 | samples, costs = beam_sampler.sample(contexts, n_samples=5, ignore_unk=True) 419 | logger.debug("Finished beam search.") 420 | 421 | assert len(samples) == len(contexts) 422 | #print 'samples', samples 423 | 424 | # Bleu evaluation 425 | bleu = bleu_eval.evaluate(samples, targets) 426 | 427 | print "** bleu score = %.4f " % bleu[0] 428 | timings["valid_bleu"].append(bleu[0]) 429 | 430 | # Jaccard evaluation 431 | jaccard = jaccard_eval.evaluate(samples, targets) 432 | 433 | print "** jaccard score = %.4f " % jaccard 434 | timings["valid_jaccard"].append(jaccard) 435 | 436 | # Recall evaluation 437 | recall_at_1 = recall_at_1_eval.evaluate(samples, targets) 438 | 439 | print "** recall@1 score = %.4f " % recall_at_1 440 | timings["valid_recall_at_1"].append(recall_at_1) 441 | 442 | recall_at_5 = recall_at_5_eval.evaluate(samples, targets) 443 | 444 | print "** recall@5 score = %.4f " % recall_at_5 445 | timings["valid_recall_at_5"].append(recall_at_5) 446 | 447 | mrr_at_5 = mrr_at_5_eval.evaluate(samples, targets) 448 | 449 | # MRR evaluation (equivalent to mean average precision) 450 | print "** mrr@5 score = %.4f " % mrr_at_5 451 | timings["valid_mrr_at_5"].append(mrr_at_5) 452 | 453 | # TF-IDF cosine similarity evaluation 454 | tfidf_cs_at_1 = tfidf_cs_at_1_eval.evaluate(samples, targets) 455 | 456 | print "** tfidf-cs@1 score = %.4f " % tfidf_cs_at_1 457 | timings["tfidf_cs_at_1"].append(tfidf_cs_at_1) 458 | 459 | tfidf_cs_at_5 = tfidf_cs_at_5_eval.evaluate(samples, targets) 460 | 461 | print "** tfidf-cs@5 score = %.4f " % tfidf_cs_at_5 462 | timings["tfidf_cs_at_5"].append(tfidf_cs_at_5) 463 | 464 | step += 1 465 | 466 | logger.debug("All done, exiting...") 467 | 468 | def parse_args(): 469 | parser = argparse.ArgumentParser() 470 | parser.add_argument("--resume", type=str, default="", help="Resume training from that state") 471 | parser.add_argument("--prototype", type=str, help="Use the prototype", default='prototype_state') 472 | 473 | args = parser.parse_args() 474 | return args 475 | 476 | if __name__ == "__main__": 477 | args = parse_args() 478 | main(args) 479 | -------------------------------------------------------------------------------- /dialog_encdec.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dialog hierarchical encoder-decoder code. 3 | The code is inspired from nmt encdec code in groundhog 4 | but we do not rely on groundhog infrastructure. 5 | """ 6 | __docformat__ = 'restructedtext en' 7 | __authors__ = ("Alessandro Sordoni") 8 | __contact__ = "Alessandro Sordoni " 9 | 10 | import theano 11 | import theano.tensor as T 12 | import numpy as np 13 | import cPickle 14 | import logging 15 | logger = logging.getLogger(__name__) 16 | 17 | from theano.sandbox.scan import scan 18 | from theano.sandbox.rng_mrg import MRG_RandomStreams 19 | from theano.tensor.nnet.conv3d2d import * 20 | from collections import OrderedDict 21 | 22 | from model import * 23 | from utils import * 24 | 25 | import operator 26 | 27 | # Theano speed-up 28 | theano.config.scan.allow_gc = False 29 | # 30 | 31 | def add_to_params(params, new_param): 32 | params.append(new_param) 33 | return new_param 34 | 35 | class EncoderDecoderBase(): 36 | def __init__(self, state, rng, parent): 37 | self.rng = rng 38 | self.parent = parent 39 | 40 | self.state = state 41 | self.__dict__.update(state) 42 | 43 | self.triple_rec_activation = eval(self.triple_rec_activation) 44 | self.sent_rec_activation = eval(self.sent_rec_activation) 45 | 46 | self.params = [] 47 | 48 | class Encoder(EncoderDecoderBase): 49 | def init_params(self): 50 | """ sent weights """ 51 | self.W_emb = add_to_params(self.params, theano.shared(value=NormalInit(self.rng, self.idim, self.rankdim), name='W_emb')) 52 | self.W_in = add_to_params(self.params, theano.shared(value=NormalInit(self.rng, self.rankdim, self.qdim), name='W_in')) 53 | self.W_hh = add_to_params(self.params, theano.shared(value=OrthogonalInit(self.rng, self.qdim, self.qdim), name='W_hh')) 54 | self.b_hh = add_to_params(self.params, theano.shared(value=np.zeros((self.qdim,), dtype='float32'), name='b_hh')) 55 | 56 | if self.sent_step_type == "gated": 57 | self.W_in_r = add_to_params(self.params, theano.shared(value=NormalInit(self.rng, self.rankdim, self.qdim), name='W_in_r')) 58 | self.W_in_z = add_to_params(self.params, theano.shared(value=NormalInit(self.rng, self.rankdim, self.qdim), name='W_in_z')) 59 | self.W_hh_r = add_to_params(self.params, theano.shared(value=OrthogonalInit(self.rng, self.qdim, self.qdim), name='W_hh_r')) 60 | self.W_hh_z = add_to_params(self.params, theano.shared(value=OrthogonalInit(self.rng, self.qdim, self.qdim), name='W_hh_z')) 61 | self.b_z = add_to_params(self.params, theano.shared(value=np.zeros((self.qdim,), dtype='float32'), name='b_z')) 62 | self.b_r = add_to_params(self.params, theano.shared(value=np.zeros((self.qdim,), dtype='float32'), name='b_r')) 63 | 64 | """ Context weights """ 65 | self.Ws_in = add_to_params(self.params, theano.shared(value=NormalInit(self.rng, self.qdim, self.sdim), name='Ws_in')) 66 | self.Ws_hh = add_to_params(self.params, theano.shared(value=OrthogonalInit(self.rng, self.sdim, self.sdim), name='Ws_hh')) 67 | self.bs_hh = add_to_params(self.params, theano.shared(value=np.zeros((self.sdim,), dtype='float32'), name='bs_hh')) 68 | 69 | if self.triple_step_type == "gated": 70 | self.Ws_in_r = add_to_params(self.params, theano.shared(value=NormalInit(self.rng, self.qdim, self.sdim), name='Ws_in_r')) 71 | self.Ws_in_z = add_to_params(self.params, theano.shared(value=NormalInit(self.rng, self.qdim, self.sdim), name='Ws_in_z')) 72 | self.Ws_hh_r = add_to_params(self.params, theano.shared(value=OrthogonalInit(self.rng, self.sdim, self.sdim), name='Ws_hh_r')) 73 | self.Ws_hh_z = add_to_params(self.params, theano.shared(value=OrthogonalInit(self.rng, self.sdim, self.sdim), name='Ws_hh_z')) 74 | self.bs_z = add_to_params(self.params, theano.shared(value=np.zeros((self.sdim,), dtype='float32'), name='bs_z')) 75 | self.bs_r = add_to_params(self.params, theano.shared(value=np.zeros((self.sdim,), dtype='float32'), name='bs_r')) 76 | 77 | def plain_sent_step(self, x_t, m_t, h_tm1): 78 | if m_t.ndim >= 1: 79 | m_t = m_t.dimshuffle(0, 'x') 80 | 81 | hr_tm1 = m_t * h_tm1 82 | h_t = self.sent_rec_activation(T.dot(x_t, self.W_in) + T.dot(hr_tm1, self.W_hh) + self.b_hh) 83 | return h_t 84 | 85 | def gated_sent_step(self, x_t, m_t, h_tm1): 86 | if m_t.ndim >= 1: 87 | m_t = m_t.dimshuffle(0, 'x') 88 | 89 | hr_tm1 = m_t * h_tm1 90 | 91 | r_t = T.nnet.sigmoid(T.dot(x_t, self.W_in_r) + T.dot(hr_tm1, self.W_hh_r) + self.b_r) 92 | z_t = T.nnet.sigmoid(T.dot(x_t, self.W_in_z) + T.dot(hr_tm1, self.W_hh_z) + self.b_z) 93 | h_tilde = self.sent_rec_activation(T.dot(x_t, self.W_in) + T.dot(r_t * hr_tm1, self.W_hh) + self.b_hh) 94 | h_t = (np.float32(1.0) - z_t) * hr_tm1 + z_t * h_tilde 95 | 96 | # return both reset state and non-reset state 97 | return h_t, r_t, z_t, h_tilde 98 | 99 | def plain_triple_step(self, h_t, m_t, hs_tm1): 100 | if m_t.ndim >= 1: 101 | m_t = m_t.dimshuffle(0, 'x') 102 | 103 | hs_tilde = self.triple_rec_activation(T.dot(h_t, self.Ws_in) + T.dot(hs_tm1, self.Ws_hh) + self.bs_hh) 104 | hs_t = (m_t) * hs_tm1 + (1 - m_t) * hs_tilde 105 | return hs_t 106 | 107 | def gated_triple_step(self, h_t, m_t, hs_tm1): 108 | rs_t = T.nnet.sigmoid(T.dot(h_t, self.Ws_in_r) + T.dot(hs_tm1, self.Ws_hh_r) + self.bs_r) 109 | zs_t = T.nnet.sigmoid(T.dot(h_t, self.Ws_in_z) + T.dot(hs_tm1, self.Ws_hh_z) + self.bs_z) 110 | hs_tilde = self.triple_rec_activation(T.dot(h_t, self.Ws_in) + T.dot(rs_t * hs_tm1, self.Ws_hh) + self.bs_hh) 111 | hs_update = (np.float32(1.) - zs_t) * hs_tm1 + zs_t * hs_tilde 112 | 113 | if m_t.ndim >= 1: 114 | m_t = m_t.dimshuffle(0, 'x') 115 | 116 | hs_t = (m_t) * hs_tm1 + (1 - m_t) * hs_update 117 | return hs_t, hs_tilde, rs_t, zs_t 118 | 119 | def approx_embedder(self, x): 120 | return self.W_emb[x] 121 | 122 | def build_encoder(self, x, xmask=None, **kwargs): 123 | one_step = False 124 | if len(kwargs): 125 | raise Exception('One step not supported in build encoder') 126 | 127 | # if x.ndim == 2 then 128 | # x = (n_steps, batch_size) 129 | if x.ndim == 2: 130 | batch_size = x.shape[1] 131 | # else x = (word_1, word_2, word_3, ...) 132 | # or x = (last_word_1, last_word_2, last_word_3, ..) 133 | # in this case batch_size is 134 | else: 135 | batch_size = 1 136 | 137 | # if it is not one_step then we initialize everything to 0 138 | if not one_step: 139 | h_0 = T.alloc(np.float32(0), batch_size, self.qdim) 140 | hs_0 = T.alloc(np.float32(0), batch_size, self.sdim) 141 | 142 | xe = self.approx_embedder(x) 143 | if xmask == None: 144 | xmask = T.neq(x, self.eos_sym) 145 | 146 | # Here we roll the mask so we avoid the need for separate 147 | # hr and h. The trick is simple: if the original mask is 148 | # 0 1 1 0 1 1 1 0 0 0 0 0 -- batch is filled with eos_sym 149 | # the rolled mask will be 150 | # 0 0 1 1 0 1 1 1 0 0 0 0 -- roll to the right 151 | # ^ ^ 152 | # two resets 153 | # the first reset will reset h_init = 0 154 | # the second will reset and update given x_t = 155 | if xmask.ndim == 2: 156 | rolled_xmask = T.roll(xmask, 1, axis=0) 157 | 158 | # Gated Encoder 159 | if self.sent_step_type == "gated": 160 | f_enc = self.gated_sent_step 161 | o_enc_info = [h_0, None, None, None] 162 | else: 163 | f_enc = self.plain_sent_step 164 | o_enc_info = [h_0] 165 | 166 | if self.triple_step_type == "gated": 167 | f_hier = self.gated_triple_step 168 | o_hier_info = [hs_0, None, None, None] 169 | else: 170 | f_hier = self.plain_triple_step 171 | o_hier_info = [hs_0] 172 | 173 | # Run through all the sentence (encode everything) 174 | _res, _ = theano.scan(f_enc, 175 | sequences=[xe, rolled_xmask],\ 176 | outputs_info=o_enc_info) 177 | # Get the hidden state sequence 178 | h = _res[0] 179 | 180 | # All hierarchical sentence 181 | # The hs sequence is based on the original mask 182 | _res, _ = theano.scan(f_hier,\ 183 | sequences=[h, xmask],\ 184 | outputs_info=o_hier_info) 185 | 186 | if isinstance(_res, list) or isinstance(_res, tuple): 187 | hs = _res[0] 188 | else: 189 | hs = _res 190 | 191 | return h, hs 192 | 193 | def __init__(self, state, rng, parent): 194 | EncoderDecoderBase.__init__(self, state, rng, parent) 195 | self.init_params() 196 | 197 | class Decoder(EncoderDecoderBase): 198 | NCE = 0 199 | EVALUATION = 1 200 | BEAM_SEARCH = 2 201 | 202 | def __init__(self, state, rng, parent, encoder): 203 | EncoderDecoderBase.__init__(self, state, rng, parent) 204 | # Take as input the encoder instance for the embeddings.. 205 | # To modify in the future 206 | self.encoder = encoder 207 | self.trng = MRG_RandomStreams(self.seed) 208 | self.init_params() 209 | 210 | def init_params(self): 211 | """ Decoder weights """ 212 | self.bd_out = add_to_params(self.params, theano.shared(value=np.zeros((self.idim,), dtype='float32'), name='bd_out')) 213 | self.Wd_emb = add_to_params(self.params, theano.shared(value=NormalInit(self.rng, self.idim, self.rankdim), name='Wd_emb')) 214 | 215 | self.Wd_hh = add_to_params(self.params, theano.shared(value=OrthogonalInit(self.rng, self.qdim, self.qdim), name='Wd_hh')) 216 | self.bd_hh = add_to_params(self.params, theano.shared(value=np.zeros((self.qdim,), dtype='float32'), name='bd_hh')) 217 | self.Wd_in = add_to_params(self.params, theano.shared(value=NormalInit(self.rng, self.rankdim, self.qdim), name='Wd_in')) 218 | self.Wd_s_0 = add_to_params(self.params, theano.shared(value=NormalInit(self.rng, self.sdim, self.qdim), name='Wd_s_0')) 219 | self.bd_s_0 = add_to_params(self.params, theano.shared(value=np.zeros((self.qdim,), dtype='float32'), name='bd_s_0')) 220 | 221 | if self.decoder_bias_type == 'all': 222 | self.Wd_s_q = add_to_params(self.params, theano.shared(value=NormalInit(self.rng, self.sdim, self.qdim), name='Wd_s_q')) 223 | 224 | if self.sent_step_type == "gated": 225 | self.Wd_in_r = add_to_params(self.params, theano.shared(value=NormalInit(self.rng, self.rankdim, self.qdim), name='Wd_in_r')) 226 | self.Wd_in_z = add_to_params(self.params, theano.shared(value=NormalInit(self.rng, self.rankdim, self.qdim), name='Wd_in_z')) 227 | self.Wd_hh_r = add_to_params(self.params, theano.shared(value=OrthogonalInit(self.rng, self.qdim, self.qdim), name='Wd_hh_r')) 228 | self.Wd_hh_z = add_to_params(self.params, theano.shared(value=OrthogonalInit(self.rng, self.qdim, self.qdim), name='Wd_hh_z')) 229 | self.bd_r = add_to_params(self.params, theano.shared(value=np.zeros((self.qdim,), dtype='float32'), name='bd_r')) 230 | self.bd_z = add_to_params(self.params, theano.shared(value=np.zeros((self.qdim,), dtype='float32'), name='bd_z')) 231 | 232 | if self.decoder_bias_type == 'all': 233 | self.Wd_s_z = add_to_params(self.params, theano.shared(value=NormalInit(self.rng, self.sdim, self.qdim), name='Wd_s_z')) 234 | self.Wd_s_r = add_to_params(self.params, theano.shared(value=NormalInit(self.rng, self.sdim, self.qdim), name='Wd_s_r')) 235 | 236 | if self.decoder_bias_type == 'selective': 237 | self.bd_sel = add_to_params(self.params, theano.shared(value=np.zeros((self.sdim,), dtype='float32'), name='bd_sel')) 238 | self.Wd_s_q = add_to_params(self.params, theano.shared(value=NormalInit(self.rng, self.sdim, self.qdim), name='Wd_s_q')) 239 | # s -> g_r 240 | self.Wd_sel_s = add_to_params(self.params, \ 241 | theano.shared(value=NormalInit(self.rng, self.sdim, self.sdim), \ 242 | name='Wd_sel_s')) 243 | # x_{n-1} -> g_r 244 | self.Wd_sel_e = add_to_params(self.params, \ 245 | theano.shared(value=NormalInit(self.rng, self.rankdim, self.sdim), \ 246 | name='Wd_sel_e')) 247 | # h_{n-1} -> g_r 248 | self.Wd_sel_h = add_to_params(self.params, \ 249 | theano.shared(value=NormalInit(self.rng, self.qdim, self.sdim), \ 250 | name='Wd_sel_h')) 251 | 252 | ###################### 253 | # Output layer weights 254 | ###################### 255 | out_target_dim = self.qdim 256 | if not self.maxout_out: 257 | out_target_dim = self.rankdim 258 | 259 | self.Wd_out = add_to_params(self.params, theano.shared(value=NormalInit(self.rng, self.qdim, out_target_dim), name='Wd_out')) 260 | 261 | # Set up deep output 262 | if self.deep_out: 263 | self.Wd_e_out = add_to_params(self.params, theano.shared(value=NormalInit(self.rng, self.rankdim, out_target_dim), name='Wd_e_out')) 264 | self.bd_e_out = add_to_params(self.params, theano.shared(value=np.zeros((out_target_dim,), dtype='float32'), name='bd_e_out')) 265 | 266 | if self.decoder_bias_type != 'first': 267 | self.Wd_s_out = add_to_params(self.params, theano.shared(value=NormalInit(self.rng, self.sdim, out_target_dim), name='Wd_s_out')) 268 | 269 | def build_output_layer(self, hs, xd, hd): 270 | pre_activ = T.dot(hd, self.Wd_out) 271 | 272 | if self.deep_out: 273 | pre_activ += T.dot(xd, self.Wd_e_out) + self.bd_e_out 274 | 275 | if self.decoder_bias_type != 'first': 276 | pre_activ += T.dot(hs, self.Wd_s_out) 277 | # ^ if bias all, bias the deep output 278 | 279 | if self.maxout_out: 280 | pre_activ = Maxout(2)(pre_activ) 281 | 282 | return pre_activ 283 | 284 | def build_next_probs_predictor(self, hs, x, prev_hd): 285 | """ 286 | Return output probabilities given prev_words x, hierarchical pass hs, and previous hd 287 | hs should always be the same (and should not be updated). 288 | """ 289 | return self.build_decoder(hs, x, mode=Decoder.BEAM_SEARCH, prev_hd=prev_hd) 290 | 291 | def approx_embedder(self, x): 292 | # Here we use the same embeddings learnt in the encoder.. !!! 293 | return self.encoder.approx_embedder(x) 294 | 295 | def output_softmax(self, pre_activ): 296 | # returns a (timestep, bs, idim) matrix (huge) 297 | return SoftMax(T.dot(pre_activ, self.Wd_emb.T) + self.bd_out) 298 | 299 | def output_nce(self, pre_activ, y, y_hat): 300 | # returns a (timestep, bs, pos + neg) matrix (very small) 301 | target_embedding = self.Wd_emb[y] 302 | # ^ target embedding is (timestep x bs, rankdim) 303 | noise_embedding = self.Wd_emb[y_hat] 304 | # ^ noise embedding is (10, timestep x bs, rankdim) 305 | 306 | # pre_activ is (timestep x bs x rankdim) 307 | pos_scores = (target_embedding * pre_activ).sum(2) 308 | neg_scores = (noise_embedding * pre_activ).sum(3) 309 | 310 | pos_scores += self.bd_out[y] 311 | neg_scores += self.bd_out[y_hat] 312 | 313 | pos_noise = self.parent.t_noise_probs[y] * 10 314 | neg_noise = self.parent.t_noise_probs[y_hat] * 10 315 | 316 | pos_scores = - T.log(T.nnet.sigmoid(pos_scores - T.log(pos_noise))) 317 | neg_scores = - T.log(1 - T.nnet.sigmoid(neg_scores - T.log(neg_noise))).sum(0) 318 | return pos_scores + neg_scores 319 | 320 | def build_decoder(self, hs, x, xmask=None, y=None, y_neg=None, mode=EVALUATION, prev_hd=None, step_num=None): 321 | # Check parameter consistency 322 | if mode == Decoder.EVALUATION or mode == Decoder.NCE: 323 | assert not prev_hd 324 | assert y 325 | else: 326 | assert not y 327 | assert prev_hd 328 | 329 | # if mode == EVALUATION 330 | # xd = (timesteps, batch_size, qdim) 331 | # 332 | # if mode != EVALUATION 333 | # xd = (n_samples, dim) 334 | xd = self.approx_embedder(x) 335 | if not xmask: 336 | xmask = T.neq(x, self.eos_sym) 337 | 338 | # we must zero out the embedding 339 | # i.e. the embedding x_{-1} is the 0 vector 340 | # as well as hd_{-1} which will be reseted in the scan functions 341 | if xd.ndim != 3: 342 | assert mode != Decoder.EVALUATION 343 | xd = (xd.dimshuffle((1, 0)) * xmask).dimshuffle((1, 0)) 344 | else: 345 | assert mode == Decoder.EVALUATION or mode == Decoder.NCE 346 | xd = (xd.dimshuffle((2,0,1)) * xmask).dimshuffle((1,2,0)) 347 | 348 | # Run the decoder 349 | if mode == Decoder.EVALUATION or mode == Decoder.NCE: 350 | hd_init = T.alloc(np.float32(0), x.shape[1], self.qdim) 351 | else: 352 | hd_init = prev_hd 353 | 354 | if self.sent_step_type == "gated": 355 | f_dec = self.gated_step 356 | o_dec_info = [hd_init, None, None, None] 357 | if self.decoder_bias_type == "selective": 358 | o_dec_info += [None, None] 359 | else: 360 | f_dec = self.plain_step 361 | o_dec_info = [hd_init] 362 | if self.decoder_bias_type == "selective": 363 | o_dec_info += [None, None] 364 | 365 | # If the mode of the decoder is EVALUATION 366 | # then we evaluate by default all the sentence 367 | # xd - i.e. xd.ndim == 3, xd = (timesteps, batch_size, qdim) 368 | if mode == Decoder.EVALUATION or mode == Decoder.NCE: 369 | _res, _ = theano.scan(f_dec, 370 | sequences=[xd, xmask, hs],\ 371 | outputs_info=o_dec_info) 372 | # else we evaluate only one step of the recurrence using the 373 | # previous hidden states and the previous computed hierarchical 374 | # states. 375 | else: 376 | _res = f_dec(xd, xmask, hs, prev_hd) 377 | 378 | if isinstance(_res, list) or isinstance(_res, tuple): 379 | hd = _res[0] 380 | else: 381 | hd = _res 382 | 383 | # if we are using selective bias, we should update our hs 384 | # to the step-selective hs 385 | step_hs = hs 386 | if self.decoder_bias_type == "selective": 387 | step_hs = _res[1] 388 | pre_activ = self.build_output_layer(step_hs, xd, hd) 389 | 390 | # EVALUATION : Return target_probs + all the predicted ranks 391 | # target_probs.ndim == 3 392 | if mode == Decoder.EVALUATION: 393 | outputs = self.output_softmax(pre_activ) 394 | target_probs = GrabProbs(outputs, y) 395 | return target_probs, hd, _res, outputs 396 | elif mode == Decoder.NCE: 397 | return self.output_nce(pre_activ, y, y_neg), hd 398 | # BEAM_SEARCH : Return output (the softmax layer) + the new hidden states 399 | elif mode == Decoder.BEAM_SEARCH: 400 | return self.output_softmax(pre_activ), hd 401 | 402 | def gated_step(self, xd_t, m_t, hs_t, hd_tm1): 403 | if m_t.ndim >= 1: 404 | m_t = m_t.dimshuffle(0, 'x') 405 | 406 | hd_tm1 = (m_t) * hd_tm1 + (1 - m_t) * self.sent_rec_activation(T.dot(hs_t, self.Wd_s_0) + self.bd_s_0) 407 | # ^ iff x_{t - 1} = (m_t = 0) then x_{t - 1} = 0 408 | # and hd_{t - 1} = tanh(W_s_0 hs_t + bd_s_0) else hd_{t - 1} is left unchanged (m_t = 1) 409 | 410 | # In the 'selective' decoder bias type each hidden state of the decoder 411 | # RNN receives the hs_t modified by the selective bias -> hsr_t 412 | if self.decoder_bias_type == 'selective': 413 | rd_sel_t = T.nnet.sigmoid(T.dot(xd_t, self.Wd_sel_e) + T.dot(hd_tm1, self.Wd_sel_h) + T.dot(hs_t, self.Wd_sel_s) + self.bd_sel) 414 | hsr_t = rd_sel_t * hs_t 415 | 416 | rd_t = T.nnet.sigmoid(T.dot(xd_t, self.Wd_in_r) + T.dot(hd_tm1, self.Wd_hh_r) + self.bd_r) 417 | zd_t = T.nnet.sigmoid(T.dot(xd_t, self.Wd_in_z) + T.dot(hd_tm1, self.Wd_hh_z) + self.bd_z) 418 | hd_tilde = self.sent_rec_activation(T.dot(xd_t, self.Wd_in) \ 419 | + T.dot(rd_t * hd_tm1, self.Wd_hh) \ 420 | + T.dot(hsr_t, self.Wd_s_q) \ 421 | + self.bd_hh) 422 | 423 | hd_t = (np.float32(1.) - zd_t) * hd_tm1 + zd_t * hd_tilde 424 | output = (hd_t, hsr_t, rd_sel_t, rd_t, zd_t, hd_tilde) 425 | 426 | # In the 'all' decoder bias type each hidden state of the decoder 427 | # RNN receives the hs_t vector as bias without modification 428 | elif self.decoder_bias_type == 'all': 429 | 430 | rd_t = T.nnet.sigmoid(T.dot(xd_t, self.Wd_in_r) + T.dot(hd_tm1, self.Wd_hh_r) + T.dot(hs_t, self.Wd_s_r) + self.bd_r) 431 | zd_t = T.nnet.sigmoid(T.dot(xd_t, self.Wd_in_z) + T.dot(hd_tm1, self.Wd_hh_z) + T.dot(hs_t, self.Wd_s_z) + self.bd_z) 432 | hd_tilde = self.sent_rec_activation(T.dot(xd_t, self.Wd_in) \ 433 | + T.dot(rd_t * hd_tm1, self.Wd_hh) \ 434 | + T.dot(hs_t, self.Wd_s_q) \ 435 | + self.bd_hh) 436 | hd_t = (np.float32(1.) - zd_t) * hd_tm1 + zd_t * hd_tilde 437 | output = (hd_t, rd_t, zd_t, hd_tilde) 438 | 439 | else: 440 | # Do not bias all the decoder (force to store very useful information in the first state) 441 | rd_t = T.nnet.sigmoid(T.dot(xd_t, self.Wd_in_r) + T.dot(hd_tm1, self.Wd_hh_r) + self.bd_r) 442 | zd_t = T.nnet.sigmoid(T.dot(xd_t, self.Wd_in_z) + T.dot(hd_tm1, self.Wd_hh_z) + self.bd_z) 443 | hd_tilde = self.sent_rec_activation(T.dot(xd_t, self.Wd_in) \ 444 | + T.dot(rd_t * hd_tm1, self.Wd_hh) \ 445 | + self.bd_hh) 446 | hd_t = (np.float32(1.) - zd_t) * hd_tm1 + zd_t * hd_tilde 447 | output = (hd_t, rd_t, zd_t, hd_tilde) 448 | return output 449 | 450 | def plain_step(self, xd_t, m_t, hs_t, hd_tm1): 451 | if m_t.ndim >= 1: 452 | m_t = m_t.dimshuffle(0, 'x') 453 | 454 | # We already assume that xd are zeroed out 455 | hd_tm1 = (m_t) * hd_tm1 + (1-m_t) * self.sent_rec_activation(T.dot(hs_t, self.Wd_s_0) + self.bd_s_0) 456 | # ^ iff x_{t - 1} = (m_t = 0) then x_{t-1} = 0 457 | # and hd_{t - 1} = 0 else hd_{t - 1} is left unchanged (m_t = 1) 458 | 459 | if self.decoder_bias_type == 'first': 460 | # Do not bias all the decoder (force to store very useful information in the first state) 461 | hd_t = self.sent_rec_activation( T.dot(xd_t, self.Wd_in) \ 462 | + T.dot(hd_tm1, self.Wd_hh) \ 463 | + self.bd_hh ) 464 | output = (hd_t,) 465 | elif self.decoder_bias_type == 'all': 466 | hd_t = self.sent_rec_activation( T.dot(xd_t, self.Wd_in) \ 467 | + T.dot(hd_tm1, self.Wd_hh) \ 468 | + T.dot(hs_t, self.Wd_s_q) \ 469 | + self.bd_hh ) 470 | output = (hd_t,) 471 | elif self.decoder_bias_type == 'selective': 472 | rd_sel_t = T.nnet.sigmoid(T.dot(xd_t, self.Wd_sel_e) + T.dot(hd_tm1, self.Wd_sel_h) + T.dot(hs_t, self.Wd_sel_s) + self.bd_sel) 473 | hsr_t = rd_sel_t * hs_t 474 | 475 | hd_tilde = self.sent_rec_activation( T.dot(xd_t, self.Wd_in) \ 476 | + T.dot(hd_tm1, self.Wd_hh) \ 477 | + T.dot(hsr_t, self.Wd_s_q) \ 478 | + self.bd_hh ) 479 | output = (hd_t, hsr_t, rd_sel_t) 480 | 481 | return output 482 | #### 483 | 484 | class DialogEncoderDecoder(Model): 485 | def indices_to_words(self, seq): 486 | """ 487 | Converts a list of word ids to a list 488 | of words. Use unk_sym if a word is not 489 | known. 490 | """ 491 | def convert(): 492 | for word_index in seq: 493 | if word_index > len(self.idx_to_str): 494 | raise ValueError('Word index is too large for the model vocabulary!') 495 | yield self.idx_to_str[word_index] 496 | return list(convert()) 497 | 498 | def words_to_indices(self, seq): 499 | """ 500 | Converts a list of words to a list 501 | of word ids. Use unk_sym if a word is not 502 | known. 503 | """ 504 | return [self.str_to_idx.get(word, self.unk_sym) for word in seq] 505 | 506 | def compute_updates(self, training_cost, params): 507 | updates = [] 508 | 509 | grads = T.grad(training_cost, params) 510 | grads = OrderedDict(zip(params, grads)) 511 | 512 | # Clip stuff 513 | c = numpy.float32(self.cutoff) 514 | clip_grads = [] 515 | 516 | norm_gs = T.sqrt(sum(T.sum(g ** 2) for p, g in grads.items())) 517 | normalization = T.switch(T.ge(norm_gs, c), c / norm_gs, np.float32(1.)) 518 | notfinite = T.or_(T.isnan(norm_gs), T.isinf(norm_gs)) 519 | 520 | for p, g in grads.items(): 521 | clip_grads.append((p, T.switch(notfinite, numpy.float32(.1) * p, g * normalization))) 522 | 523 | grads = OrderedDict(clip_grads) 524 | 525 | if self.updater == 'adagrad': 526 | updates = Adagrad(grads, self.lr) 527 | elif self.updater == 'sgd': 528 | raise Exception("Sgd not implemented!") 529 | elif self.updater == 'adadelta': 530 | updates = Adadelta(grads) 531 | elif self.updater == 'rmsprop': 532 | updates = RMSProp(grads, self.lr) 533 | elif self.updater == 'adam': 534 | updates = Adam(grads) 535 | else: 536 | raise Exception("Updater not understood!") 537 | return updates 538 | 539 | def build_train_function(self): 540 | if not hasattr(self, 'train_fn'): 541 | # Compile functions 542 | logger.debug("Building train function") 543 | self.train_fn = theano.function(inputs=[self.x_data, self.x_max_length, self.x_cost_mask], 544 | outputs=self.training_cost, 545 | updates=self.updates, name="train_fn") 546 | return self.train_fn 547 | 548 | def build_nce_function(self): 549 | if not hasattr(self, 'train_fn'): 550 | # Compile functions 551 | logger.debug("Building train function") 552 | self.nce_fn = theano.function(inputs=[self.x_data, self.y_neg, self.x_max_length, self.x_cost_mask], 553 | outputs=self.contrastive_cost, 554 | updates=self.updates, name="train_fn") 555 | return self.nce_fn 556 | 557 | def build_eval_function(self): 558 | if not hasattr(self, 'eval_fn'): 559 | # Compile functions 560 | logger.debug("Building evaluation function") 561 | self.eval_fn = theano.function(inputs=[self.x_data, self.x_max_length, self.x_cost_mask], 562 | outputs=[self.softmax_cost_acc, self.softmax_cost], name="eval_fn") 563 | return self.eval_fn 564 | 565 | def build_eval_misclassification_function(self): 566 | if not hasattr(self, 'eval_misclass_fn'): 567 | # Compile functions 568 | logger.debug("Building misclassification evaluation function") 569 | self.eval_misclass_fn = theano.function(inputs=[self.x_data, self.x_max_length, self.x_cost_mask], 570 | outputs=self.training_misclassification, name="eval_misclass_fn", on_unused_input='ignore') 571 | 572 | return self.eval_misclass_fn 573 | 574 | def build_get_states_function(self): 575 | if not hasattr(self, 'get_states_fn'): 576 | # Compile functions 577 | logger.debug("Building selective function") 578 | 579 | outputs = [self.h, self.hs, self.hd] + [x for x in self.decoder_states] 580 | self.get_states_fn = theano.function(inputs=[self.x_data, self.x_max_length], 581 | outputs=outputs, name="get_states_fn") 582 | return self.get_states_fn 583 | 584 | def build_next_probs_function(self): 585 | if not hasattr(self, 'next_probs_fn'): 586 | outputs, hd = self.decoder.build_next_probs_predictor(self.beam_hs, self.beam_source, prev_hd=self.beam_hd) 587 | self.next_probs_fn = theano.function(inputs=[self.beam_hs, self.beam_source, self.beam_hd], 588 | outputs=[outputs, hd], 589 | name="next_probs_fn") 590 | return self.next_probs_fn 591 | 592 | def build_next_encoder_function(self): 593 | if not hasattr(self, 'next_encoder_fn'): 594 | h, hs = self.encoder.build_encoder(self.beam_source, prev_hs=self.beam_hs, 595 | prev_h=self.beam_h, prev_token=self.beam_prev_source) 596 | self.next_encoder_fn = theano.function(inputs=[self.beam_hs, self.beam_source, self.beam_h, self.beam_prev_source], 597 | outputs = [h, hs], 598 | name='next_encoder_fn') 599 | return self.next_encoder_fn 600 | 601 | def build_encoder_function(self): 602 | if not hasattr(self, 'encoder_fn'): 603 | h, hs = self.encoder.build_encoder(self.aug_x_data) 604 | self.encoder_fn = theano.function(inputs=[self.x_data], 605 | outputs=[h, hs], name="encoder_fn") 606 | return self.encoder_fn 607 | 608 | def __init__(self, state): 609 | Model.__init__(self) 610 | self.state = state 611 | 612 | # Compatibility towards older models 613 | self.__dict__.update(state) 614 | self.rng = numpy.random.RandomState(state['seed']) 615 | 616 | # Load dictionary 617 | raw_dict = cPickle.load(open(self.dictionary, 'r')) 618 | 619 | # Probabilities for each term in the corpus 620 | self.noise_probs = [x[2] for x in sorted(raw_dict, key=operator.itemgetter(1))] 621 | self.noise_probs = numpy.array(self.noise_probs, dtype='float64') 622 | self.noise_probs /= numpy.sum(self.noise_probs) 623 | self.noise_probs = self.noise_probs ** 0.75 624 | self.noise_probs /= numpy.sum(self.noise_probs) 625 | 626 | self.t_noise_probs = theano.shared(self.noise_probs.astype('float32'), 't_noise_probs') 627 | 628 | # Dictionaries to convert str to idx and vice-versa 629 | self.str_to_idx = dict([(tok, tok_id) for tok, tok_id, _ in raw_dict]) 630 | self.idx_to_str = dict([(tok_id, tok) for tok, tok_id, freq in raw_dict]) 631 | 632 | # Extract document (triple) frequency for each word 633 | self.word_freq = dict([(tok_id, freq) for _, tok_id, freq in raw_dict]) 634 | # self.document_freq = dict([(tok_id, df) for _, tok_id, _, df in raw_dict]) 635 | 636 | if '' not in self.str_to_idx \ 637 | or '' not in self.str_to_idx: 638 | raise Exception("Error, malformed dictionary!") 639 | 640 | # Number of words in the dictionary 641 | self.idim = len(self.str_to_idx) 642 | self.state['idim'] = self.idim 643 | logger.debug("idim: " + str(self.idim)) 644 | 645 | logger.debug("Initializing encoder") 646 | self.encoder = Encoder(self.state, self.rng, self) 647 | logger.debug("Initializing decoder") 648 | self.decoder = Decoder(self.state, self.rng, self, self.encoder) 649 | 650 | # Init params 651 | self.params = self.encoder.params + self.decoder.params 652 | assert len(set(self.params)) == (len(self.encoder.params) + len(self.decoder.params)) 653 | 654 | self.y_neg = T.itensor3('y_neg') 655 | self.x_data = T.imatrix('x_data') 656 | self.x_cost_mask = T.matrix('cost_mask') 657 | self.x_max_length = T.iscalar('x_max_length') 658 | 659 | # The training is done with a trick. We append a special at the beginning of the dialog 660 | # so that we can predict also the first sent in the dialog starting from the dialog beginning token (). 661 | self.aug_x_data = T.concatenate([T.alloc(np.int32(self.eos_sym), 1, self.x_data.shape[1]), self.x_data]) 662 | training_x = self.aug_x_data[:self.x_max_length] 663 | training_y = self.aug_x_data[1:self.x_max_length+1] 664 | 665 | # Here we find the end-of-sentence tokens in the minibatch. 666 | training_hs_mask = T.neq(training_x, self.eos_sym) 667 | training_x_cost_mask = self.x_cost_mask[:self.x_max_length].flatten() 668 | 669 | # Backward compatibility 670 | if 'decoder_bias_type' in self.state: 671 | logger.debug("Decoder bias type {}".format(self.decoder_bias_type)) 672 | 673 | logger.debug("Build encoder") 674 | self.h, self.hs = self.encoder.build_encoder(training_x, xmask=training_hs_mask) 675 | 676 | logger.debug("Build decoder (NCE)") 677 | contrastive_cost, self.hd_nce = self.decoder.build_decoder( 678 | self.hs, training_x, y_neg=self.y_neg, y=training_y, xmask=training_hs_mask, mode=Decoder.NCE) 679 | 680 | logger.debug("Build decoder (EVAL)") 681 | 682 | target_probs, self.hd, self.decoder_states, target_probs_full_matrix = self.decoder.build_decoder( 683 | self.hs, training_x, xmask=training_hs_mask, y=training_y, mode=Decoder.EVALUATION) 684 | 685 | # Prediction cost and rank cost 686 | self.contrastive_cost = T.sum(contrastive_cost.flatten() * training_x_cost_mask) 687 | self.softmax_cost = -T.log(target_probs) * training_x_cost_mask 688 | self.softmax_cost_acc = T.sum(self.softmax_cost) 689 | 690 | # Mean squared error 691 | self.training_cost = self.softmax_cost_acc 692 | if self.use_nce: 693 | self.training_cost = self.contrastive_cost 694 | self.updates = self.compute_updates(self.training_cost / training_x.shape[1], self.params) 695 | 696 | # Prediction accuracy 697 | self.training_misclassification = T.sum(T.neq(T.argmax(target_probs_full_matrix, axis=2), training_y).flatten() * training_x_cost_mask) 698 | 699 | # Beam-search variables 700 | self.beam_source = T.lvector("beam_source") 701 | self.beam_prev_source = T.lvector("beam_prev_source") 702 | 703 | self.beam_h = T.matrix("beam_h") 704 | self.beam_hs = T.matrix("beam_hs") 705 | self.beam_hd = T.matrix("beam_hd") 706 | self.beam_step_num = T.lscalar("beam_step_num") 707 | --------------------------------------------------------------------------------