├── datadir_wsj ├── GRU.pyc ├── model.pyc ├── search.pyc ├── stream.pyc ├── train.pyc ├── attention.pyc ├── sampling.pyc ├── checkpoint.pyc ├── SimplePrinting.pyc ├── afterprocess.pyc ├── configurations.pyc ├── match_functions.pyc ├── SequenceGenerator.pyc ├── configurations_base.pyc ├── learning_rate_halver.pyc ├── search_decoder_with_extra_class.pyc ├── SequenceGenerator_forPickTopicWord.pyc ├── ssh.pub ├── configurations.py ├── compute_bleu.py ├── double.py ├── double_image.py ├── README ├── ssh ├── SimplePrinting.py ├── main.py ├── get_valid_status.py ├── visualize_attention.py ├── match_functions.py ├── afterprocess.py ├── GRU.py ├── learning_rate_halver.py ├── checkpoint.py ├── SequenceGenerator.py ├── model.py ├── SequenceGenerator_forPickTopicWord.py ├── search.py ├── search_decoder_with_extra_class.py ├── sampling.py ├── stream.py ├── train.py └── attention.py /datadir_wsj: -------------------------------------------------------------------------------- 1 | D:/users/v-qizhou/Data/WSJ/ 2 | -------------------------------------------------------------------------------- /GRU.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynetteXing1991/TA-Seq2Seq/HEAD/GRU.pyc -------------------------------------------------------------------------------- /model.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynetteXing1991/TA-Seq2Seq/HEAD/model.pyc -------------------------------------------------------------------------------- /search.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynetteXing1991/TA-Seq2Seq/HEAD/search.pyc -------------------------------------------------------------------------------- /stream.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynetteXing1991/TA-Seq2Seq/HEAD/stream.pyc -------------------------------------------------------------------------------- /train.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynetteXing1991/TA-Seq2Seq/HEAD/train.pyc -------------------------------------------------------------------------------- /attention.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynetteXing1991/TA-Seq2Seq/HEAD/attention.pyc -------------------------------------------------------------------------------- /sampling.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynetteXing1991/TA-Seq2Seq/HEAD/sampling.pyc -------------------------------------------------------------------------------- /checkpoint.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynetteXing1991/TA-Seq2Seq/HEAD/checkpoint.pyc -------------------------------------------------------------------------------- /SimplePrinting.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynetteXing1991/TA-Seq2Seq/HEAD/SimplePrinting.pyc -------------------------------------------------------------------------------- /afterprocess.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynetteXing1991/TA-Seq2Seq/HEAD/afterprocess.pyc -------------------------------------------------------------------------------- /configurations.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynetteXing1991/TA-Seq2Seq/HEAD/configurations.pyc -------------------------------------------------------------------------------- /match_functions.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynetteXing1991/TA-Seq2Seq/HEAD/match_functions.pyc -------------------------------------------------------------------------------- /SequenceGenerator.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynetteXing1991/TA-Seq2Seq/HEAD/SequenceGenerator.pyc -------------------------------------------------------------------------------- /configurations_base.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynetteXing1991/TA-Seq2Seq/HEAD/configurations_base.pyc -------------------------------------------------------------------------------- /learning_rate_halver.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynetteXing1991/TA-Seq2Seq/HEAD/learning_rate_halver.pyc -------------------------------------------------------------------------------- /search_decoder_with_extra_class.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynetteXing1991/TA-Seq2Seq/HEAD/search_decoder_with_extra_class.pyc -------------------------------------------------------------------------------- /SequenceGenerator_forPickTopicWord.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LynetteXing1991/TA-Seq2Seq/HEAD/SequenceGenerator_forPickTopicWord.pyc -------------------------------------------------------------------------------- /ssh.pub: -------------------------------------------------------------------------------- 1 | ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCxeDwSIRXHv/dCpDJW9WnIV3fza78ELNVdJmr6fTM9q4UHExbATvfzwxTopGnoqt/igcGepETtGgq0mV8AoWNvCR+aMZkCAYYbKboa2JHCWmqQPEN8wcdgx9HPGqNo5iO2Cg658aezuH9WQoj51FxBSmJm/XrGgYsGNUaK1NlDaAhVCrlcUC3wm+M1vws5JjOIARUDKNklw9gjDAOSziTFF7H+B+mG/+tgdzhwKye+bgL24Ohvy7ej1NVotKSEYCynBGN4GCcn/RGv9+a3uU/g7SpOG9EXCi9ID50xHpxf0ngLl+lZ0QlC9ib5AGD8Wt9800nBuHvTMq24F3J1EwdZ v-chxing@microsoft.com 2 | -------------------------------------------------------------------------------- /configurations.py: -------------------------------------------------------------------------------- 1 | from configurations_base import default 2 | 3 | def wsj(): 4 | config = default() 5 | config['model_name'] = 'wsj' 6 | config['saveto'] = 'models/wsj' 7 | config['step_rule'] = 'AdaGrad' 8 | config['match_function'] = 'SumMatchFunction' 9 | config['attention_images'] = config['saveto'] + '/attention_images/' 10 | config['attention_weights'] = config['saveto'] + '/attention_weights' 11 | config['val_output_orig'] = config['saveto'] + '/test_output_orig' 12 | config['val_output_repl'] = config['saveto'] + '/test_output_repl' 13 | return config 14 | -------------------------------------------------------------------------------- /compute_bleu.py: -------------------------------------------------------------------------------- 1 | import re 2 | import argparse 3 | import configurations 4 | from subprocess import Popen, PIPE 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument( 8 | "--proto", default="normal_adagrad", 9 | help="Prototype config to use for config") 10 | args = parser.parse_args() 11 | 12 | def main(config): 13 | compbleu_cmd = [config['bleu_script_1'], 14 | config['val_set_target'], 15 | config['val_output_orig']] 16 | bleu_subproc = Popen(compbleu_cmd, stdout=PIPE) 17 | while True: 18 | line = bleu_subproc.stdout.readline() 19 | if line != '': 20 | if 'BLEU' in line: 21 | stdout = line 22 | else: 23 | break 24 | bleu_subproc.terminate() 25 | out_parse = re.match(r'BLEU = [-.0-9]+', stdout) 26 | bleu_score = float(out_parse.group()[6:]) 27 | print bleu_score 28 | 29 | if __name__ == '__main__': 30 | config = getattr(configurations, args.proto)() 31 | main(config) 32 | -------------------------------------------------------------------------------- /double.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import matplotlib.cm as cm 5 | from progressbar import ProgressBar 6 | from PIL import Image 7 | 8 | dir1 = 'models/normal_adagrad/attention_images/' 9 | dir2 = 'models/rec_adagrad/attention_images/' 10 | outdir = 'normal_vs_rec/' 11 | if not os.path.exists(outdir): 12 | os.mkdir(outdir) 13 | 14 | candidates = xrange(1082) 15 | 16 | pbar = ProgressBar(max_value=len(candidates)).start() 17 | for i, k in enumerate(candidates): 18 | pbar.update(i + 1) 19 | fig = plt.figure(figsize=(40, 20), dpi=80) 20 | f1 = dir1 + str(k) + '.png' 21 | f2 = dir2 + str(k) + '.png' 22 | for i, ff in enumerate([f2, f1]): 23 | image = Image.open(ff) 24 | w, h = image.size 25 | image = image.crop((100, 0, w-100, h)) 26 | arr = np.asarray(image) 27 | fig.add_subplot(1, 2, i) 28 | fig.tight_layout() 29 | plt.imshow(arr, cmap=cm.Greys_r) 30 | plt.tight_layout() 31 | plt.axis('off') 32 | plt.savefig(outdir + str(k) + '.png') 33 | plt.close() 34 | pbar.finish() 35 | -------------------------------------------------------------------------------- /double_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import matplotlib.cm as cm 5 | from progressbar import ProgressBar 6 | from PIL import Image 7 | 8 | dir1 = 'models/normal_adagrad/attention_images/' 9 | dir2 = 'models/rec_adagrad/attention_images/' 10 | outdir = 'double_images/' 11 | if not os.path.exists(outdir): 12 | os.mkdir(outdir) 13 | 14 | candidates = range(20) 15 | 16 | pbar = ProgressBar(max_value=len(candidates)).start() 17 | for i, k in enumerate(candidates): 18 | pbar.update(i + 1) 19 | fig = plt.figure(figsize=(40, 20), dpi=80) 20 | f1 = dir1 + str(k) + '.png' 21 | f2 = dir2 + str(k) + '.png' 22 | for i, ff in enumerate([f2, f1]): 23 | image = Image.open(ff).convert('L') 24 | w, h = image.size 25 | image = image.crop((0, 0, w-500, h)) 26 | arr = np.asarray(image) 27 | fig.add_subplot(1, 2, i) 28 | fig.tight_layout() 29 | plt.imshow(arr, cmap=cm.Blues) 30 | plt.tight_layout() 31 | plt.axis('off') 32 | plt.savefig(outdir + str(k) + '.png') 33 | plt.close() 34 | pbar.finish() 35 | -------------------------------------------------------------------------------- /README: -------------------------------------------------------------------------------- 1 | TA-Seq2Seq 2 | 3 | This project is built on Theano 0.9, python 2.7 and Blocks(https://github.com/mila-udem/blocks). Please make sure they are installed before running this project. 4 | 5 | 6 | Step 1: preparing the data 7 | 8 | 9 | This project requires 3 vocabularies, the query vocabulary, response vocabulary and topic vocabulary. You should build every vocabulary as a dictionary like {'I': 0, 'UNK': 1, 'a': 2, 'student': 3, '': 4} and save it as an pkl file. 10 | 11 | This project also requires a query file, a response file and a topic word file, in which the query, response and topic word list attached of a case are saved separately in the same line of the three files. 12 | 13 | 14 | Step 2: checking the configurations 15 | 16 | 17 | Please refer to the function topicAawareJPData() in configurations_base.py as an example of how to write configuration of your experiment. 18 | 19 | Let me explain some important features: 20 | 21 | The 'topic_vocab_output' and 'topic_vocab_output' are set as the same topic vocabulary built beforehand. 22 | 'topic_embeddings' is the embedding matrix of all topic words in which the i-th row is the embedding of the i-th word in the topic word vocabulary. 23 | 'topical_word_num' is the number of topic words attached for every query (number of words in every line of topic word file). 24 | 'tw_vocab_overlap' is a one-hot matrix that maps topic words with their numbers in the response vocabulary. A simple case is as follows, 25 | I UNK a student 26 | student(topic word)[[0 0 0 1 0] 27 | a(topic word) [0 0 1 0 0]] 28 | 29 | 30 | Step 3: Run! 31 | -------------------------------------------------------------------------------- /ssh: -------------------------------------------------------------------------------- 1 | -----BEGIN RSA PRIVATE KEY----- 2 | Proc-Type: 4,ENCRYPTED 3 | DEK-Info: AES-128-CBC,872CA31FCDA11E404801BAB285231B84 4 | 5 | pcN1wLrNRMp0gI+2EHN+IH50wdTIAUYfcvlSccagsfSNUl10XU9vhOc3bQZhs3OS 6 | sfak+hmgmaPYBiSK1fLzTsvr2BrdKf4K/8ZWh6JYhEU9ADc9QkfEj/DAz76s2E2G 7 | Uxs8spEUbaYQGZgMNnpSarXENEFQjE+ysAMsFr35N8udFEfVbxlntq0KCYW9i8zl 8 | Lm3jVjIVHNWyHfcXIKBvZFbsMJaJ0aoiBk7TJBDnvHAtDXNfZJhr+JAz/IdFvSu5 9 | LKJ4FIhlVGIGhFUdpWKd/jwUYKUmy9Zd6+Cyaj3yfolT104IzRBe9v7x1T+zpAd8 10 | W1TjLGkSIDDtOjZ5kBc7+cM2xtUi24ZWWzDN5Mk/SdTcXwo7cgr23QdDboZWow7B 11 | TzRWil8y2n59qTixIlaig6WInrerod1OewaoU0bgSqjp1g45qgzMSuu8cwlixMj7 12 | wmsfK0+FjpklagOQ4bHmrT8EsM82EnfQihXm4JwqM52FFVETun58hbH3aKlh4T3n 13 | iB85tEthbshn2VCUIGTk0vD0IEfq2d+9eZgRo4A3k7OJPa370gBfD4i9iTk2kMss 14 | xUlo8knzc1mRhUDzKLmTwGs7g0YYQbtULsg9jAWlxlg4e6PshfBjpLuG8CgAYDsX 15 | 1fUmi6y/0KthwsGMtsYKrnV568+3b9dfZ8gweIOeukfFJB/8dPXh/9ONW9NwHUp4 16 | dFg/+vcXRRaLZ9RkYz1UvjcEo0HW0In8yDegjqng+HZSQMe4C/V+ABs7e5bgSJ88 17 | OOPuqMsNZIETGS9v8XvipZXLQPbKegaos4/JBo2o1MvOqFy5JzCgrDH0DocB7NL9 18 | PdSyahgZMSGcWYlfTH3K+NivsNfbnlenFW8Tu/9Kra6tbDDZKS1fFPrXbwhoUo36 19 | IGnE1PlqutPkA8wys+HMCbYqdbJ5AEvyNy5Yqt/ZOzIxgl8iUhLUpv66LFBTU3rY 20 | 87NZTw4Q23cl1OO6YnmcW7Uk/fDUy+HCQ/gYRErJhY0VGkzRKJGB4GpjHz8RKi45 21 | rKJBzrfUUG/4hX3j4KsvfU91H9BWledlhCCGHCtrDiyIVamajc+ed+hI1Ujg3Rx7 22 | 9zmHKKJ1rdcGn6XKiUS2JhltMzNapUHmpkXr/ihcoO7llWeYT5YpgbdcplKIZ6WZ 23 | OKTb2zvp8JaDPFYpW0E3u0YIae6QCZoNAO2Q5Pg42a/MUAN14qs/BU8OQ4a4asYK 24 | cOH07vnpY1hzrnAmTI8ruygxCTynDNcq3BdaetuMP5q0GVPLXoebDJrzkneWDbAs 25 | qIb845euH02V18yZS2H5jg/Cod5RusedCUpXf7vdG6gh0/iVp+OtMx8eN/XCaYio 26 | d+uz5gmoQXg+AFgl6dt6/7vQAKUYJsomCVE/++fkDDj2Mxj6XXIcMnpSLjhhDsWD 27 | km4XS1nF0e1P5DLHPDyHZd1mVOPUVgFJVG4EEZKCWtafC3Z+Ro4YN+r5wRruIuya 28 | kM8+kn1QBaYgFgsbkUG8yKo/dZHR5CXtcGqp9ty1tij+xI9fe/qIsGC8O+tVF7eW 29 | ONRpdFOmuf5d56AlsybWxRZHMEar8mfiENcX/KFwPRUnaqBOs5IgOPYnRnadVwIu 30 | -----END RSA PRIVATE KEY----- 31 | -------------------------------------------------------------------------------- /SimplePrinting.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from toolz import first 3 | from blocks.extensions import SimpleExtension 4 | 5 | 6 | class SimplePrinting(SimpleExtension): 7 | """Prints log messages to the screen.""" 8 | def __init__(self, model_name, **kwargs): 9 | self.model_name = model_name 10 | kwargs.setdefault("before_first_epoch", True) 11 | kwargs.setdefault("on_resumption", True) 12 | kwargs.setdefault("after_training", True) 13 | kwargs.setdefault("after_epoch", True) 14 | kwargs.setdefault("on_interrupt", True) 15 | super(SimplePrinting, self).__init__(**kwargs) 16 | 17 | def _print_attributes(self, attribute_tuples): 18 | for attr, value in sorted(attribute_tuples.items(), key=first): 19 | if not attr.startswith("_"): 20 | print("\t" + "{}: {}".format(attr, value)), 21 | print 22 | sys.stdout.flush() 23 | 24 | def do(self, which_callback, *args): 25 | log = self.main_loop.log 26 | print_status = True 27 | 28 | # print() 29 | # print("".join(79 * "-")) 30 | if which_callback == "before_epoch" and log.status['epochs_done'] == 0: 31 | print("BEFORE FIRST EPOCH") 32 | elif which_callback == "on_resumption": 33 | print("TRAINING HAS BEEN RESUMED") 34 | elif which_callback == "after_training": 35 | print("TRAINING HAS BEEN FINISHED:") 36 | elif which_callback == "after_epoch": 37 | print("AFTER ANOTHER EPOCH") 38 | elif which_callback == "on_interrupt": 39 | print("TRAINING HAS BEEN INTERRUPTED") 40 | print_status = False 41 | # print("".join(79 * "-")) 42 | if print_status: 43 | # print("Training status:") 44 | # self._print_attributes(log.status) 45 | print self.model_name, log.status['iterations_done'], 46 | self._print_attributes(log.current_row) 47 | # print() 48 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """Encoder-Decoder with search for machine translation. 2 | 3 | In this demo, encoder-decoder architecture with attention mechanism is used for 4 | machine translation. The attention mechanism is implemented according to 5 | [BCB]_. The training data used is WMT15 Czech to English corpus, which you have 6 | to download, preprocess and put to your 'datadir' in the config file. Note 7 | that, you can use `prepare_data.py` script to download and apply all the 8 | preprocessing steps needed automatically. Please see `prepare_data.py` for 9 | further options of preprocessing. 10 | 11 | .. [BCB] Dzmitry Bahdanau, Kyunghyun Cho and Yoshua Bengio. Neural 12 | Machine Translation by Jointly Learning to Align and Translate. 13 | """ 14 | 15 | import argparse 16 | import logging 17 | import pprint 18 | 19 | import configurations_base 20 | 21 | from train import main 22 | from afterprocess import afterprocesser 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | # Get the arguments 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument( 29 | "--proto", default="topicAwareJPData", 30 | help="Prototype config to use for config") 31 | parser.add_argument( 32 | "--bokeh", default=False, action="store_true", 33 | help="Use bokeh server for plotting") 34 | parser.add_argument( 35 | "--mode", choices=["train", "translate"], default='translate', 36 | help="The mode to run. In the `train` mode a model is trained." 37 | " In the `translate` mode a trained model is used to translate" 38 | " an input file and generates tokenized translation.") 39 | parser.add_argument( 40 | "--test-file", default='', help="Input test file for `translate` mode") 41 | args = parser.parse_args() 42 | 43 | 44 | if __name__ == "__main__": 45 | # Get configurations for model 46 | config = getattr(configurations_base, args.proto)() 47 | # configuration['test_set'] = args.test_file 48 | # logger.info("Model options:\n{}".format(pprint.pformat(configuration))) 49 | # Get data streams and call main 50 | main(args.mode, config, args.bokeh) 51 | -------------------------------------------------------------------------------- /get_valid_status.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cPickle 4 | import operator 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | model_dir = 'models/' 9 | BLEU = 'validation_bleu' 10 | COST = 'validation_cost' 11 | PLOT_BLEU = True 12 | PLOT_COST = False 13 | 14 | 15 | def get_log(path): 16 | filenames = os.listdir(path) 17 | logs = [f for f in filenames if f.startswith('log')] 18 | if len(logs) == 0: 19 | return None 20 | iterations = [int(l.split('.')[-1]) for l in logs if '.' in l] 21 | if len(iterations) == 0: 22 | return cPickle.load(open(os.path.join(path, logs[0]), 'rb')) 23 | x = max(iterations) 24 | return cPickle.load(open(os.path.join(path, 'log.' + str(x)), 'rb')) 25 | 26 | 27 | lines = [] 28 | names = [] 29 | def main(): 30 | for model_name in os.listdir(model_dir): 31 | if model_name.endswith('bk'): 32 | continue 33 | model_path = os.path.join(model_dir, model_name) 34 | if not os.path.isdir(model_path): 35 | continue 36 | log = get_log(model_path) 37 | if log is None: 38 | continue 39 | log = sorted(log.items(), key=operator.itemgetter(0)) 40 | bleus = [(x, y[BLEU]) for x, y in log if BLEU in y] 41 | costs = [(x, y[COST]) for x, y in log if COST in y] 42 | if len(bleus) == 0: 43 | continue 44 | 45 | if PLOT_BLEU: 46 | line, = plt.plot(*zip(*bleus), 47 | linewidth=2.0, 48 | label=model_name + ' bleu', 49 | marker='+') 50 | names.append(model_name + ' bleu') 51 | 52 | if PLOT_COST: 53 | line, = plt.plot(*zip(*costs), 54 | linewidth=2.0, 55 | label=model_name + 'cost', 56 | marker='+') 57 | names.append(model_name + ' cost') 58 | lines.append(line) 59 | 60 | 61 | plt.legend(lines, names, loc=4) 62 | plt.show() 63 | 64 | 65 | if __name__ == '__main__': 66 | main() 67 | -------------------------------------------------------------------------------- /visualize_attention.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import pickle 4 | import numpy as np 5 | import matplotlib 6 | import matplotlib.pyplot as plt 7 | from matplotlib.font_manager import FontProperties 8 | from progressbar import ProgressBar 9 | from PIL import Image 10 | import sys 11 | import configurations 12 | import argparse 13 | 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument( 17 | "--proto", default="normal_adagrad", 18 | help="Prototype config to use for config") 19 | args = parser.parse_args() 20 | 21 | 22 | def showmat(name, mat, alpha, beta): 23 | sample_dir = 'samples/' 24 | alpha = [a.decode('utf-8') for a in alpha] 25 | beta = [b.decode('utf-8') for b in beta] 26 | 27 | fig = plt.figure(figsize=(20, 20), dpi=80) 28 | plt.clf() 29 | matplotlib.rcParams.update({'font.size': 18}) 30 | ax = fig.add_subplot(111) 31 | ax.set_aspect(1) 32 | ax.xaxis.tick_top() 33 | res = ax.imshow(mat, cmap=plt.cm.Blues, 34 | interpolation='nearest') 35 | 36 | font_prop = FontProperties() 37 | font_prop.set_file('./wqy-zenhei.ttf') 38 | font_prop.set_size('large') 39 | 40 | plt.xticks(range(len(alpha)), alpha, rotation=60, 41 | fontproperties=font_prop) 42 | plt.yticks(range(len(beta)), beta, fontproperties=font_prop) 43 | 44 | cax = plt.axes([0.0, 0.0, 0.0, 0.0]) 45 | plt.colorbar(mappable=res, cax=cax) 46 | plt.savefig(name + '.png', format='png') 47 | plt.close() 48 | 49 | 50 | def main(config): 51 | images_dir = config['attention_images'] 52 | if not os.path.exists(images_dir): 53 | os.mkdir(images_dir) 54 | 55 | source_file = open(config['val_set_source'], 'r').readlines() 56 | target_file = pickle.load(open(config['val_output_repl'] + '.pkl', 'rb')) 57 | weights = pickle.load(open(config['attention_weights'], 'rb')) 58 | 59 | pbar = ProgressBar(max_value=len(source_file)).start() 60 | for i, (source, target, weight) in enumerate( 61 | zip( source_file, target_file, weights)): 62 | pbar.update(i + 1) 63 | source = source.strip().split() 64 | showmat(images_dir + str(i), weight, source, target) 65 | pbar.finish() 66 | 67 | def crop(config): 68 | indir = config['attention_images'] 69 | outdir = config['attention_images'] + '/cropped' 70 | if not os.path.exists(outdir): 71 | os.mkdir(outdir) 72 | for fname in os.listdir(indir): 73 | inpath = os.path.join(indir, fname) 74 | outpath = os.path.join(outdir, fname) 75 | if os.path.isdir(inpath): 76 | continue 77 | image = Image.open(inpath) 78 | w, h = image.size 79 | image = image.crop((0, 0, w, h-12)) 80 | image.save(outpath, 'png') 81 | 82 | 83 | if __name__ == '__main__': 84 | config = getattr(configurations, args.proto)() 85 | main(config) 86 | crop(config) 87 | -------------------------------------------------------------------------------- /match_functions.py: -------------------------------------------------------------------------------- 1 | from theano import tensor 2 | from blocks.bricks.base import application, Brick, lazy 3 | from blocks.bricks import (Brick, Initializable, Sequence, 4 | Feedforward, Linear, Tanh) 5 | from blocks.utils import dict_union, dict_subset, pack 6 | 7 | 8 | class ShallowEnergyComputer(Initializable, Feedforward): 9 | """A simple energy computer: first tanh, then weighted sum.""" 10 | @lazy() 11 | def __init__(self, **kwargs): 12 | super(ShallowEnergyComputer, self).__init__(**kwargs) 13 | self.tanh = Tanh() 14 | self.linear = Linear(use_bias=False) 15 | self.children = [self.tanh, self.linear] 16 | 17 | @application 18 | def apply(self, *args): 19 | output = args 20 | output = self.tanh.apply(*pack(output)) 21 | output = self.linear.apply(*pack(output)) 22 | return output 23 | 24 | @property 25 | def input_dim(self): 26 | return self.children[1].input_dim 27 | 28 | @input_dim.setter 29 | def input_dim(self, value): 30 | self.children[1].input_dim = value 31 | 32 | @property 33 | def output_dim(self): 34 | return self.children[1].output_dim 35 | 36 | @output_dim.setter 37 | def output_dim(self, value): 38 | self.children[1].output_dim = value 39 | 40 | 41 | class SumMatchFunction(Initializable, Feedforward): 42 | 43 | @lazy() 44 | def __init__(self, **kwargs): 45 | super(SumMatchFunction, self).__init__(**kwargs) 46 | self.shallow = ShallowEnergyComputer() 47 | self.children = [self.shallow] 48 | 49 | @application 50 | def apply(self, states, attended): 51 | match_vectors = states + attended 52 | energies = self.shallow.apply(*pack(match_vectors)) 53 | energies = energies.reshape( 54 | match_vectors.shape[:-1], ndim=match_vectors.ndim - 1) 55 | return energies 56 | 57 | @property 58 | def input_dim(self): 59 | return self.children[0].input_dim 60 | 61 | @input_dim.setter 62 | def input_dim(self, value): 63 | self.children[0].input_dim = value 64 | 65 | @property 66 | def output_dim(self): 67 | return self.children[0].output_dim 68 | 69 | @output_dim.setter 70 | def output_dim(self, value): 71 | self.children[0].output_dim = value 72 | 73 | 74 | class CatMatchFunction(Initializable, Feedforward): 75 | 76 | @lazy() 77 | def __init__(self, **kwargs): 78 | super(CatMatchFunction, self).__init__(**kwargs) 79 | self.shallow = ShallowEnergyComputer() 80 | self.children = [self.shallow] 81 | 82 | @application 83 | def apply(self, states, attended): 84 | states = tensor.repeat(states[None, :, :], attended.shape[0], axis=0) 85 | match_vectors = tensor.concatenate([states, attended], axis=2) 86 | energies = self.shallow.apply(*pack(match_vectors)) 87 | energies = energies.reshape( 88 | match_vectors.shape[:-1], ndim=match_vectors.ndim - 1) 89 | return energies 90 | 91 | @property 92 | def input_dim(self): 93 | return self.children[0].input_dim 94 | 95 | @input_dim.setter 96 | def input_dim(self, value): 97 | # because we concat to input_dim is match_dim * 2 98 | self.children[0].input_dim = value * 2 99 | 100 | @property 101 | def output_dim(self): 102 | return self.children[0].output_dim 103 | 104 | @output_dim.setter 105 | def output_dim(self, value): 106 | self.children[0].output_dim = value 107 | 108 | 109 | class DotMatchFunction(Initializable, Feedforward): 110 | 111 | @lazy() 112 | def __init__(self, **kwargs): 113 | super(DotMatchFunction, self).__init__(**kwargs) 114 | 115 | @application 116 | def apply(self, states, attended): 117 | match_vectors = tensor.tensordot( 118 | attended, states, axes=[2, 1])[:, 0, :] 119 | energies = tensor.exp(match_vectors) 120 | return energies 121 | 122 | 123 | class GeneralMatchFunction(Initializable, Feedforward): 124 | 125 | @lazy() 126 | def __init__(self, **kwargs): 127 | super(GeneralMatchFunction, self).__init__(**kwargs) 128 | self.linear = Linear(use_bias=False) 129 | self.children = [self.linear] 130 | 131 | @application 132 | def apply(self, states, attended): 133 | states = self.linear.apply(*pack(states)) 134 | match_vectors = tensor.tensordot( 135 | attended, states, axes=[2, 1])[:, 0, :] 136 | energies = tensor.exp(match_vectors) 137 | return energies 138 | 139 | @property 140 | def input_dim(self): 141 | return self.children[0].input_dim 142 | 143 | @input_dim.setter 144 | def input_dim(self, value): 145 | self.children[0].input_dim = value 146 | self.children[0].output_dim = value 147 | -------------------------------------------------------------------------------- /afterprocess.py: -------------------------------------------------------------------------------- 1 | import cPickle 2 | import numpy as np 3 | import configurations 4 | import argparse 5 | import operator 6 | 7 | class afterprocesser: 8 | 9 | def __init__(self, config): 10 | self.config = config 11 | 12 | def is_unk(self, s): 13 | return s == '' 14 | 15 | def is_dollar(self, s): 16 | return s.startswith('$') 17 | 18 | def is_eol(self, s): 19 | return s == '' 20 | 21 | def process_sent(self, src, sent, weights, 22 | trans_table, repl_table, att_table): 23 | replaced = [] 24 | replaced_cut = [] 25 | for word, ws in zip(sent, weights): 26 | att = np.argmax(ws) 27 | if self.is_unk(word): 28 | if att in repl_table: 29 | mark, repl = repl_table[att] 30 | replaced.append(repl) 31 | replaced_cut.append(repl) 32 | elif att < len(src): 33 | if src[att] in trans_table: 34 | repl, freq = trans_table[src[att]][0] 35 | replaced.append(repl) 36 | replaced_cut.append(repl) 37 | elif src[att] in att_table: 38 | repl, freq = att_table[src[att]][0] 39 | replaced.append(repl) 40 | replaced_cut.append(repl) 41 | else: 42 | replaced.append('$' + src[att].strip()) 43 | else: 44 | replaced.append(word) 45 | elif self.is_dollar(word): 46 | if att in repl_table: 47 | mark, repl = repl_table[att] 48 | replaced.append(repl) 49 | replaced_cut.append(repl) 50 | elif att < len(src): 51 | if src[att] in trans_table: 52 | repl, freq = trans_table[src[att]][0] 53 | replaced.append(repl) 54 | replaced_cut.append(repl) 55 | elif src[att] in att_table: 56 | repl, freq = att_table[src[att]][0] 57 | replaced.append(repl) 58 | replaced_cut.append(repl) 59 | else: 60 | replaced.append('$' + src[att].strip()) 61 | else: 62 | replaced.append(word) 63 | elif not self.is_eol(word): 64 | replaced.append(word) 65 | replaced_cut.append(word) 66 | return replaced, replaced_cut 67 | 68 | 69 | def main(self): 70 | val_set = self.config['val_set_source'] 71 | source_file = open(val_set, 'r').readlines() 72 | original_file = open(self.config['val_output_orig'], 'r').readlines() 73 | replaced_file = open(self.config['val_output_repl'], 'wb') 74 | replaced_pkl = open(self.config['val_output_repl'] + '.pkl', 'wb') 75 | weights = cPickle.load(open(self.config['attention_weights'], 'rb')) 76 | translation_table = cPickle.load(open(self.config['translation_table'], 'rb')) 77 | replacement_table = cPickle.load(open(self.config['replacement_table'], 'rb')) 78 | 79 | att_table = dict() 80 | ''' 81 | for mat, src, trg in zip(weights, source_file, original_file): 82 | src = src.split() 83 | trg = trg.split() 84 | for line, word in zip(mat.T, src): 85 | line = line / line.sum() 86 | i = line.argmax() 87 | if self.is_unk(trg[i]) or self.is_dollar(trg[i]) or self.is_eol(trg[i]): 88 | continue 89 | if word not in att_table: 90 | att_table[word] = dict() 91 | if trg[i] not in att_table[word]: 92 | att_table[word][trg[i]] = 0 93 | att_table[word][trg[i]] += 1 94 | for key, value in att_table.items(): 95 | value = sorted(value.items(), key=operator.itemgetter(1), reverse=True) 96 | att_table[key] = value 97 | ''' 98 | 99 | all_replaced = [] 100 | for i, (source, sent, weight, repl) in enumerate(zip( 101 | source_file, original_file, weights, replacement_table)): 102 | sent = sent.strip().split(' ') 103 | source = source.strip().split(' ') 104 | replaced, replaced_cut = self.process_sent(source, sent, weight, 105 | translation_table, repl, 106 | att_table) 107 | all_replaced.append(replaced) 108 | replaced_file.write(' '.join(replaced_cut) + '\n') 109 | cPickle.dump(all_replaced, replaced_pkl) 110 | 111 | 112 | if __name__ == '__main__': 113 | config = configurations.normal_adagrad() 114 | ap = afterprocesser(config) 115 | ap.main() 116 | -------------------------------------------------------------------------------- /GRU.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy 3 | from theano import tensor, Variable 4 | from blocks.bricks import Initializable, Logistic, Tanh, Linear, MLP 5 | from blocks.bricks.base import Application, application, Brick, lazy 6 | from blocks.initialization import NdarrayInitialization 7 | from blocks.roles import add_role, WEIGHT, INITIAL_STATE 8 | from blocks.utils import (pack, shared_floatx_nans, shared_floatx_zeros, 9 | dict_union, dict_subset, is_shared_variable) 10 | from blocks.bricks.parallel import Fork 11 | from blocks.bricks.recurrent import recurrent, BaseRecurrent 12 | 13 | 14 | class GRU(BaseRecurrent, Initializable): 15 | @lazy(allocation=['dim']) 16 | def __init__(self, dim, attended_dim, 17 | activation=None, gate_activation=None, 18 | **kwargs): 19 | super(GRU, self).__init__(**kwargs) 20 | self.dim = dim 21 | self.attended_dim = attended_dim 22 | 23 | if not activation: 24 | activation = Tanh() 25 | if not gate_activation: 26 | gate_activation = Logistic() 27 | self.activation = activation 28 | self.gate_activation = gate_activation 29 | 30 | self.initial_transformer = MLP(activations=[Tanh()], 31 | dims=[attended_dim, self.dim], 32 | name='state_initializer') 33 | 34 | self.children = [activation, gate_activation, self.initial_transformer] 35 | 36 | @property 37 | def state_to_state(self): 38 | return self.parameters[0] 39 | 40 | @property 41 | def state_to_gates(self): 42 | return self.parameters[1] 43 | 44 | def get_dim(self, name): 45 | if name == 'mask': 46 | return 0 47 | if name in ['inputs', 'states']: 48 | return self.dim 49 | if name == 'gate_inputs': 50 | return 2 * self.dim 51 | return super(GRU, self).get_dim(name) 52 | 53 | def _allocate(self): 54 | ''' 55 | self.parameters.append(shared_floatx_nans((self.dim, self.dim), 56 | name='state_to_state')) 57 | self.parameters.append(shared_floatx_nans((self.dim, 2 * self.dim), 58 | name='state_to_gates')) 59 | self.parameters.append(shared_floatx_zeros((self.dim,), 60 | name="initial_state")) 61 | for i in range(2): 62 | if self.parameters[i]: 63 | add_role(self.parameters[i], WEIGHT) 64 | add_role(self.parameters[2], INITIAL_STATE) 65 | ''' 66 | self.parameters.append(shared_floatx_nans((self.dim, self.dim), 67 | name='state_to_state')) 68 | self.parameters.append(shared_floatx_nans((self.dim, 2 * self.dim), 69 | name='state_to_gates')) 70 | for i in range(2): 71 | if self.parameters[i]: 72 | add_role(self.parameters[i], WEIGHT) 73 | 74 | def _initialize(self): 75 | self.weights_init.initialize(self.state_to_state, self.rng) 76 | state_to_update = self.weights_init.generate( 77 | self.rng, (self.dim, self.dim)) 78 | state_to_reset = self.weights_init.generate( 79 | self.rng, (self.dim, self.dim)) 80 | self.state_to_gates.set_value( 81 | numpy.hstack([state_to_update, state_to_reset])) 82 | 83 | @recurrent(sequences=['mask', 'inputs', 'gate_inputs'], 84 | states=['states'], outputs=['states'], contexts=[]) 85 | def apply(self, inputs, gate_inputs, states, mask=None): 86 | """Apply the gated recurrent transition. 87 | 88 | Parameters 89 | ---------- 90 | states : :class:`~tensor.TensorVariable` 91 | The 2 dimensional matrix of current states in the shape 92 | (batch_size, dim). Required for `one_step` usage. 93 | inputs : :class:`~tensor.TensorVariable` 94 | The 2 dimensional matrix of inputs in the shape (batch_size, 95 | dim) 96 | gate_inputs : :class:`~tensor.TensorVariable` 97 | The 2 dimensional matrix of inputs to the gates in the 98 | shape (batch_size, 2 * dim). 99 | mask : :class:`~tensor.TensorVariable` 100 | A 1D binary array in the shape (batch,) which is 1 if there is 101 | data available, 0 if not. Assumed to be 1-s only if not given. 102 | 103 | Returns 104 | ------- 105 | output : :class:`~tensor.TensorVariable` 106 | Next states of the network. 107 | 108 | """ 109 | gate_values = self.gate_activation.apply( 110 | states.dot(self.state_to_gates) + gate_inputs) 111 | update_values = gate_values[:, :self.dim] 112 | reset_values = gate_values[:, self.dim:] 113 | states_reset = states * reset_values 114 | next_states = self.activation.apply( 115 | states_reset.dot(self.state_to_state) + inputs) 116 | next_states = (next_states * update_values + 117 | states * (1 - update_values)) 118 | if mask: 119 | next_states = (mask[:, None] * next_states + 120 | (1 - mask[:, None]) * states) 121 | return next_states 122 | 123 | @application(outputs=apply.states) 124 | def initial_states(self, batch_size, *args, **kwargs): 125 | ''' 126 | return [tensor.repeat(self.parameters[2][None, :], batch_size, 0)] 127 | ''' 128 | attended = kwargs['attended'] 129 | initial_state = self.initial_transformer.apply( 130 | attended[0, :, -self.attended_dim:]) 131 | return initial_state 132 | -------------------------------------------------------------------------------- /learning_rate_halver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy 3 | import logging 4 | import operator 5 | import cPickle 6 | from blocks.extensions.training import SharedVariableModifier, SimpleExtension 7 | from blocks.serialization import secure_dump, load, BRICK_DELIMITER 8 | from checkpoint import SaveLoadUtils 9 | 10 | BLEU = 'validation_bleu' 11 | COST = 'validation_cost' 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def halver(t, x): 17 | return x / 2 18 | 19 | 20 | class LearningRateHalver(SharedVariableModifier, SaveLoadUtils): 21 | 22 | def __init__(self, record_name, comparator, learning_rate, 23 | patience_default, lower_threshold=0.001): 24 | self.record_name = record_name 25 | self.comparator = comparator 26 | self.learning_rate = learning_rate 27 | self.patience_default = patience_default 28 | self.patience = patience_default 29 | self.lower_threshold = lower_threshold 30 | super(LearningRateHalver, self).__init__(self.learning_rate, halver) 31 | 32 | def do_half_nan(self): 33 | if 'perplexity' in self.main_loop.log.current_row: 34 | pepl = self.main_loop.log.current_row['perplexity'].tolist() 35 | return numpy.isnan(pepl) 36 | 37 | def do_half_patient(self): 38 | logs = sorted(self.main_loop.log.items(), 39 | key=operator.itemgetter(0), 40 | reverse=True) 41 | bleu_values = [y[self.record_name] for x, y in logs 42 | if self.record_name in y] 43 | if len(bleu_values) < 2: 44 | return False 45 | current_value = bleu_values[-1] 46 | previous_value = bleu_values[-2] 47 | if self.comparator(current_value, previous_value): 48 | self.patience -= 1 49 | if self.patience == 0: 50 | self.patience = self.patience_default 51 | return True 52 | else: 53 | self.patience = self.patience_default 54 | self.remove_old_models() 55 | return False 56 | 57 | def reload_parameters(self, path): 58 | params = self.load_parameter_values(path) 59 | self.set_model_parameters(self.main_loop.model, params) 60 | 61 | def reload_iteration_state(self, path): 62 | with open(path, 'rb') as source: 63 | self.main_loop.iteration_state = load(source) 64 | 65 | def reload_log(self, path): 66 | with open(path, 'rb') as source: 67 | self.main_loop.log = cPickle.load(source) 68 | 69 | def reload_previous_model(self, step_back): 70 | paths = sorted(self.main_loop.log.items(), 71 | key=operator.itemgetter(0), 72 | reverse=True) 73 | paths = [y['saved_to'] for x, y in paths if 'saved_to' in y] 74 | paths = [path for path in paths if all([os.path.exists(p) for p in path])] 75 | if len(paths) < 1: 76 | return 77 | idx = min(step_back, len(paths) - 1) 78 | path = paths[idx] 79 | to_be_removed = paths[:idx] + paths[idx+1:] 80 | reload_from = path[0].split('.')[-1] 81 | logger.info('Reloading model from ' + reload_from) 82 | self.reload_parameters(path[0]) 83 | self.reload_iteration_state(path[1]) 84 | self.reload_log(path[2]) 85 | self.main_loop.log.current_row['reload_from'] = int(reload_from) 86 | self.remove_models(to_be_removed) 87 | 88 | def remove_models(self, paths): 89 | [os.remove(p) for pp in paths for p in pp if os.path.exists(p)] 90 | 91 | def remove_old_models(self): 92 | paths = sorted(self.main_loop.log.items(), 93 | key=operator.itemgetter(0), 94 | reverse=True) 95 | paths = [y['saved_to'] for x, y in paths if 'saved_to' in y] 96 | paths = [path for path in paths if all([os.path.exists(p) for p in path])] 97 | to_be_removed = paths[:-3] 98 | self.remove_models(to_be_removed) 99 | 100 | def do(self, which_callback, *args): 101 | current_learning_rate = self.learning_rate.get_value().tolist() 102 | self.main_loop.log.current_row['learning_rate'] = current_learning_rate 103 | if current_learning_rate < self.lower_threshold: 104 | self.main_loop.log.current_row['training_finish_requested'] = True 105 | if self.record_name in self.main_loop.log.current_row: 106 | if self.do_half_nan(): 107 | self.reload_previous_model(1) 108 | super(LearningRateHalver, self).do(which_callback, *args) 109 | if self.do_half_patient(): 110 | self.reload_previous_model(self.patience_default) 111 | super(LearningRateHalver, self).do(which_callback, *args) 112 | 113 | 114 | def doubler(t, x): 115 | return x / 2 116 | 117 | 118 | class LearningRateDoubler(SharedVariableModifier): 119 | 120 | def __init__(self, record_name, comparator, learning_rate, patience_default): 121 | self.record_name = record_name 122 | self.comparator = comparator 123 | self.learning_rate = learning_rate 124 | self.patience_default = patience_default 125 | self.patience = patience_default 126 | super(LearningRateDoubler, self).__init__(self.learning_rate, doubler) 127 | 128 | def do_double(self): 129 | logs = sorted(self.main_loop.log.items(), 130 | key=operator.itemgetter(0), 131 | reverse=True) 132 | bleu_values = [y[self.record_name] for x, y in logs 133 | if self.record_name in y] 134 | if len(bleu_values) < 2: 135 | return False 136 | current_value = bleu_values[-1] 137 | previous_value = bleu_values[-2] 138 | if self.comparator(current_value, previous_value): 139 | self.patience -= 1 140 | if self.patience == 0: 141 | self.patience = self.patience_default 142 | return True 143 | else: 144 | self.patience = self.patience_default 145 | return False 146 | 147 | def do(self, which_callback, *args): 148 | self.main_loop.log.current_row['learning_rate'] = \ 149 | self.learning_rate.get_value().tolist() 150 | if self.record_name in self.main_loop.log.current_row: 151 | if self.do_double(): 152 | super(LearningRateDoubler, self).do(which_callback, *args) 153 | 154 | 155 | class OldModelRemover(SimpleExtension): 156 | 157 | def __init__(self, saveto, **kwargs): 158 | self.saveto = saveto 159 | super(OldModelRemover, self).__init__(**kwargs) 160 | 161 | def remove_old_models(self): 162 | params_prefix = 'params.npz.' 163 | states_prefix = 'iteration_states.pkl.' 164 | logs_prefix = 'log.' 165 | fnames = os.listdir(self.saveto) 166 | params = [f for f in fnames if f.startswith(params_prefix)] 167 | states = [f for f in fnames if f.startswith(states_prefix)] 168 | logs = [f for f in fnames if f.startswith(log_prefix)] 169 | for f in params + states + logs: 170 | num = int(f.split('.')[-1]) 171 | if self.main_loop.status['iterations_done'] - num > 3000: 172 | os.remove(os.path.join(self.saveto, f)) 173 | 174 | def do(self, which_callback, *args): 175 | current_row = self.main_loop.log.current_row 176 | if BLEU in current_row or COST in current_row: 177 | remove_old_models() 178 | -------------------------------------------------------------------------------- /checkpoint.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy 3 | import os 4 | import time 5 | 6 | from contextlib import closing 7 | from six.moves import cPickle 8 | 9 | from blocks.extensions.saveload import SAVED_TO, LOADED_FROM 10 | from blocks.extensions import TrainingExtension, SimpleExtension 11 | from blocks.serialization import secure_dump, load, BRICK_DELIMITER 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class SaveLoadUtils(object): 17 | 18 | @property 19 | def path_to_folder(self): 20 | return self.folder 21 | 22 | @property 23 | def path_to_parameters(self): 24 | return os.path.join(self.folder, 'params.npz') 25 | 26 | @property 27 | def path_to_iter_state(self): 28 | return os.path.join(self.folder, 'iterations_state.pkl') 29 | 30 | @property 31 | def path_to_log(self): 32 | return os.path.join(self.folder, 'log') 33 | 34 | def load_parameter_values(self, path): 35 | with closing(numpy.load(path)) as source: 36 | param_values = {} 37 | for name, value in source.items(): 38 | if name != 'pkl': 39 | name_ = name.replace(BRICK_DELIMITER, '/') 40 | if not name_.startswith('/'): 41 | name_ = '/' + name_ 42 | param_values[name_] = value 43 | return param_values 44 | 45 | def save_parameter_values(self, param_values, path): 46 | param_values = {name.replace("/", "-"): param 47 | for name, param in param_values.items()} 48 | with open(path, 'wb') as outfile: 49 | numpy.savez(outfile, **param_values) 50 | 51 | def set_model_parameters(self, model, params): 52 | params_this = model.get_parameter_dict() 53 | missing = set(params_this.keys()) - set(params.keys()) 54 | for pname in params_this.keys(): 55 | if pname in params: 56 | val = params[pname] 57 | if params_this[pname].get_value().shape != val.shape: 58 | logger.warning( 59 | " Dimension mismatch {}-{} for {}" 60 | .format(params_this[pname].get_value().shape, 61 | val.shape, pname)) 62 | 63 | params_this[pname].set_value(val) 64 | ''' 65 | logger.info(" Loaded to CG {:15}: {}" 66 | .format(val.shape, pname)) 67 | ''' 68 | else: 69 | logger.warning( 70 | " Parameter does not exist: {}".format(pname)) 71 | logger.info( 72 | " Number of parameters loaded for computation graph: {}" 73 | .format(len(params_this) - len(missing))) 74 | 75 | 76 | class CheckpointNMT(SimpleExtension, SaveLoadUtils): 77 | 78 | def __init__(self, saveto, model_name, **kwargs): 79 | self.folder = saveto 80 | self.model_name = model_name 81 | kwargs.setdefault("after_training", True) 82 | super(CheckpointNMT, self).__init__(**kwargs) 83 | 84 | def enhance_path(self, main_loop, path): 85 | return path + '.' + str(main_loop.status['iterations_done']) 86 | 87 | def dump_parameters(self, main_loop): 88 | params_to_save = main_loop.model.get_parameter_values() 89 | self.save_parameter_values(params_to_save, 90 | self.enhance_path(main_loop, 91 | self.path_to_parameters)) 92 | 93 | def dump_iteration_state(self, main_loop): 94 | secure_dump(main_loop.iteration_state, 95 | self.enhance_path(main_loop, self.path_to_iter_state)) 96 | 97 | def dump_log(self, main_loop): 98 | secure_dump(main_loop.log, 99 | self.enhance_path(main_loop, self.path_to_log), 100 | cPickle.dump) 101 | 102 | def dump(self, main_loop): 103 | if not os.path.exists(self.path_to_folder): 104 | os.mkdir(self.path_to_folder) 105 | print("") 106 | logger.info(" Saving model: " + self.model_name) 107 | start = time.time() 108 | logger.info(" Saving parameters") 109 | self.dump_parameters(main_loop) 110 | logger.info(" Saving iteration state") 111 | self.dump_iteration_state(main_loop) 112 | logger.info(" Saving log") 113 | self.dump_log(main_loop) 114 | logger.info(" Model saved, took {} seconds.".format(time.time()-start)) 115 | 116 | def do(self, callback_name, *args): 117 | try: 118 | self.dump(self.main_loop) 119 | except Exception: 120 | raise 121 | finally: 122 | ''' 123 | already_saved_to = self.main_loop.log.current_row.get(SAVED_TO, ()) 124 | self.main_loop.log.current_row[SAVED_TO] = (already_saved_to + 125 | (self.path_to_folder + 126 | 'params.npz',)) 127 | ''' 128 | self.main_loop.log.current_row[SAVED_TO] = [ 129 | self.enhance_path(self.main_loop, self.path_to_parameters), 130 | self.enhance_path(self.main_loop, self.path_to_iter_state), 131 | self.enhance_path(self.main_loop, self.path_to_log)] 132 | 133 | 134 | class LoadNMT(TrainingExtension, SaveLoadUtils): 135 | 136 | def __init__(self, saveto, **kwargs): 137 | self.folder = saveto 138 | super(LoadNMT, self).__init__(saveto, **kwargs) 139 | 140 | def before_training(self): 141 | if not os.path.exists(self.path_to_folder): 142 | return 143 | self.load_last_model(self.main_loop) 144 | 145 | def load_parameters(self, path): 146 | return self.load_parameter_values(path) 147 | 148 | def load_parameters_default(self): 149 | return self.load_parameter_values(self.path_to_parameters) 150 | 151 | def load_iteration_state(self, path): 152 | with open(path, "rb") as source: 153 | return load(source) 154 | 155 | def load_log(self, path): 156 | with open(path, "rb") as source: 157 | return cPickle.load(source) 158 | 159 | def get_last_save(self, saves, prefix): 160 | if len(saves) == 0: 161 | return None 162 | if prefix in saves: 163 | return prefix 164 | nums = [int(s[len(prefix)+1:]) for s in saves] 165 | return prefix + '.' + str(max(nums)) 166 | 167 | def load_last_model(self, main_loop): 168 | param_prefix = 'params.npz' 169 | state_prefix = 'iterations_state.pkl' 170 | log_prefix = 'log' 171 | 172 | files = os.listdir(self.folder) 173 | params = [f for f in files if f.startswith(param_prefix)] 174 | states = [f for f in files if f.startswith(state_prefix)] 175 | logs = [f for f in files if f.startswith(log_prefix)] 176 | 177 | param_name = self.get_last_save(params, param_prefix) 178 | if param_name is not None: 179 | logger.info(" Loading params from " + param_name) 180 | params_all = self.load_parameters( 181 | os.path.join(self.path_to_folder, param_name)) 182 | self.set_model_parameters(main_loop.model, params_all) 183 | 184 | state_name = self.get_last_save(states, state_prefix) 185 | if state_name is not None: 186 | logger.info(" Loading state from " + state_name) 187 | main_loop.iteration_state = self.load_iteration_state( 188 | os.path.join(self.path_to_folder, state_name)) 189 | 190 | log_name = self.get_last_save(logs, log_prefix) 191 | if log_name is not None: 192 | logger.info(" Loading log from " + log_name) 193 | main_loop.log = self.load_log( 194 | os.path.join(self.path_to_folder, log_name)) 195 | 196 | if param_name is not None and len(param_name) < len(param_prefix): 197 | main_loop.log.current_row[LOADED_FROM] = \ 198 | int(param_name[len(param_prefix)+1:]) 199 | -------------------------------------------------------------------------------- /SequenceGenerator.py: -------------------------------------------------------------------------------- 1 | from theano import tensor 2 | 3 | from blocks.bricks import Initializable, Random, Bias, NDimensionalSoftmax 4 | from blocks.bricks.base import application, Brick, lazy 5 | from blocks.bricks.parallel import Fork, Merge 6 | from blocks.bricks.lookup import LookupTable 7 | from blocks.bricks.recurrent import recurrent 8 | from blocks.roles import add_role, COST 9 | from blocks.utils import dict_union, dict_subset 10 | from blocks.bricks.sequence_generators import ( 11 | BaseSequenceGenerator, FakeAttentionRecurrent) 12 | from attention import AttentionRecurrent 13 | 14 | 15 | class SequenceGenerator(BaseSequenceGenerator): 16 | r"""A more user-friendly interface for :class:`BaseSequenceGenerator`. 17 | 18 | Parameters 19 | ---------- 20 | readout : instance of :class:`AbstractReadout` 21 | The readout component for the sequence generator. 22 | transition : instance of :class:`.BaseRecurrent` 23 | The recurrent transition to be used in the sequence generator. 24 | Will be combined with `attention`, if that one is given. 25 | attention : object, optional 26 | The attention mechanism to be added to ``transition``, 27 | an instance of 28 | :class:`~blocks.bricks.attention.AbstractAttention`. 29 | add_contexts : bool 30 | If ``True``, the 31 | :class:`.AttentionRecurrent` wrapping the 32 | `transition` will add additional contexts for the attended and its 33 | mask. 34 | \*\*kwargs : dict 35 | All keywords arguments are passed to the base class. If `fork` 36 | keyword argument is not provided, :class:`.Fork` is created 37 | that forks all transition sequential inputs without a "mask" 38 | substring in them. 39 | 40 | """ 41 | def __init__(self, readout, transition, attention=None, 42 | use_step_decay_cost=False, 43 | use_doubly_stochastic=False, lambda_ds=0.001, 44 | use_concentration_cost=False, lambda_ct=10, 45 | use_stablilizer=False, lambda_st=50, 46 | add_contexts=True, **kwargs): 47 | self.use_doubly_stochastic = use_doubly_stochastic 48 | self.use_step_decay_cost = use_step_decay_cost 49 | self.use_concentration_cost = use_concentration_cost 50 | self.use_stablilizer = use_stablilizer 51 | self.lambda_ds = lambda_ds 52 | self.lambda_ct = lambda_ct 53 | self.lambda_st = lambda_st 54 | normal_inputs = [name for name in transition.apply.sequences 55 | if 'mask' not in name] 56 | kwargs.setdefault('fork', Fork(normal_inputs)) 57 | if attention: 58 | transition = AttentionRecurrent( 59 | transition, attention, 60 | add_contexts=add_contexts, name="att_trans") 61 | else: 62 | transition = FakeAttentionRecurrent(transition, 63 | name="with_fake_attention") 64 | super(SequenceGenerator, self).__init__( 65 | readout, transition, **kwargs) 66 | 67 | @application 68 | def cost_matrix(self, application_call, outputs, mask=None, **kwargs): 69 | """Returns generation costs for output sequences. 70 | 71 | See Also 72 | -------- 73 | :meth:`cost` : Scalar cost. 74 | 75 | """ 76 | # We assume the data has axes (time, batch, features, ...) 77 | batch_size = outputs.shape[1] 78 | 79 | # Prepare input for the iterative part 80 | states = dict_subset(kwargs, self._state_names, must_have=False) 81 | # masks in context are optional (e.g. `attended_mask`) 82 | contexts = dict_subset(kwargs, self._context_names, must_have=False) 83 | feedback = self.readout.feedback(outputs) 84 | inputs = self.fork.apply(feedback, as_dict=True) 85 | 86 | # Run the recurrent network 87 | results = self.transition.apply( 88 | mask=mask, return_initial_states=True, as_dict=True, 89 | **dict_union(inputs, states, contexts)) 90 | 91 | # Separate the deliverables. The last states are discarded: they 92 | # are not used to predict any output symbol. The initial glimpses 93 | # are discarded because they are not used for prediction. 94 | # Remember, glimpses are computed _before_ output stage, states are 95 | # computed after. 96 | states = {name: results[name][:-1] for name in self._state_names} 97 | glimpses = {name: results[name][1:] for name in self._glimpse_names} 98 | 99 | # Compute the cost 100 | feedback = tensor.roll(feedback, 1, 0) 101 | feedback = tensor.set_subtensor( 102 | feedback[0], 103 | self.readout.feedback(self.readout.initial_outputs(batch_size))) 104 | readouts = self.readout.readout( 105 | feedback=feedback, **dict_union(states, glimpses, contexts)) 106 | costs = self.readout.cost(readouts, outputs) 107 | 108 | if self.use_doubly_stochastic: 109 | # Doubly stochastic cost 110 | # \lambda\sum_{i}(1-\sum_{t}w_{t, i})^2 111 | # the first dimensions of weights returned by transition 112 | # is batch, time 113 | weights = glimpses['weights'] 114 | weights_sum_time = tensor.sum(weights, 0) 115 | penalties = tensor.ones_like(weights_sum_time) - weights_sum_time 116 | penalties_squared = tensor.pow(penalties, 2) 117 | ds_costs = tensor.sum(penalties_squared, 1) 118 | costs += (self.lambda_ds * ds_costs)[None, :] 119 | 120 | def step_decay_cost(states): 121 | # shape is time, batch, features 122 | eta = 0.0001 123 | xi = 100 124 | states_norm = states.norm(2, axis=2) 125 | zz = tensor.zeros([1, states.shape[1]]) 126 | padded_norm = tensor.join(0, zz, states_norm)[:-1, :] 127 | diffs = states_norm - padded_norm 128 | costs = eta * (xi ** diffs) 129 | return costs 130 | 131 | if self.use_step_decay_cost: 132 | costs += step_decay_cost(states['states']) 133 | 134 | def stablilizer_cost(states): 135 | states_norm = states.norm(2, axis=2) 136 | zz = tensor.zeros([1, states.shape[1]]) 137 | padded_norm = tensor.join(0, zz, states_norm)[:-1, :] 138 | diffs = states_norm - padded_norm 139 | costs = tensor.pow(diffs, 2) 140 | return costs 141 | 142 | if self.use_stablilizer: 143 | costs += self.lambda_st * stablilizer_cost(states['states']) 144 | 145 | if self.use_concentration_cost: 146 | # weights has shape [batch, time, source sentence len] 147 | weights = glimpses['weights'] 148 | maxis = tensor.max(weights, axis=2) 149 | lacks = tensor.ones_like(maxis) - maxis 150 | costs += self.lambda_ct * lacks 151 | 152 | if mask is not None: 153 | costs *= mask 154 | 155 | for name, variable in list(glimpses.items()) + list(states.items()): 156 | application_call.add_auxiliary_variable( 157 | variable.copy(), name=name) 158 | 159 | # This variables can be used to initialize the initial states of the 160 | # next batch using the last states of the current batch. 161 | for name in self._state_names: 162 | application_call.add_auxiliary_variable( 163 | results[name][-1].copy(), name=name+"_final_value") 164 | 165 | return costs 166 | 167 | @recurrent 168 | def generate(self, outputs, **kwargs): 169 | """A sequence generation step. 170 | 171 | Parameters 172 | ---------- 173 | outputs : :class:`~tensor.TensorVariable` 174 | The outputs from the previous step. 175 | 176 | Notes 177 | ----- 178 | The contexts, previous states and glimpses are expected as keyword 179 | arguments. 180 | 181 | """ 182 | states = dict_subset(kwargs, self._state_names) 183 | # masks in context are optional (e.g. `attended_mask`) 184 | contexts = dict_subset(kwargs, self._context_names, must_have=False) 185 | glimpses = dict_subset(kwargs, self._glimpse_names) 186 | 187 | next_glimpses = self.transition.take_glimpses( 188 | as_dict=True, **dict_union(states, glimpses, contexts)) 189 | next_readouts = self.readout.readout( 190 | feedback=self.readout.feedback(outputs), 191 | **dict_union(states, next_glimpses, contexts)) 192 | next_outputs = self.readout.emit(next_readouts) 193 | next_costs = self.readout.cost(next_readouts, next_outputs) 194 | next_feedback = self.readout.feedback(next_outputs) 195 | next_inputs = (self.fork.apply(next_feedback, as_dict=True) 196 | if self.fork else {'feedback': next_feedback}) 197 | next_states = self.transition.compute_states( 198 | as_list=True, 199 | **dict_union(next_inputs, states, next_glimpses, contexts)) 200 | return (next_states + [next_outputs] + 201 | list(next_glimpses.values()) + [next_costs]) 202 | 203 | @generate.delegate 204 | def generate_delegate(self): 205 | return self.transition.apply 206 | 207 | @generate.property('states') 208 | def generate_states(self): 209 | return self._state_names + ['outputs'] + self._glimpse_names 210 | 211 | @generate.property('outputs') 212 | def generate_outputs(self): 213 | return (self._state_names + ['outputs'] + 214 | self._glimpse_names + ['costs']) 215 | 216 | def get_dim(self, name): 217 | if name in (self._state_names + self._context_names + 218 | self._glimpse_names): 219 | return self.transition.get_dim(name) 220 | elif name == 'outputs': 221 | return self.readout.get_dim(name) 222 | return super(BaseSequenceGenerator, self).get_dim(name) 223 | 224 | @application 225 | def initial_states(self, batch_size, *args, **kwargs): 226 | # TODO: support dict of outputs for application methods 227 | # to simplify this code. 228 | state_dict = dict( 229 | self.transition.initial_states( 230 | batch_size, as_dict=True, *args, **kwargs), 231 | outputs=self.readout.initial_outputs(batch_size)) 232 | return [state_dict[state_name] 233 | for state_name in self.generate.states] 234 | 235 | @initial_states.property('outputs') 236 | def initial_states_outputs(self): 237 | return self.generate.states 238 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from theano import tensor 2 | from toolz import merge 3 | 4 | from blocks.bricks import (Tanh, Maxout, Linear, FeedforwardSequence, 5 | Bias, Initializable, MLP) 6 | # from blocks.bricks.attention import SequenceContentAttention 7 | from blocks.bricks.base import application 8 | from blocks.bricks.lookup import LookupTable 9 | from blocks.bricks.parallel import Fork 10 | from blocks.bricks.recurrent import GatedRecurrent, Bidirectional 11 | from blocks.bricks.sequence_generators import ( 12 | LookupFeedback, Readout, SoftmaxEmitter) 13 | from blocks.roles import add_role, WEIGHT 14 | from blocks.utils import shared_floatx_nans 15 | 16 | from picklable_itertools.extras import equizip 17 | from attention import SequenceContentAttention 18 | #from SequenceGenerator import SequenceGenerator 19 | from SequenceGenerator_forPickTopicWord import SequenceGenerator 20 | from GRU import GRU 21 | from match_functions import ( 22 | SumMatchFunction, CatMatchFunction, 23 | DotMatchFunction, GeneralMatchFunction) 24 | 25 | 26 | # Helper class 27 | class InitializableFeedforwardSequence(FeedforwardSequence, Initializable): 28 | pass 29 | 30 | 31 | class LookupFeedbackWMT15(LookupFeedback): 32 | """Zero-out initial readout feedback by checking its value.""" 33 | 34 | @application 35 | def feedback(self, outputs): 36 | assert self.output_dim == 0 37 | 38 | shp = [outputs.shape[i] for i in range(outputs.ndim)] 39 | outputs_flat = outputs.flatten() 40 | outputs_flat_zeros = tensor.switch(outputs_flat < 0, 0, 41 | outputs_flat) 42 | 43 | lookup_flat = tensor.switch( 44 | outputs_flat[:, None] < 0, 45 | tensor.alloc(0., outputs_flat.shape[0], self.feedback_dim), 46 | self.lookup.apply(outputs_flat_zeros)) 47 | lookup = lookup_flat.reshape(shp+[self.feedback_dim]) 48 | return lookup 49 | 50 | 51 | class BidirectionalWMT15(Bidirectional): 52 | """Wrap two Gated Recurrents each having separate parameters.""" 53 | 54 | @application 55 | def apply(self, forward_dict, backward_dict): 56 | """Applies forward and backward networks and concatenates outputs.""" 57 | forward = self.children[0].apply(as_list=True, **forward_dict) 58 | backward = [x[::-1] for x in 59 | self.children[1].apply(reverse=True, as_list=True, 60 | **backward_dict)] 61 | return [tensor.concatenate([f, b], axis=2) 62 | for f, b in equizip(forward, backward)] 63 | 64 | 65 | class BidirectionalEncoder(Initializable): 66 | """Encoder of RNNsearch model.""" 67 | 68 | def __init__(self, vocab_size, embedding_dim, state_dim, **kwargs): 69 | super(BidirectionalEncoder, self).__init__(**kwargs) 70 | self.vocab_size = vocab_size 71 | self.embedding_dim = embedding_dim 72 | self.state_dim = state_dim 73 | 74 | self.lookup = LookupTable(name='embeddings') 75 | self.bidir = BidirectionalWMT15( 76 | GatedRecurrent(activation=Tanh(), dim=state_dim)) 77 | self.fwd_fork = Fork( 78 | [name for name in self.bidir.prototype.apply.sequences 79 | if name != 'mask'], prototype=Linear(), name='fwd_fork') 80 | self.back_fork = Fork( 81 | [name for name in self.bidir.prototype.apply.sequences 82 | if name != 'mask'], prototype=Linear(), name='back_fork') 83 | 84 | self.children = [self.lookup, self.bidir, 85 | self.fwd_fork, self.back_fork] 86 | 87 | def _push_allocation_config(self): 88 | self.lookup.length = self.vocab_size 89 | self.lookup.dim = self.embedding_dim 90 | 91 | self.fwd_fork.input_dim = self.embedding_dim 92 | self.fwd_fork.output_dims = [self.bidir.children[0].get_dim(name) 93 | for name in self.fwd_fork.output_names] 94 | self.back_fork.input_dim = self.embedding_dim 95 | self.back_fork.output_dims = [self.bidir.children[1].get_dim(name) 96 | for name in self.back_fork.output_names] 97 | 98 | @application(inputs=['source_sentence', 'source_sentence_mask'], 99 | outputs=['representation']) 100 | def apply(self, source_sentence, source_sentence_mask): 101 | # Time as first dimension 102 | source_sentence = source_sentence.T 103 | source_sentence_mask = source_sentence_mask.T 104 | 105 | embeddings = self.lookup.apply(source_sentence) 106 | 107 | representation = self.bidir.apply( 108 | merge(self.fwd_fork.apply(embeddings, as_dict=True), 109 | {'mask': source_sentence_mask}), 110 | merge(self.back_fork.apply(embeddings, as_dict=True), 111 | {'mask': source_sentence_mask}) 112 | ) 113 | return representation 114 | 115 | class topicalq_transformer(Initializable): 116 | 117 | def __init__(self, vocab_size, topical_embedding_dim, state_dim,word_num,batch_size, 118 | **kwargs): 119 | super(topicalq_transformer, self).__init__(**kwargs) 120 | self.vocab_size = vocab_size; 121 | self.word_embedding_dim = topical_embedding_dim; 122 | self.state_dim = state_dim; 123 | self.word_num=word_num; 124 | self.batch_size=batch_size; 125 | self.look_up=LookupTable(name='topical_embeddings'); 126 | self.transformer=MLP(activations=[Tanh()], 127 | dims=[self.word_embedding_dim*self.word_num, self.state_dim], 128 | name='topical_transformer'); 129 | self.children = [self.look_up,self.transformer]; 130 | 131 | def _push_allocation_config(self): 132 | self.look_up.length = self.vocab_size 133 | self.look_up.dim = self.word_embedding_dim 134 | 135 | 136 | # do we have to push_config? remain unsure 137 | @application(inputs=['source_topical_word_sequence'], 138 | outputs=['topical_embedding']) 139 | def apply(self, source_topical_word_sequence): 140 | # Time as first dimension 141 | source_topical_word_sequence=source_topical_word_sequence.T; 142 | word_topical_embeddings = self.look_up.apply(source_topical_word_sequence); 143 | word_topical_embeddings=word_topical_embeddings.swapaxes(0,1); 144 | #requires testing 145 | concatenated_topical_embeddings=tensor.reshape(word_topical_embeddings,[word_topical_embeddings.shape[0],word_topical_embeddings.shape[1]*word_topical_embeddings.shape[2]]); 146 | topical_embedding=self.transformer.apply(concatenated_topical_embeddings); 147 | return topical_embedding 148 | 149 | class Decoder(Initializable): 150 | """Decoder of RNNsearch model.""" 151 | 152 | def __init__(self, vocab_size, topicWord_size, embedding_dim, state_dim,topical_dim, 153 | representation_dim, match_function='SumMacthFunction', 154 | use_doubly_stochastic=False, lambda_ds=0.001, 155 | use_local_attention=False, window_size=10, 156 | use_step_decay_cost=False, 157 | use_concentration_cost=False, lambda_ct=10, 158 | use_stablilizer=False, lambda_st=50, 159 | theano_seed=None, **kwargs): 160 | super(Decoder, self).__init__(**kwargs) 161 | self.vocab_size = vocab_size 162 | self.topicWord_size= topicWord_size 163 | self.embedding_dim = embedding_dim 164 | self.state_dim = state_dim 165 | self.representation_dim = representation_dim 166 | self.theano_seed = theano_seed 167 | 168 | # Initialize gru with special initial state 169 | self.transition = GRU( 170 | attended_dim=state_dim, dim=state_dim, 171 | activation=Tanh(), name='decoder') 172 | 173 | self.energy_computer = globals()[match_function](name='energy_comp') 174 | 175 | # Initialize the attention mechanism 176 | self.attention = SequenceContentAttention( 177 | state_names=self.transition.apply.states, 178 | attended_dim=representation_dim, 179 | match_dim=state_dim, 180 | energy_computer=self.energy_computer, 181 | use_local_attention=use_local_attention, 182 | window_size=window_size, 183 | name="attention") 184 | 185 | self.topical_attention=SequenceContentAttention( 186 | state_names=self.transition.apply.states, 187 | attended_dim=topical_dim, 188 | match_dim=state_dim, 189 | energy_computer=self.energy_computer, 190 | use_local_attention=use_local_attention, 191 | window_size=window_size, 192 | name="topical_attention")#not sure whether the match dim would be correct. 193 | 194 | 195 | # Initialize the readout, note that SoftmaxEmitter emits -1 for 196 | # initial outputs which is used by LookupFeedBackWMT15 197 | readout = Readout( 198 | source_names=['states', 'feedback', 199 | self.attention.take_glimpses.outputs[0]], 200 | readout_dim=self.vocab_size, 201 | emitter=SoftmaxEmitter(initial_output=-1, theano_seed=theano_seed), 202 | feedback_brick=LookupFeedbackWMT15(vocab_size, embedding_dim), 203 | post_merge=InitializableFeedforwardSequence( 204 | [Bias(dim=state_dim, name='maxout_bias').apply, 205 | Maxout(num_pieces=2, name='maxout').apply, 206 | Linear(input_dim=state_dim / 2, output_dim=embedding_dim, 207 | use_bias=False, name='softmax0').apply, 208 | Linear(input_dim=embedding_dim, name='softmax1').apply]), 209 | merged_dim=state_dim, 210 | name='readout') 211 | 212 | # calculate the readout of topic word, 213 | # no specific feedback brick, use the trival feedback break 214 | # no post_merge and merge, use Bias and Linear 215 | topicWordReadout = Readout( 216 | source_names=['states', 'feedback', 217 | self.attention.take_glimpses.outputs[0]], 218 | readout_dim=self.topicWord_size, 219 | emitter=SoftmaxEmitter(initial_output=-1, theano_seed=theano_seed), 220 | name='twReadout') 221 | 222 | 223 | # Build sequence generator accordingly 224 | self.sequence_generator = SequenceGenerator( 225 | readout=readout, 226 | topicWordReadout=topicWordReadout, 227 | topic_vector_names=['topicSumVector'], 228 | transition=self.transition, 229 | attention=self.attention, 230 | topical_attention=self.topical_attention, 231 | q_dim=self.state_dim, 232 | #q_name='topic_embedding', 233 | topical_name='topic_embedding', 234 | content_name='content_embedding', 235 | use_step_decay_cost=use_step_decay_cost, 236 | use_doubly_stochastic=use_doubly_stochastic, lambda_ds=lambda_ds, 237 | use_concentration_cost=use_concentration_cost, lambda_ct=lambda_ct, 238 | use_stablilizer=use_stablilizer, lambda_st=lambda_st, 239 | fork=Fork([name for name in self.transition.apply.sequences 240 | if name != 'mask'], prototype=Linear()) 241 | ) 242 | 243 | self.children = [self.sequence_generator] 244 | 245 | @application(inputs=['representation', 'source_sentence_mask', 246 | 'target_sentence_mask', 'tw_representation','tw_mask','target_sentence','target_sentence_mask', 247 | 'target_topic_sentence','target_topic_binary_sentence','topic_embedding','content_embedding'], 248 | outputs=['cost']) 249 | def cost(self, representation, source_sentence_mask,tw_representation,tw_mask, 250 | target_sentence, target_sentence_mask, target_topic_sentence,target_topic_binary_sentence,topic_embedding,content_embedding): 251 | 252 | source_sentence_mask = source_sentence_mask.T 253 | target_sentence = target_sentence.T 254 | target_sentence_mask = target_sentence_mask.T 255 | target_topic_sentence=target_topic_sentence.T 256 | target_topic_binary_sentence=target_topic_binary_sentence.T 257 | # Get the cost matrix 258 | cost = self.sequence_generator.cost_matrix(**{ 259 | 'mask': target_sentence_mask, 260 | 'outputs': target_sentence, 261 | 'attended': representation, 262 | 'attended_mask': source_sentence_mask, 263 | 'topical_attended':tw_representation, 264 | 'topical_attended_mask':tw_mask, 265 | 'topic_embedding':topic_embedding, 266 | 'content_embedding':content_embedding, 267 | 'tw_outputs': target_topic_sentence, 268 | 'tw_binary': target_topic_binary_sentence} 269 | ) 270 | 271 | ''' 272 | return (cost * target_sentence_mask).sum() / \ 273 | target_sentence_mask.shape[1] 274 | ''' 275 | #return tensor.exp((cost * target_sentence_mask).sum()/tensor.sum(target_sentence_mask)) 276 | return (cost * target_sentence_mask).sum()/tensor.sum(target_sentence_mask) 277 | 278 | @application 279 | def generate(self, source_sentence, representation, tw_vocab_overlap,topic_embedding,**kwargs): 280 | return self.sequence_generator.generate( 281 | tw_vocab_overlap=tw_vocab_overlap, 282 | n_steps=2 * source_sentence.shape[1], 283 | batch_size=source_sentence.shape[0], 284 | attended=representation, 285 | attended_mask=tensor.ones(source_sentence.shape).T, 286 | topic_embedding=topic_embedding, 287 | **kwargs) 288 | -------------------------------------------------------------------------------- /SequenceGenerator_forPickTopicWord.py: -------------------------------------------------------------------------------- 1 | from theano import tensor 2 | 3 | from blocks.bricks import Initializable, Random, Bias, NDimensionalSoftmax 4 | from blocks.bricks.base import application, Brick, lazy 5 | from blocks.bricks.parallel import Fork, Merge 6 | from blocks.bricks.lookup import LookupTable 7 | from blocks.bricks.recurrent import recurrent 8 | from blocks.roles import add_role, COST 9 | from blocks.utils import dict_union, dict_subset 10 | from blocks.bricks.sequence_generators import ( 11 | BaseSequenceGenerator, FakeAttentionRecurrent) 12 | from attention_with_topicalq import AttentionRecurrent 13 | 14 | from blocks.bricks.wrappers import WithExtraDims 15 | 16 | class PickTargetProb(Brick): 17 | """A softmax brick. 18 | 19 | Works with 2-dimensional inputs only. If you need more, 20 | see :class:`NDimensionalSoftmax`. 21 | """ 22 | 23 | @application(inputs=['y', 'x'], outputs=['cost']) 24 | def apply(self, application_call, y, x): 25 | if y.ndim == x.ndim - 1: 26 | print("y.ndim == x.ndim - 1"); 27 | indices = tensor.arange(y.shape[0]) * x.shape[1] + y 28 | cost = x.flatten()[indices] 29 | else: 30 | raise TypeError('rank mismatch between x and y') 31 | return cost 32 | 33 | 34 | class NDPickTargetProb(PickTargetProb): 35 | decorators = [WithExtraDims()] 36 | 37 | class SelectTarget(Random): 38 | 39 | @application 40 | def emit(self, probs): 41 | batch_size = probs.shape[0] 42 | pvals_flat = probs.reshape((batch_size, -1)) 43 | generated = self.theano_rng.multinomial(pvals=pvals_flat) 44 | return generated.reshape(probs.shape).argmax(axis=-1) 45 | 46 | @application 47 | def cost(self, y,x): 48 | indices = tensor.arange(y.shape[0]) * x.shape[1] + y 49 | cost = x.flatten()[indices] 50 | 51 | return cost 52 | 53 | 54 | class SequenceGenerator(BaseSequenceGenerator): 55 | r"""A more user-friendly interface for :class:`BaseSequenceGenerator`. 56 | 57 | Parameters 58 | ---------- 59 | readout : instance of :class:`AbstractReadout` 60 | The readout component for the sequence generator. 61 | transition : instance of :class:`.BaseRecurrent` 62 | The recurrent transition to be used in the sequence generator. 63 | Will be combined with `attention`, if that one is given. 64 | attention : object, optional 65 | The attention mechanism to be added to ``transition``, 66 | an instance of 67 | :class:`~blocks.bricks.attention.AbstractAttention`. 68 | add_contexts : bool 69 | If ``True``, the 70 | :class:`.AttentionRecurrent` wrapping the 71 | `transition` will add additional contexts for the attended and its 72 | mask. 73 | \*\*kwargs : dict 74 | All keywords arguments are passed to the base class. If `fork` 75 | keyword argument is not provided, :class:`.Fork` is created 76 | that forks all transition sequential inputs without a "mask" 77 | substring in them. 78 | 79 | """ 80 | def __init__(self, readout,topicWordReadout,topic_vector_names, transition,topical_name,content_name,q_dim,q_name, attention=None,topical_attention=None, 81 | use_step_decay_cost=False, 82 | use_doubly_stochastic=False, lambda_ds=0.001, 83 | use_concentration_cost=False, lambda_ct=10, 84 | use_stablilizer=False, lambda_st=50, 85 | add_contexts=True, **kwargs): 86 | self.use_doubly_stochastic = use_doubly_stochastic 87 | self.use_step_decay_cost = use_step_decay_cost 88 | self.use_concentration_cost = use_concentration_cost 89 | self.use_stablilizer = use_stablilizer 90 | self.lambda_ds = lambda_ds 91 | self.lambda_ct = lambda_ct 92 | self.lambda_st = lambda_st 93 | normal_inputs = [name for name in transition.apply.sequences 94 | if 'mask' not in name] 95 | kwargs.setdefault('fork', Fork(normal_inputs)) 96 | if attention: 97 | transition = AttentionRecurrent( 98 | transition, attention,topical_attention,topical_attended_name='topical_attended',topical_attended_mask_name='topical_attended_mask',content_name=content_name,topical_name=topical_name, 99 | add_contexts=add_contexts, name="att_trans") 100 | else: 101 | transition = FakeAttentionRecurrent(transition, 102 | name="with_fake_attention") 103 | 104 | self.topicWordReadout=topicWordReadout; 105 | self._topic_vector_names=topic_vector_names; 106 | self.probPick=NDPickTargetProb(); 107 | self.sampleTarget=SelectTarget(); 108 | #self._q_names=[q_name]; 109 | self.topical_name=topical_name; 110 | self.content_name=content_name; 111 | self._topical_context_names=['topical_attended','topical_attended_mask']; 112 | super(SequenceGenerator, self).__init__( 113 | readout, transition, **kwargs) 114 | self.children+=[self.topicWordReadout,self.probPick,self.sampleTarget]; 115 | 116 | def _push_allocation_config(self): 117 | 118 | super(SequenceGenerator, self)._push_allocation_config(); 119 | transition_sources = (self._state_names + self._context_names + 120 | self._glimpse_names) 121 | self.topicWordReadout.source_dims = [self.transition.get_dim(name) 122 | if name in transition_sources 123 | else self.readout.get_dim(name) 124 | for name in self.readout.source_names] 125 | self.topicWordReadout.push_allocation_config() 126 | 127 | @application 128 | def cost_matrix(self, application_call, outputs,tw_outputs, tw_binary,mask=None, **kwargs): 129 | """Returns generation costs for output sequences. 130 | 131 | See Also 132 | -------- 133 | :meth:`cost` : Scalar cost. 134 | 135 | """ 136 | # We assume the data has axes (time, batch, features, ...) 137 | batch_size = outputs.shape[1] 138 | 139 | # Prepare input for the iterative part 140 | states = dict_subset(kwargs, self._state_names, must_have=False) 141 | # masks in context are optional (e.g. `attended_mask`) 142 | contexts = dict_subset(kwargs, self._context_names, must_have=False) 143 | topical_word_contexts=dict_subset(kwargs, self._topical_context_names) 144 | topical_embeddings=dict_subset(kwargs,[self.topical_name]); 145 | content_embeddings=dict_subset(kwargs,[self.content_name]); 146 | #q=dict_subset(kwargs, self._q_names, must_have=True,pop=True); 147 | feedback = self.readout.feedback(outputs) 148 | inputs = self.fork.apply(feedback, as_dict=True) 149 | 150 | # Run the recurrent network 151 | results = self.transition.apply( 152 | mask=mask, return_initial_states=True, as_dict=True, 153 | **dict_union(inputs, states, contexts,topical_word_contexts,topical_embeddings,content_embeddings)) 154 | 155 | # Separate the deliverables. The last states are discarded: they 156 | # are not used to predict any output symbol. The initial glimpses 157 | # are discarded because they are not used for prediction. 158 | # Remember, glimpses are computed _before_ output stage, states are 159 | # computed after. 160 | states = {name: results[name][:-1] for name in self._state_names} 161 | glimpses = {name: results[name][1:] for name in self._glimpse_names} 162 | glimpses_modified={'weighted_averages':glimpses['weighted_averages'],'weigths':glimpses['weights']} 163 | 164 | # Compute the cost 165 | feedback = tensor.roll(feedback, 1, 0) 166 | feedback = tensor.set_subtensor( 167 | feedback[0], 168 | self.readout.feedback(self.readout.initial_outputs(batch_size))) 169 | readouts = self.readout.readout( 170 | feedback=feedback, **dict_union(states, glimpses_modified, contexts)) 171 | #costs = self.readout.cost(readouts, outputs) 172 | 173 | #topicSumVec = dict_subset(kwargs, self._topic_vector_names, must_have=True); 174 | twReadouts=self.topicWordReadout.readout(feedback=feedback,**dict_union(states,glimpses_modified, contexts)); 175 | twExp=tensor.exp(twReadouts); 176 | rwExp=tensor.exp(readouts); 177 | Z=twExp.sum(keepdims=True,axis=2)+rwExp.sum(keepdims=True,axis=2);#remains uncertain,keepdims, and the # of axis 178 | twExp/=Z; 179 | rwExp/=Z; 180 | twCost=self.probPick.apply(tw_outputs,twExp,extra_ndim=twExp.ndim - 2); 181 | rwCost=self.probPick.apply(outputs,rwExp,extra_ndim=rwExp.ndim - 2); 182 | totalCost=twCost*tw_binary+rwCost; 183 | costs=-tensor.log(totalCost); 184 | 185 | if self.use_doubly_stochastic: 186 | # Doubly stochastic cost 187 | # \lambda\sum_{i}(1-\sum_{t}w_{t, i})^2 188 | # the first dimensions of weights returned by transition 189 | # is batch, time 190 | weights = glimpses['weights'] 191 | weights_sum_time = tensor.sum(weights, 0) 192 | penalties = tensor.ones_like(weights_sum_time) - weights_sum_time 193 | penalties_squared = tensor.pow(penalties, 2) 194 | ds_costs = tensor.sum(penalties_squared, 1) 195 | costs += (self.lambda_ds * ds_costs)[None, :] 196 | 197 | def step_decay_cost(states): 198 | # shape is time, batch, features 199 | eta = 0.0001 200 | xi = 100 201 | states_norm = states.norm(2, axis=2) 202 | zz = tensor.zeros([1, states.shape[1]]) 203 | padded_norm = tensor.join(0, zz, states_norm)[:-1, :] 204 | diffs = states_norm - padded_norm 205 | costs = eta * (xi ** diffs) 206 | return costs 207 | 208 | if self.use_step_decay_cost: 209 | costs += step_decay_cost(states['states']) 210 | 211 | def stablilizer_cost(states): 212 | states_norm = states.norm(2, axis=2) 213 | zz = tensor.zeros([1, states.shape[1]]) 214 | padded_norm = tensor.join(0, zz, states_norm)[:-1, :] 215 | diffs = states_norm - padded_norm 216 | costs = tensor.pow(diffs, 2) 217 | return costs 218 | 219 | if self.use_stablilizer: 220 | costs += self.lambda_st * stablilizer_cost(states['states']) 221 | 222 | if self.use_concentration_cost: 223 | # weights has shape [batch, time, source sentence len] 224 | weights = glimpses['weights'] 225 | maxis = tensor.max(weights, axis=2) 226 | lacks = tensor.ones_like(maxis) - maxis 227 | costs += self.lambda_ct * lacks 228 | 229 | if mask is not None: 230 | costs *= mask 231 | 232 | for name, variable in list(glimpses.items()) + list(states.items()): 233 | application_call.add_auxiliary_variable( 234 | variable.copy(), name=name) 235 | 236 | # This variables can be used to initialize the initial states of the 237 | # next batch using the last states of the current batch. 238 | for name in self._state_names: 239 | application_call.add_auxiliary_variable( 240 | results[name][-1].copy(), name=name+"_final_value") 241 | 242 | return costs 243 | 244 | 245 | @recurrent 246 | def generate(self, outputs,tw_vocab_overlap, **kwargs): 247 | """A sequence generation step. 248 | 249 | Parameters 250 | ---------- 251 | outputs : :class:`~tensor.TensorVariable` 252 | The outputs from the previous step. 253 | 254 | Notes 255 | ----- 256 | The contexts, previous states and glimpses are expected as keyword 257 | arguments. 258 | 259 | """ 260 | states = dict_subset(kwargs, self._state_names) 261 | # masks in context are optional (e.g. `attended_mask`) 262 | contexts = dict_subset(kwargs, self._context_names, must_have=False) 263 | topical_word_contexts=dict_subset(kwargs, self._topical_context_names) 264 | topical_embeddings=dict_subset(kwargs,[self.topical_name]); 265 | content_embeddings=dict_subset(kwargs,[self.content_name]); 266 | glimpses = dict_subset(kwargs, self._glimpse_names) 267 | next_glimpses = self.transition.take_glimpses( 268 | as_dict=True, 269 | **dict_union( 270 | states, glimpses,topical_embeddings,content_embeddings,contexts,topical_word_contexts)); 271 | glimpses_modified={'weighted_averages':next_glimpses['weighted_averages'],'weigths':next_glimpses['weights']} 272 | next_readouts = self.readout.readout( 273 | feedback=self.readout.feedback(outputs), 274 | **dict_union(states, glimpses_modified, contexts)) 275 | next_tw_readouts=self.topicWordReadout.readout( 276 | feedback=self.readout.feedback(outputs), 277 | **dict_union(states, glimpses_modified, contexts)) 278 | twExp=tensor.exp(next_tw_readouts); 279 | rwExp=tensor.exp(next_readouts); 280 | Z=twExp.sum(keepdims=True,axis=1)+rwExp.sum(keepdims=True,axis=1);#remains uncertain,keepdims, and the # of axis 281 | twExp/=Z; 282 | rwExp/=Z; 283 | probs=tensor.dot(twExp,tw_vocab_overlap)+rwExp; 284 | next_outputs = self.sampleTarget.emit(probs); 285 | next_costs = self.sampleTarget.cost(next_outputs, probs) 286 | next_costs=-tensor.log(next_costs); 287 | next_feedback = self.readout.feedback(next_outputs) 288 | next_inputs = (self.fork.apply(next_feedback, as_dict=True) 289 | if self.fork else {'feedback': next_feedback}) 290 | next_states = self.transition.compute_states( 291 | as_list=True, 292 | **dict_union(next_inputs, states, next_glimpses, contexts,topical_word_contexts)) 293 | return (next_states + [next_outputs] + 294 | list(next_glimpses.values()) + [next_costs]) 295 | 296 | @generate.delegate 297 | def generate_delegate(self): 298 | return self.transition.apply 299 | 300 | @generate.property('states') 301 | def generate_states(self): 302 | return self._state_names + ['outputs'] + self._glimpse_names 303 | 304 | @generate.property('outputs') 305 | def generate_outputs(self): 306 | return (self._state_names + ['outputs'] + 307 | self._glimpse_names + ['costs']) 308 | 309 | @generate.property('contexts') 310 | def generate_contexts(self): 311 | return (self.transition.apply.contexts++self._topical_context_names+[self.content_name]+[self.topical_name] 312 | +['tw_vocab_overlap']) 313 | 314 | def get_dim(self, name): 315 | if name in (self._state_names + self._context_names + 316 | self._glimpse_names): 317 | return self.transition.get_dim(name) 318 | elif name == 'outputs': 319 | return self.readout.get_dim(name) 320 | return super(BaseSequenceGenerator, self).get_dim(name) 321 | 322 | @application 323 | def initial_states(self, batch_size, *args, **kwargs): 324 | # TODO: support dict of outputs for application methods 325 | # to simplify this code. 326 | state_dict = dict( 327 | self.transition.initial_states( 328 | batch_size, as_dict=True, *args, **kwargs), 329 | outputs=self.readout.initial_outputs(batch_size)) 330 | return [state_dict[state_name] 331 | for state_name in self.generate.states] 332 | 333 | @initial_states.property('outputs') 334 | def initial_states_outputs(self): 335 | return self.generate.states 336 | -------------------------------------------------------------------------------- /search.py: -------------------------------------------------------------------------------- 1 | """The beam search module.""" 2 | from collections import OrderedDict 3 | from six.moves import range 4 | import numpy 5 | from picklable_itertools.extras import equizip 6 | from theano import config, function, tensor 7 | from blocks.bricks.sequence_generators import BaseSequenceGenerator 8 | from blocks.filter import VariableFilter, get_application_call, get_brick 9 | from blocks.graph import ComputationGraph 10 | from blocks.roles import INPUT, OUTPUT 11 | from blocks.utils import unpack 12 | 13 | 14 | class BeamSearch(object): 15 | """Approximate search for the most likely sequence. 16 | 17 | Beam search is an approximate algorithm for finding :math:`y^* = 18 | argmax_y P(y|c)`, where :math:`y` is an output sequence, :math:`c` are 19 | the contexts, :math:`P` is the output distribution of a 20 | :class:`.SequenceGenerator`. At each step it considers :math:`k` 21 | candidate sequence prefixes. :math:`k` is called the beam size, and the 22 | sequence are called the beam. The sequences are replaced with their 23 | :math:`k` most probable continuations, and this is repeated until 24 | end-of-line symbol is met. 25 | 26 | The beam search compiles quite a few Theano functions under the hood. 27 | Normally those are compiled at the first :meth:`search` call, but 28 | you can also explicitly call :meth:`compile`. 29 | 30 | Parameters 31 | ---------- 32 | samples : :class:`~theano.Variable` 33 | An output of a sampling computation graph built by 34 | :meth:`~blocks.brick.SequenceGenerator.generate`, the one 35 | corresponding to sampled sequences. 36 | 37 | See Also 38 | -------- 39 | :class:`.SequenceGenerator` 40 | 41 | Notes 42 | ----- 43 | Sequence generator should use an emitter which has `probs` method 44 | e.g. :class:`SoftmaxEmitter`. 45 | 46 | Does not support dummy contexts so far (all the contexts must be used 47 | in the `generate` method of the sequence generator for the current code 48 | to work). 49 | 50 | """ 51 | 52 | def __init__(self, samples): 53 | # Extracting information from the sampling computation graph 54 | self.cg = ComputationGraph(samples) 55 | self.inputs = self.cg.inputs 56 | self.generator = get_brick(samples) 57 | if not isinstance(self.generator, BaseSequenceGenerator): 58 | raise ValueError 59 | self.generate_call = get_application_call(samples) 60 | if (not self.generate_call.application == 61 | self.generator.generate): 62 | raise ValueError 63 | self.inner_cg = ComputationGraph(self.generate_call.inner_outputs) 64 | 65 | # Fetching names from the sequence generator 66 | self.context_names = self.generator.generate.contexts 67 | self.state_names = self.generator.generate.states 68 | 69 | # Parsing the inner computation graph of sampling scan 70 | self.contexts = [ 71 | VariableFilter(bricks=[self.generator], 72 | name=name, 73 | roles=[INPUT])(self.inner_cg)[0] 74 | for name in self.context_names] 75 | self.input_states = [] 76 | # Includes only those state names that were actually used 77 | # in 'generate' 78 | self.input_state_names = [] 79 | for name in self.generator.generate.states: 80 | var = VariableFilter( 81 | bricks=[self.generator], name=name, 82 | roles=[INPUT])(self.inner_cg) 83 | if var: 84 | self.input_state_names.append(name) 85 | self.input_states.append(var[0]) 86 | 87 | self.compiled = False 88 | 89 | def _compile_initial_state_and_context_computer(self): 90 | initial_states = VariableFilter( 91 | applications=[self.generator.initial_states], 92 | roles=[OUTPUT])(self.cg) 93 | outputs = OrderedDict([(v.tag.name, v) for v in initial_states]) 94 | beam_size = unpack( 95 | VariableFilter(applications=[self.generator.initial_states], 96 | name='batch_size')(self.cg)) 97 | for name, context in equizip(self.context_names, self.contexts): 98 | outputs[name] = context 99 | outputs['beam_size'] = beam_size 100 | self.initial_state_and_context_computer = function( 101 | self.inputs, outputs, on_unused_input='ignore') 102 | 103 | def _compile_next_state_computer(self): 104 | next_states = [VariableFilter(bricks=[self.generator], 105 | name=name, 106 | roles=[OUTPUT])(self.inner_cg)[-1] 107 | for name in self.state_names] 108 | next_outputs = VariableFilter( 109 | applications=[self.generator.readout.emit], roles=[OUTPUT])( 110 | self.inner_cg.variables) 111 | self.next_state_computer = function( 112 | self.contexts + self.input_states + next_outputs, 113 | next_states, 114 | on_unused_input='ignore') 115 | 116 | def _compile_logprobs_computer(self): 117 | # This filtering should return identical variables 118 | # (in terms of computations) variables, and we do not care 119 | # which to use. 120 | probs = VariableFilter( 121 | applications=[self.generator.readout.emitter.probs], 122 | roles=[OUTPUT])(self.inner_cg)[0] 123 | logprobs = -tensor.log(probs) 124 | self.logprobs_computer = function( 125 | self.contexts + self.input_states, logprobs, 126 | on_unused_input='ignore') 127 | 128 | def compile(self): 129 | """Compile all Theano functions used.""" 130 | self._compile_initial_state_and_context_computer() 131 | self._compile_next_state_computer() 132 | self._compile_logprobs_computer() 133 | self.compiled = True 134 | 135 | def compute_initial_states_and_contexts(self, inputs): 136 | """Computes initial states and contexts from inputs. 137 | 138 | Parameters 139 | ---------- 140 | inputs : dict 141 | Dictionary of input arrays. 142 | 143 | Returns 144 | ------- 145 | A tuple containing a {name: :class:`numpy.ndarray`} dictionary of 146 | contexts ordered like `self.context_names` and a 147 | {name: :class:`numpy.ndarray`} dictionary of states ordered like 148 | `self.state_names`. 149 | 150 | """ 151 | outputs = self.initial_state_and_context_computer( 152 | *[inputs[var] for var in self.inputs]) 153 | contexts = OrderedDict((n, outputs.pop(n)) for n in self.context_names) 154 | beam_size = outputs.pop('beam_size') 155 | initial_states = outputs 156 | return contexts, initial_states, beam_size 157 | 158 | def compute_logprobs(self, contexts, states): 159 | """Compute log probabilities of all possible outputs. 160 | 161 | Parameters 162 | ---------- 163 | contexts : dict 164 | A {name: :class:`numpy.ndarray`} dictionary of contexts. 165 | states : dict 166 | A {name: :class:`numpy.ndarray`} dictionary of states. 167 | 168 | Returns 169 | ------- 170 | A :class:`numpy.ndarray` of the (beam size, number of possible 171 | outputs) shape. 172 | 173 | """ 174 | input_states = [states[name] for name in self.input_state_names] 175 | return self.logprobs_computer(*(list(contexts.values()) + 176 | input_states)) 177 | 178 | def compute_next_states(self, contexts, states, outputs): 179 | """Computes next states. 180 | 181 | Parameters 182 | ---------- 183 | contexts : dict 184 | A {name: :class:`numpy.ndarray`} dictionary of contexts. 185 | states : dict 186 | A {name: :class:`numpy.ndarray`} dictionary of states. 187 | outputs : :class:`numpy.ndarray` 188 | A :class:`numpy.ndarray` of this step outputs. 189 | 190 | Returns 191 | ------- 192 | A {name: numpy.array} dictionary of next states. 193 | 194 | """ 195 | input_states = [states[name] for name in self.input_state_names] 196 | next_values = self.next_state_computer(*(list(contexts.values()) + 197 | input_states + [outputs])) 198 | return OrderedDict(equizip(self.state_names, next_values)) 199 | 200 | @staticmethod 201 | def _smallest(matrix, k, only_first_row=False): 202 | """Find k smallest elements of a matrix. 203 | 204 | Parameters 205 | ---------- 206 | matrix : :class:`numpy.ndarray` 207 | The matrix. 208 | k : int 209 | The number of smallest elements required. 210 | only_first_row : bool, optional 211 | Consider only elements of the first row. 212 | 213 | Returns 214 | ------- 215 | Tuple of ((row numbers, column numbers), values). 216 | 217 | """ 218 | if only_first_row: 219 | flatten = matrix[:1, :].flatten() 220 | else: 221 | flatten = matrix.flatten() 222 | args = numpy.argpartition(flatten, k)[:k] 223 | args = args[numpy.argsort(flatten[args])] 224 | return numpy.unravel_index(args, matrix.shape), flatten[args] 225 | 226 | def search(self, input_values, eol_symbol, max_length, 227 | ignore_first_eol=False, as_arrays=False): 228 | """Performs beam search. 229 | 230 | If the beam search was not compiled, it also compiles it. 231 | 232 | Parameters 233 | ---------- 234 | input_values : dict 235 | A {:class:`~theano.Variable`: :class:`~numpy.ndarray`} 236 | dictionary of input values. The shapes should be 237 | the same as if you ran sampling with batch size equal to 238 | `beam_size`. Put it differently, the user is responsible 239 | for duplicaling inputs necessary number of times, because 240 | this class has insufficient information to do it properly. 241 | eol_symbol : int 242 | End of sequence symbol, the search stops when the symbol is 243 | generated. 244 | max_length : int 245 | Maximum sequence length, the search stops when it is reached. 246 | ignore_first_eol : bool, optional 247 | When ``True``, the end if sequence symbol generated at the 248 | first iteration are ignored. This useful when the sequence 249 | generator was trained on data with identical symbols for 250 | sequence start and sequence end. 251 | as_arrays : bool, optional 252 | If ``True``, the internal representation of search results 253 | is returned, that is a (matrix of outputs, mask, 254 | costs of all generated outputs) tuple. 255 | 256 | Returns 257 | ------- 258 | outputs : list of lists of ints 259 | A list of the `beam_size` best sequences found in the order 260 | of decreasing likelihood. 261 | costs : list of floats 262 | A list of the costs for the `outputs`, where cost is the 263 | negative log-likelihood. 264 | 265 | """ 266 | if not self.compiled: 267 | self.compile() 268 | 269 | contexts, states, beam_size = self.compute_initial_states_and_contexts( 270 | input_values) 271 | 272 | # This array will store all generated outputs, including those from 273 | # previous step and those from already finished sequences. 274 | all_outputs = states['outputs'][None, :] 275 | all_masks = numpy.ones_like(all_outputs, dtype=config.floatX) 276 | all_costs = numpy.zeros_like(all_outputs, dtype=config.floatX) 277 | all_attended_source = numpy.zeros_like(all_outputs, dtype=int) 278 | """ 279 | above arrays all have shape [time, beam_size] where time grows 280 | 281 | states[weights] has shape [beam_size, len] 282 | where len is the length of source sentence 283 | 284 | outputs has shape (beam_size, ) 285 | """ 286 | 287 | # all weights has shape: [time, beam_size, source_len] 288 | all_weights = numpy.hstack( 289 | [states['weights'], 290 | numpy.zeros([beam_size, 300 - states['weights'].shape[1]])] 291 | )[None, :, :] 292 | 293 | for i in range(max_length): 294 | 295 | if all_masks[-1].sum() == 0: 296 | break 297 | 298 | # We carefully hack values of the `logprobs` array to ensure 299 | # that all finished sequences are continued with `eos_symbol`. 300 | logprobs = self.compute_logprobs(contexts, states) 301 | next_costs = (all_costs[-1, :, None] + 302 | logprobs * all_masks[-1, :, None]) 303 | (finished,) = numpy.where(all_masks[-1] == 0) 304 | next_costs[finished, :eol_symbol] = numpy.inf 305 | next_costs[finished, eol_symbol + 1:] = numpy.inf 306 | 307 | # The `i == 0` is required because at the first step the beam 308 | # size is effectively only 1. 309 | (indexes, outputs), chosen_costs = self._smallest( 310 | next_costs, beam_size, only_first_row=i == 0) 311 | 312 | # Rearrange everything 313 | for name in states: 314 | states[name] = states[name][indexes] 315 | all_outputs = all_outputs[:, indexes] 316 | all_masks = all_masks[:, indexes] 317 | all_costs = all_costs[:, indexes] 318 | all_attended_source = all_attended_source[:, indexes] 319 | all_weights = all_weights[:, indexes, :] 320 | 321 | # Record chosen output and compute new states 322 | states.update(self.compute_next_states(contexts, states, outputs)) 323 | 324 | all_outputs = numpy.vstack([all_outputs, outputs[None, :]]) 325 | all_costs = numpy.vstack([all_costs, chosen_costs[None, :]]) 326 | mask = outputs != eol_symbol 327 | if ignore_first_eol and i == 0: 328 | mask[:] = 1 329 | all_masks = numpy.vstack([all_masks, mask[None, :]]) 330 | _weights = numpy.hstack( 331 | [states['weights'], 332 | numpy.zeros([beam_size, 300 - states['weights'].shape[1]])]) 333 | all_weights = numpy.vstack([all_weights, _weights[None, :, :]]) 334 | max_attended = numpy.argmax(states['weights'], axis=1) 335 | all_attended_source = numpy.vstack( 336 | [all_attended_source, max_attended]) 337 | 338 | all_outputs = all_outputs[1:] 339 | all_attended_source = all_attended_source[1:] 340 | all_weights = all_weights[1:] 341 | all_masks = all_masks[:-1] 342 | all_costs = all_costs[1:] - all_costs[:-1] 343 | result = all_outputs, all_masks, all_costs, all_attended_source 344 | if as_arrays: 345 | return result 346 | return self.result_to_lists(result, all_weights) 347 | 348 | @staticmethod 349 | def result_to_lists(result, weights): 350 | outputs, masks, costs, attendeds = [array.T for array in result] 351 | weights = numpy.swapaxes(weights, 0, 1) 352 | outputs = [list(output[:mask.sum()]) 353 | for output, mask in equizip(outputs, masks)] 354 | attendeds = [list(attended[:mask.sum()]) 355 | for attended, mask in equizip(attendeds, masks)] 356 | weights = [weight[:mask.sum()] 357 | for weight, mask in zip(weights, masks)] 358 | costs = list(costs.T.sum(axis=0)) 359 | return outputs, costs, attendeds, weights 360 | -------------------------------------------------------------------------------- /search_decoder_with_extra_class.py: -------------------------------------------------------------------------------- 1 | """The beam search module.""" 2 | from collections import OrderedDict 3 | from six.moves import range 4 | import numpy 5 | from picklable_itertools.extras import equizip 6 | from theano import config, function, tensor 7 | from blocks.bricks.sequence_generators import BaseSequenceGenerator 8 | from blocks.filter import VariableFilter, get_application_call, get_brick 9 | from blocks.graph import ComputationGraph 10 | from blocks.roles import INPUT, OUTPUT 11 | from blocks.utils import unpack 12 | 13 | 14 | class BeamSearch(object): 15 | """Approximate search for the most likely sequence. 16 | 17 | Beam search is an approximate algorithm for finding :math:`y^* = 18 | argmax_y P(y|c)`, where :math:`y` is an output sequence, :math:`c` are 19 | the contexts, :math:`P` is the output distribution of a 20 | :class:`.SequenceGenerator`. At each step it considers :math:`k` 21 | candidate sequence prefixes. :math:`k` is called the beam size, and the 22 | sequence are called the beam. The sequences are replaced with their 23 | :math:`k` most probable continuations, and this is repeated until 24 | end-of-line symbol is met. 25 | 26 | The beam search compiles quite a few Theano functions under the hood. 27 | Normally those are compiled at the first :meth:`search` call, but 28 | you can also explicitly call :meth:`compile`. 29 | 30 | Parameters 31 | ---------- 32 | samples : :class:`~theano.Variable` 33 | An output of a sampling computation graph built by 34 | :meth:`~blocks.brick.SequenceGenerator.generate`, the one 35 | corresponding to sampled sequences. 36 | 37 | See Also 38 | -------- 39 | :class:`.SequenceGenerator` 40 | 41 | Notes 42 | ----- 43 | Sequence generator should use an emitter which has `probs` method 44 | e.g. :class:`SoftmaxEmitter`. 45 | 46 | Does not support dummy contexts so far (all the contexts must be used 47 | in the `generate` method of the sequence generator for the current code 48 | to work). 49 | 50 | """ 51 | 52 | def __init__(self, samples): 53 | # Extracting information from the sampling computation graph 54 | self.cg = ComputationGraph(samples) 55 | self.inputs = self.cg.inputs 56 | self.generator = get_brick(samples) 57 | if not isinstance(self.generator, BaseSequenceGenerator): 58 | raise ValueError 59 | self.generate_call = get_application_call(samples) 60 | if (not self.generate_call.application == 61 | self.generator.generate): 62 | raise ValueError 63 | self.inner_cg = ComputationGraph(self.generate_call.inner_outputs) 64 | 65 | # Fetching names from the sequence generator 66 | self.context_names = self.generator.generate.contexts 67 | self.state_names = self.generator.generate.states 68 | 69 | # Parsing the inner computation graph of sampling scan 70 | self.contexts = [ 71 | VariableFilter(bricks=[self.generator], 72 | name=name, 73 | roles=[INPUT])(self.inner_cg)[0] 74 | for name in self.context_names] 75 | self.input_states = [] 76 | # Includes only those state names that were actually used 77 | # in 'generate' 78 | self.input_state_names = [] 79 | for name in self.generator.generate.states: 80 | var = VariableFilter( 81 | bricks=[self.generator], name=name, 82 | roles=[INPUT])(self.inner_cg) 83 | if var: 84 | self.input_state_names.append(name) 85 | self.input_states.append(var[0]) 86 | 87 | self.tv_overlap_name=['tw_vocab_overlap']; 88 | self.tv_overlap = [VariableFilter( 89 | bricks=[self.generator], name=self.tv_overlap_name[0], 90 | roles=[INPUT])(self.inner_cg)[0]] 91 | 92 | def _compile_initial_state_and_context_computer(self): 93 | initial_states = VariableFilter( 94 | applications=[self.generator.initial_states], 95 | roles=[OUTPUT])(self.cg) 96 | outputs = OrderedDict([(v.tag.name, v) for v in initial_states]) 97 | beam_size = unpack( 98 | VariableFilter(applications=[self.generator.initial_states], 99 | name='batch_size')(self.cg)) 100 | for name, context in equizip(self.context_names, self.contexts): 101 | outputs[name] = context 102 | outputs['beam_size'] = beam_size 103 | self.initial_state_and_context_computer = function( 104 | self.inputs, outputs, on_unused_input='ignore') 105 | 106 | def _compile_next_state_computer(self): 107 | next_states = [VariableFilter(bricks=[self.generator], 108 | name=name, 109 | roles=[OUTPUT])(self.inner_cg)[-1] 110 | for name in self.state_names] 111 | next_outputs = VariableFilter( 112 | applications=[self.generator.sampleTarget.emit], roles=[OUTPUT])( 113 | self.inner_cg.variables) 114 | self.next_state_computer = function( 115 | self.contexts + self.input_states + next_outputs, 116 | next_states, 117 | on_unused_input='warn') 118 | 119 | def _compile_logprobs_computer(self): 120 | # This filtering should return identical variables 121 | # (in terms of computations) variables, and we do not care 122 | # which to use. 123 | probs = VariableFilter( 124 | applications=[self.generator.sampleTarget.emit], 125 | roles=[INPUT])(self.inner_cg)[0] 126 | logprobs = -tensor.log(probs) 127 | self.logprobs_computer = function( 128 | self.contexts + self.input_states, logprobs, 129 | on_unused_input='warn') 130 | 131 | def compile(self): 132 | """Compile all Theano functions used.""" 133 | self._compile_initial_state_and_context_computer() 134 | self._compile_next_state_computer() 135 | self._compile_logprobs_computer() 136 | self.compiled = True 137 | 138 | def compute_initial_states_and_contexts(self, inputs): 139 | """Computes initial states and contexts from inputs. 140 | 141 | Parameters 142 | ---------- 143 | inputs : dict 144 | Dictionary of input arrays. 145 | 146 | Returns 147 | ------- 148 | A tuple containing a {name: :class:`numpy.ndarray`} dictionary of 149 | contexts ordered like `self.context_names` and a 150 | {name: :class:`numpy.ndarray`} dictionary of states ordered like 151 | `self.state_names`. 152 | 153 | """ 154 | outputs = self.initial_state_and_context_computer( 155 | *[inputs[var] for var in self.inputs]) 156 | contexts = OrderedDict((n, outputs.pop(n)) for n in self.context_names) 157 | beam_size = outputs.pop('beam_size') 158 | initial_states = outputs 159 | return contexts, initial_states, beam_size 160 | 161 | def compute_logprobs(self, contexts, states): 162 | """Compute log probabilities of all possible outputs. 163 | 164 | Parameters 165 | ---------- 166 | contexts : dict 167 | A {name: :class:`numpy.ndarray`} dictionary of contexts. 168 | states : dict 169 | A {name: :class:`numpy.ndarray`} dictionary of states. 170 | 171 | Returns 172 | ------- 173 | A :class:`numpy.ndarray` of the (beam size, number of possible 174 | outputs) shape. 175 | 176 | """ 177 | input_states = [states[name] for name in self.input_state_names] 178 | return self.logprobs_computer(*(list(contexts.values()) + 179 | input_states)) 180 | 181 | def compute_next_states(self, contexts, states, outputs): 182 | """Computes next states. 183 | 184 | Parameters 185 | ---------- 186 | contexts : dict 187 | A {name: :class:`numpy.ndarray`} dictionary of contexts. 188 | states : dict 189 | A {name: :class:`numpy.ndarray`} dictionary of states. 190 | outputs : :class:`numpy.ndarray` 191 | A :class:`numpy.ndarray` of this step outputs. 192 | 193 | Returns 194 | ------- 195 | A {name: numpy.array} dictionary of next states. 196 | 197 | """ 198 | input_states = [states[name] for name in self.input_state_names] 199 | next_values = self.next_state_computer(*(list(contexts.values()) + 200 | input_states + [outputs])) 201 | return OrderedDict(equizip(self.state_names, next_values)) 202 | 203 | @staticmethod 204 | def _smallest(matrix, k, only_first_row=False): 205 | """Find k smallest elements of a matrix. 206 | 207 | Parameters 208 | ---------- 209 | matrix : :class:`numpy.ndarray` 210 | The matrix. 211 | k : int 212 | The number of smallest elements required. 213 | only_first_row : bool, optional 214 | Consider only elements of the first row. 215 | 216 | Returns 217 | ------- 218 | Tuple of ((row numbers, column numbers), values). 219 | 220 | """ 221 | if only_first_row: 222 | flatten = matrix[:1, :].flatten() 223 | else: 224 | flatten = matrix.flatten() 225 | args = numpy.argpartition(flatten, k)[:k] 226 | args = args[numpy.argsort(flatten[args])] 227 | return numpy.unravel_index(args, matrix.shape), flatten[args] 228 | 229 | def search(self, input_values, tw_vocab_overlap,eol_symbol, max_length, 230 | ignore_first_eol=False, as_arrays=False): 231 | """Performs beam search. 232 | 233 | If the beam search was not compiled, it also compiles it. 234 | 235 | Parameters 236 | ---------- 237 | input_values : dict 238 | A {:class:`~theano.Variable`: :class:`~numpy.ndarray`} 239 | dictionary of input values. The shapes should be 240 | the same as if you ran sampling with batch size equal to 241 | `beam_size`. Put it differently, the user is responsible 242 | for duplicaling inputs necessary number of times, because 243 | this class has insufficient information to do it properly. 244 | eol_symbol : int 245 | End of sequence symbol, the search stops when the symbol is 246 | generated. 247 | max_length : int 248 | Maximum sequence length, the search stops when it is reached. 249 | ignore_first_eol : bool, optional 250 | When ``True``, the end if sequence symbol generated at the 251 | first iteration are ignored. This useful when the sequence 252 | generator was trained on data with identical symbols for 253 | sequence start and sequence end. 254 | as_arrays : bool, optional 255 | If ``True``, the internal representation of search results 256 | is returned, that is a (matrix of outputs, mask, 257 | costs of all generated outputs) tuple. 258 | 259 | Returns 260 | ------- 261 | outputs : list of lists of ints 262 | A list of the `beam_size` best sequences found in the order 263 | of decreasing likelihood. 264 | costs : list of floats 265 | A list of the costs for the `outputs`, where cost is the 266 | negative log-likelihood. 267 | 268 | """ 269 | #if not self.compiled: 270 | self.compile() 271 | 272 | contexts, states, beam_size = self.compute_initial_states_and_contexts( 273 | input_values) 274 | 275 | # This array will store all generated outputs, including those from 276 | # previous step and those from already finished sequences. 277 | all_outputs = states['outputs'][None, :] 278 | all_masks = numpy.ones_like(all_outputs, dtype=config.floatX) 279 | all_costs = numpy.zeros_like(all_outputs, dtype=config.floatX) 280 | all_attended_source = numpy.zeros_like(all_outputs, dtype=int) 281 | """ 282 | above arrays all have shape [time, beam_size] where time grows 283 | 284 | states[weights] has shape [beam_size, len] 285 | where len is the length of source sentence 286 | 287 | outputs has shape (beam_size, ) 288 | """ 289 | 290 | # all weights has shape: [time, beam_size, source_len] 291 | all_weights = numpy.hstack( 292 | [states['weights'], 293 | numpy.zeros([beam_size, 300 - states['weights'].shape[1]])] 294 | )[None, :, :] 295 | 296 | for i in range(max_length): 297 | 298 | if all_masks[-1].sum() == 0: 299 | break 300 | 301 | # We carefully hack values of the `logprobs` array to ensure 302 | # that all finished sequences are continued with `eos_symbol`. 303 | logprobs = self.compute_logprobs(contexts, states) 304 | next_costs = (all_costs[-1, :, None] + 305 | logprobs * all_masks[-1, :, None]) 306 | (finished,) = numpy.where(all_masks[-1] == 0) 307 | next_costs[finished, :eol_symbol] = numpy.inf 308 | next_costs[finished, eol_symbol + 1:] = numpy.inf 309 | 310 | # The `i == 0` is required because at the first step the beam 311 | # size is effectively only 1. 312 | (indexes, outputs), chosen_costs = self._smallest( 313 | next_costs, beam_size, only_first_row=i == 0) 314 | 315 | # Rearrange everything 316 | for name in states: 317 | states[name] = states[name][indexes] 318 | all_outputs = all_outputs[:, indexes] 319 | all_masks = all_masks[:, indexes] 320 | all_costs = all_costs[:, indexes] 321 | all_attended_source = all_attended_source[:, indexes] 322 | all_weights = all_weights[:, indexes, :] 323 | 324 | # Record chosen output and compute new states 325 | states.update(self.compute_next_states(contexts, states, outputs)) 326 | 327 | all_outputs = numpy.vstack([all_outputs, outputs[None, :]]) 328 | all_costs = numpy.vstack([all_costs, chosen_costs[None, :]]) 329 | mask = outputs != eol_symbol 330 | if ignore_first_eol and i == 0: 331 | mask[:] = 1 332 | all_masks = numpy.vstack([all_masks, mask[None, :]]) 333 | _weights = numpy.hstack( 334 | [states['weights'], 335 | numpy.zeros([beam_size, 300 - states['weights'].shape[1]])]) 336 | all_weights = numpy.vstack([all_weights, _weights[None, :, :]]) 337 | max_attended = numpy.argmax(states['weights'], axis=1) 338 | all_attended_source = numpy.vstack( 339 | [all_attended_source, max_attended]) 340 | 341 | all_outputs = all_outputs[1:] 342 | all_attended_source = all_attended_source[1:] 343 | all_weights = all_weights[1:] 344 | all_masks = all_masks[:-1] 345 | all_costs = all_costs[1:] - all_costs[:-1] 346 | result = all_outputs, all_masks, all_costs, all_attended_source 347 | if as_arrays: 348 | return result 349 | return self.result_to_lists(result, all_weights) 350 | 351 | @staticmethod 352 | def result_to_lists(result, weights): 353 | outputs, masks, costs, attendeds = [array.T for array in result] 354 | weights = numpy.swapaxes(weights, 0, 1) 355 | outputs = [list(output[:mask.sum()]) 356 | for output, mask in equizip(outputs, masks)] 357 | attendeds = [list(attended[:mask.sum()]) 358 | for attended, mask in equizip(attendeds, masks)] 359 | weights = [weight[:mask.sum()] 360 | for weight, mask in zip(weights, masks)] 361 | costs = list(costs.T.sum(axis=0)) 362 | return outputs, costs, attendeds, weights 363 | -------------------------------------------------------------------------------- /sampling.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import logging 4 | import numpy 5 | import operator 6 | import os 7 | import re 8 | import signal 9 | import time 10 | import cPickle 11 | 12 | from blocks.extensions import SimpleExtension 13 | from search import BeamSearch 14 | from afterprocess import afterprocesser 15 | 16 | from subprocess import Popen, PIPE 17 | from progressbar import ProgressBar 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class SamplingBase(object): 23 | """Utility class for BleuValidator and Sampler.""" 24 | 25 | def _get_attr_rec(self, obj, attr): 26 | return self._get_attr_rec(getattr(obj, attr), attr) \ 27 | if hasattr(obj, attr) else obj 28 | 29 | def _get_true_length(self, seq, vocab): 30 | try: 31 | return seq.tolist().index(vocab['']) + 1 32 | except ValueError: 33 | return len(seq) 34 | 35 | def _oov_to_unk(self, seq, vocab_size, unk_idx): 36 | return [x if x < vocab_size else unk_idx for x in seq] 37 | 38 | def _idx_to_sent(self, seq, ivocab): 39 | return " ".join([ivocab.get(idx, "") for idx in seq]) 40 | 41 | def _idx_to_word(self, seq, ivocab): 42 | # return " ".join([ivocab.get(idx, "") for idx in seq]) 43 | return [ivocab.get(idx, "") for idx in seq] 44 | 45 | 46 | class Sampler(SimpleExtension, SamplingBase): 47 | """Random Sampling from model.""" 48 | 49 | def __init__(self, model, data_stream, model_name, hook_samples=1, 50 | src_vocab=None, trg_vocab=None, src_ivocab=None, 51 | trg_ivocab=None, src_vocab_size=None, **kwargs): 52 | super(Sampler, self).__init__(**kwargs) 53 | self.model = model 54 | self.hook_samples = hook_samples 55 | self.data_stream = data_stream 56 | self.model_name = model_name 57 | self.src_vocab = src_vocab 58 | self.trg_vocab = trg_vocab 59 | self.src_ivocab = src_ivocab 60 | self.trg_ivocab = trg_ivocab 61 | self.src_vocab_size = src_vocab_size 62 | self.is_synced = False 63 | self.sampling_fn = model.get_theano_function() 64 | 65 | def do(self, which_callback, *args): 66 | 67 | # Get dictionaries, this may not be the practical way 68 | sources = self._get_attr_rec(self.main_loop, 'data_stream') 69 | 70 | # Load vocabularies and invert if necessary 71 | # WARNING: Source and target indices from data stream 72 | # can be different 73 | if not self.src_vocab: 74 | self.src_vocab = sources.data_streams[0].dataset.dictionary 75 | if not self.trg_vocab: 76 | self.trg_vocab = sources.data_streams[1].dataset.dictionary 77 | if not self.src_ivocab: 78 | self.src_ivocab = {v: k for k, v in self.src_vocab.items()} 79 | if not self.trg_ivocab: 80 | self.trg_ivocab = {v: k for k, v in self.trg_vocab.items()} 81 | if not self.src_vocab_size: 82 | self.src_vocab_size = len(self.src_vocab) 83 | 84 | # Randomly select source samples from the current batch 85 | # WARNING: Source and target indices from data stream 86 | # can be different 87 | batch = args[0] 88 | batch_size = batch['source'].shape[0] 89 | hook_samples = min(batch_size, self.hook_samples) 90 | 91 | # TODO: this is problematic for boundary conditions, eg. last batch 92 | sample_idx = numpy.random.choice( 93 | batch_size, hook_samples, replace=False) 94 | src_batch = batch[self.main_loop.data_stream.mask_sources[0]] 95 | trg_batch = batch[self.main_loop.data_stream.mask_sources[1]] 96 | 97 | input_ = src_batch[sample_idx, :] 98 | target_ = trg_batch[sample_idx, :] 99 | 100 | # Sample 101 | print() 102 | for i in range(hook_samples): 103 | input_length = self._get_true_length(input_[i], self.src_vocab) 104 | target_length = self._get_true_length(target_[i], self.trg_vocab) 105 | 106 | inp = input_[i, :input_length] 107 | _1, outputs, _2, _3, costs = (self.sampling_fn(inp[None, :])) 108 | outputs = outputs.flatten() 109 | costs = costs.T 110 | 111 | sample_length = self._get_true_length(outputs, self.trg_vocab) 112 | 113 | print("Sampling: " + self.model_name) 114 | 115 | print("Input : ", self._idx_to_sent(input_[i][:input_length], 116 | self.src_ivocab)) 117 | print("Target: ", self._idx_to_sent(target_[i][:target_length], 118 | self.trg_ivocab)) 119 | print("Sample: ", self._idx_to_sent(outputs[:sample_length], 120 | self.trg_ivocab)) 121 | print("Sample cost: ", costs[:sample_length].sum()) 122 | print() 123 | 124 | class pplValidation(SimpleExtension, SamplingBase): 125 | """Random Sampling from model.""" 126 | 127 | def __init__(self, model, data_stream, model_name,config, 128 | src_vocab=None, n_best=1, track_n_models=1, trg_ivocab=None, 129 | patience=10, normalize=True, **kwargs): 130 | super(pplValidation, self).__init__(**kwargs) 131 | self.model = model 132 | self.data_stream = data_stream 133 | self.model_name = model_name 134 | self.src_vocab = src_vocab 135 | self.trg_ivocab = trg_ivocab 136 | self.is_synced = False 137 | self.sampling_fn = model 138 | 139 | self.config = config 140 | self.n_best = n_best 141 | self.normalize = normalize 142 | self.patience = patience 143 | 144 | 145 | def do(self, which_callback, *args): 146 | 147 | print() 148 | # Evaluate and save if necessary 149 | cost = self._evaluate_model() 150 | print("Average validation cost: " + str(cost)); 151 | 152 | def _evaluate_model(self): 153 | 154 | logger.info("Started Validation: ") 155 | 156 | ts = self.data_stream.get_epoch_iterator() 157 | total_cost = 0.0 158 | 159 | #pbar = ProgressBar(max_value=len(ts)).start()#modified 160 | pbar = ProgressBar(max_value=20036).start(); 161 | for i, (src,src_mask, trg,trg_mask, te,te_mask,tt,tt_mask,tb,tb_mask) in enumerate(ts): 162 | costs = self.model(*[trg, trg_mask, src, src_mask,te,tt,tb]) 163 | cost = costs.sum() 164 | total_cost+=cost 165 | pbar.update(i + 1) 166 | total_cost/=20036; 167 | pbar.finish() 168 | self.data_stream.reset() 169 | 170 | # run afterprocess 171 | # self.ap.main() 172 | self.main_loop.log.current_row['validation_cost'] = total_cost 173 | 174 | return total_cost 175 | 176 | 177 | class perplexityValidation(SimpleExtension, SamplingBase): 178 | """Random Sampling from model.""" 179 | 180 | def __init__(self,source_sentence,samples, model, data_stream, model_name,config, 181 | src_vocab=None, n_best=1, track_n_models=1, trg_ivocab=None, 182 | patience=10, normalize=True, **kwargs): 183 | super(perplexityValidation, self).__init__(**kwargs) 184 | self.model = model 185 | self.data_stream = data_stream 186 | self.model_name = model_name 187 | self.src_vocab = src_vocab 188 | self.trg_ivocab = trg_ivocab 189 | self.is_synced = False 190 | self.sampling_fn = model.get_theano_function() 191 | 192 | self.source_sentence = source_sentence 193 | self.samples = samples 194 | self.config = config 195 | self.n_best = n_best 196 | self.normalize = normalize 197 | self.patience = patience 198 | 199 | # Helpers 200 | self.vocab = data_stream.dataset.dictionary 201 | self.trg_ivocab = trg_ivocab 202 | self.unk_sym = data_stream.dataset.unk_token 203 | self.eos_sym = data_stream.dataset.eos_token 204 | self.unk_idx = self.vocab[self.unk_sym] 205 | self.eos_idx = self.vocab[self.eos_sym] 206 | self.src_eos_idx = config['src_vocab_size'] - 1 207 | self.beam_search = BeamSearch(samples=samples) 208 | 209 | def do(self, which_callback, *args): 210 | 211 | print() 212 | # Evaluate and save if necessary 213 | cost = self._evaluate_model() 214 | print("Average validation cost: " + str(cost)); 215 | 216 | def _evaluate_model(self): 217 | 218 | logger.info("Started Validation: ") 219 | 220 | if not self.trg_ivocab: 221 | sources = self._get_attr_rec(self.main_loop, 'data_stream') 222 | trg_vocab = sources.data_streams[1].dataset.dictionary 223 | self.trg_ivocab = {v: k for k, v in trg_vocab.items()} 224 | 225 | ts = self.data_stream.get_epoch_iterator() 226 | ftrans_original = open(self.config['val_output_orig'], 'w') 227 | total_cost = 0.0 228 | 229 | pbar = ProgressBar(max_value=len(ts)).start()#modified 230 | for i, line in enumerate(ts): 231 | seq = self._oov_to_unk( 232 | line[0], self.config['src_vocab_size'], self.unk_idx) 233 | input_ = numpy.tile(seq, (self.config['beam_size'], 1)) 234 | 235 | # draw sample, checking to ensure we don't get an empty string back 236 | trans, costs, attendeds, weights = \ 237 | self.beam_search.search( 238 | input_values={self.source_sentence: input_}, 239 | max_length=3*len(seq), eol_symbol=self.src_eos_idx, 240 | ignore_first_eol=True) 241 | 242 | # normalize costs according to the sequence lengths 243 | if self.normalize: 244 | lengths = numpy.array([len(s) for s in trans]) 245 | costs = costs / lengths 246 | 247 | best = numpy.argsort(costs)[0] 248 | try: 249 | total_cost += costs[best] 250 | trans_out = trans[best] 251 | trans_out = self._idx_to_word(trans_out, self.trg_ivocab) 252 | except ValueError: 253 | logger.info( 254 | "Can NOT find a translation for line: {}".format(i+1)) 255 | trans_out = '' 256 | 257 | print(' '.join(trans_out), file=ftrans_original) 258 | pbar.update(i + 1) 259 | 260 | pbar.finish() 261 | ftrans_original.close() 262 | self.data_stream.reset() 263 | 264 | # run afterprocess 265 | # self.ap.main() 266 | self.main_loop.log.current_row['validation_cost'] = total_cost 267 | 268 | return total_cost 269 | 270 | 271 | class BleuValidator(SimpleExtension, SamplingBase): 272 | 273 | def __init__(self, source_sentence, samples, model, data_stream, 274 | config, n_best=1, track_n_models=1, trg_ivocab=None, 275 | patience=10, normalize=True, **kwargs): 276 | super(BleuValidator, self).__init__(**kwargs) 277 | self.source_sentence = source_sentence 278 | self.samples = samples 279 | self.model = model 280 | self.data_stream = data_stream 281 | self.config = config 282 | self.n_best = n_best 283 | self.track_n_models = track_n_models 284 | self.normalize = normalize 285 | self.patience = patience 286 | 287 | # Helpers 288 | self.vocab = data_stream.dataset.dictionary 289 | self.trg_ivocab = trg_ivocab 290 | self.unk_sym = data_stream.dataset.unk_token 291 | self.eos_sym = data_stream.dataset.eos_token 292 | self.unk_idx = self.vocab[self.unk_sym] 293 | self.eos_idx = self.vocab[self.eos_sym] 294 | self.src_eos_idx = config['src_vocab_size'] - 1 295 | self.best_models = [] 296 | self.beam_search = BeamSearch(samples=samples) 297 | self.multibleu_cmd = ['perl', self.config['bleu_script'], 298 | self.config['val_set_target'], '<'] 299 | self.compbleu_cmd = [self.config['bleu_script_1'], 300 | self.config['val_set_target'], 301 | self.config['val_output_repl']] 302 | self.ap = afterprocesser(config) 303 | 304 | # Create saving directory if it does not exist 305 | if not os.path.exists(self.config['saveto']): 306 | os.makedirs(self.config['saveto']) 307 | 308 | def do(self, which_callback, *args): 309 | 310 | # Track validation burn in 311 | if self.main_loop.status['iterations_done'] <= \ 312 | self.config['val_burn_in']: 313 | return 314 | 315 | # Evaluate and save if necessary 316 | bleu, cost = self._evaluate_model() 317 | self._save_model(bleu, cost) 318 | self._stop() 319 | 320 | def _stop(self): 321 | def get_last_max(l): 322 | t = 0 323 | r = 0 324 | for i, j in enumerate(l): 325 | if j >= t: 326 | r = i 327 | return r 328 | 329 | def _evaluate_model(self): 330 | 331 | logger.info("Started Validation: ") 332 | 333 | if not self.trg_ivocab: 334 | sources = self._get_attr_rec(self.main_loop, 'data_stream') 335 | trg_vocab = sources.data_streams[1].dataset.dictionary 336 | self.trg_ivocab = {v: k for k, v in trg_vocab.items()} 337 | 338 | ts = self.data_stream.get_epoch_iterator() 339 | rts = open(self.config['val_set_source']).readlines() 340 | ftrans_original = open(self.config['val_output_orig'], 'w') 341 | saved_weights = [] 342 | total_cost = 0.0 343 | 344 | pbar = ProgressBar(max_value=len(rts)).start() 345 | for i, (line, line_raw) in enumerate(zip(ts, rts)): 346 | trans_in = line_raw.split() 347 | seq = self._oov_to_unk( 348 | line[0], self.config['src_vocab_size'], self.unk_idx) 349 | input_ = numpy.tile(seq, (self.config['beam_size'], 1)) 350 | 351 | # draw sample, checking to ensure we don't get an empty string back 352 | trans, costs, attendeds, weights = \ 353 | self.beam_search.search( 354 | input_values={self.source_sentence: input_}, 355 | max_length=3*len(seq), eol_symbol=self.src_eos_idx, 356 | ignore_first_eol=True) 357 | 358 | # normalize costs according to the sequence lengths 359 | if self.normalize: 360 | lengths = numpy.array([len(s) for s in trans]) 361 | costs = costs / lengths 362 | 363 | best = numpy.argsort(costs)[0] 364 | try: 365 | total_cost += costs[best] 366 | trans_out = trans[best] 367 | weight = weights[best][:, :len(trans_in)] 368 | trans_out = self._idx_to_word(trans_out, self.trg_ivocab) 369 | except ValueError: 370 | logger.info( 371 | "Can NOT find a translation for line: {}".format(i+1)) 372 | trans_out = '' 373 | 374 | saved_weights.append(weight) 375 | print(' '.join(trans_out), file=ftrans_original) 376 | pbar.update(i + 1) 377 | 378 | pbar.finish() 379 | ftrans_original.close() 380 | cPickle.dump(saved_weights, open(self.config['attention_weights'], 'wb')) 381 | self.data_stream.reset() 382 | 383 | # run afterprocess 384 | # self.ap.main() 385 | 386 | # calculate bleu 387 | bleu_subproc = Popen(self.compbleu_cmd, stdout=PIPE) 388 | while True: 389 | line = bleu_subproc.stdout.readline() 390 | if line != '': 391 | if 'BLEU' in line: 392 | stdout = line 393 | else: 394 | break 395 | bleu_subproc.terminate() 396 | out_parse = re.match(r'BLEU = [-.0-9]+', stdout) 397 | assert out_parse is not None 398 | 399 | # extract the score 400 | bleu_score = float(out_parse.group()[6:]) * 100 401 | logger.info('BLEU: ' + str(bleu_score)) 402 | self.main_loop.log.current_row['validation_bleu'] = bleu_score 403 | self.main_loop.log.current_row['validation_cost'] = total_cost 404 | 405 | return bleu_score, total_cost 406 | 407 | def _is_valid_to_save(self, bleu_score): 408 | if not self.best_models or min(self.best_models, 409 | key=operator.attrgetter('score')).score < bleu_score: 410 | return True 411 | return False 412 | 413 | def _save_model(self, bleu_score, total_cost): 414 | if self._is_valid_to_save(bleu_score): 415 | model = ModelInfo(bleu_score, 'bleu', self.config['saveto']) 416 | 417 | # Manage n-best model list first 418 | if len(self.best_models) >= self.track_n_models: 419 | old_model = self.best_models[0] 420 | if old_model.path and os.path.isfile(old_model.path): 421 | logger.info("Deleting old model %s" % old_model.path) 422 | os.remove(old_model.path) 423 | self.best_models.remove(old_model) 424 | 425 | self.best_models.append(model) 426 | self.best_models.sort(key=operator.attrgetter('score')) 427 | 428 | # Save the model here 429 | s = signal.signal(signal.SIGINT, signal.SIG_IGN) 430 | logger.info("Saving new model {}".format(model.path)) 431 | self.dump_parameters(self.main_loop, model.path) 432 | signal.signal(signal.SIGINT, s) 433 | 434 | def dump_parameters(self, main_loop, path): 435 | params_to_save = main_loop.model.get_parameter_values() 436 | param_values = {name.replace("/", "-"): param 437 | for name, param in params_to_save.items()} 438 | outfile_path = path + '.' + str(main_loop.status['iterations_done']) 439 | with open(outfile_path, 'wb') as outfile: 440 | numpy.savez(outfile, **param_values) 441 | 442 | 443 | class ModelInfo: 444 | """Utility class to keep track of evaluated models.""" 445 | 446 | def __init__(self, score, name, path=None): 447 | self.score = score 448 | self.path = self._generate_path(path, name) 449 | 450 | def _generate_path(self, path, name): 451 | gen_path = os.path.join( 452 | path, name + '_%.2f' % 453 | (self.score) if path else None) 454 | return gen_path 455 | -------------------------------------------------------------------------------- /stream.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | from fuel.datasets import TextFile 4 | from fuel.schemes import ConstantScheme 5 | from fuel.streams import DataStream 6 | from fuel.transformers import ( 7 | Merge, Batch, Filter, Padding, SortMapping, Unpack, Mapping) 8 | 9 | from six.moves import cPickle 10 | 11 | 12 | def _ensure_special_tokens(vocab, bos_idx=0, eos_idx=0, unk_idx=1): 13 | """Ensures special tokens exist in the dictionary.""" 14 | 15 | # remove tokens if they exist in some other index 16 | tokens_to_remove = [k for k, v in vocab.items() 17 | if v in [bos_idx, eos_idx, unk_idx]] 18 | for token in tokens_to_remove: 19 | vocab.pop(token) 20 | # put corresponding item 21 | vocab[''] = bos_idx 22 | vocab[''] = eos_idx 23 | vocab[''] = unk_idx 24 | return vocab 25 | 26 | def _ensure_unk(vocab,unk_idx=1): 27 | """Ensures special tokens exist in the dictionary.""" 28 | 29 | # remove tokens if they exist in some other index 30 | tokens_to_remove = [k for k, v in vocab.items() 31 | if v in [unk_idx]] 32 | for token in tokens_to_remove: 33 | vocab.pop(token) 34 | # put corresponding item 35 | vocab[''] = unk_idx 36 | return vocab 37 | 38 | def _ensure_unk(vocab,unk_idx=1): 39 | """Ensures special tokens exist in the dictionary.""" 40 | 41 | # remove tokens if they exist in some other index 42 | tokens_to_remove = [k for k, v in vocab.items() 43 | if v in [unk_idx]] 44 | for token in tokens_to_remove: 45 | vocab.pop(token) 46 | # put corresponding item 47 | vocab[''] = unk_idx 48 | return vocab 49 | 50 | def _length(sentence_pair): 51 | """Assumes target is the last element in the tuple.""" 52 | return len(sentence_pair[-2]) 53 | 54 | 55 | class PaddingWithEOS(Padding): 56 | """Padds a stream with given end of sequence idx.""" 57 | def __init__(self, data_stream, eos_idx, **kwargs): 58 | kwargs['data_stream'] = data_stream 59 | self.eos_idx = eos_idx 60 | super(PaddingWithEOS, self).__init__(**kwargs) 61 | 62 | def get_data_from_batch(self, request=None): 63 | if request is not None: 64 | raise ValueError 65 | data = list(next(self.child_epoch_iterator)) 66 | data_with_masks = [] 67 | for i, (source, source_data) in enumerate( 68 | zip(self.data_stream.sources, data)): 69 | if source not in self.mask_sources: 70 | data_with_masks.append(source_data) 71 | continue 72 | 73 | shapes = [numpy.asarray(sample).shape for sample in source_data] 74 | lengths = [shape[0] for shape in shapes] 75 | max_sequence_length = max(lengths) 76 | rest_shape = shapes[0][1:] 77 | if not all([shape[1:] == rest_shape for shape in shapes]): 78 | raise ValueError("All dimensions except length must be equal") 79 | dtype = numpy.asarray(source_data[0]).dtype 80 | 81 | padded_data = numpy.ones( 82 | (len(source_data), max_sequence_length) + rest_shape, 83 | dtype=dtype) * self.eos_idx[i] 84 | for i, sample in enumerate(source_data): 85 | padded_data[i, :len(sample)] = sample 86 | data_with_masks.append(padded_data) 87 | 88 | mask = numpy.zeros((len(source_data), max_sequence_length), 89 | self.mask_dtype) 90 | for i, sequence_length in enumerate(lengths): 91 | mask[i, :sequence_length] = 1 92 | data_with_masks.append(mask) 93 | return tuple(data_with_masks) 94 | 95 | class _oov_to_unk(object): 96 | """Maps out of vocabulary token index to unk token index.""" 97 | def __init__(self, src_vocab_size=30000, trg_vocab_size=30000,src_topic_vocab_size=2000,trg_topic_vocab_size=2000, 98 | unk_id=1): 99 | self.src_vocab_size = src_vocab_size 100 | self.trg_vocab_size = trg_vocab_size 101 | self.src_topic_vocab_size=src_topic_vocab_size 102 | self.trg_topic_vocab_size=trg_topic_vocab_size 103 | self.unk_id = unk_id 104 | 105 | def __call__(self, sentence_pair): 106 | for x in sentence_pair[3]: 107 | if x>=self.trg_topic_vocab_size: 108 | print("error!!"); 109 | return ([x if x < self.src_vocab_size else self.unk_id 110 | for x in sentence_pair[0]], 111 | [x if x < self.trg_vocab_size else self.unk_id 112 | for x in sentence_pair[1]], 113 | sentence_pair[2],sentence_pair[3],sentence_pair[4]) 114 | 115 | # class _oov_to_unk(object): 116 | # """Maps out of vocabulary token index to unk token index.""" 117 | # def __init__(self, src_vocab_size=30000, trg_vocab_size=30000, 118 | # unk_id=1): 119 | # self.src_vocab_size = src_vocab_size 120 | # self.trg_vocab_size = trg_vocab_size 121 | # self.unk_id = unk_id 122 | # 123 | # def __call__(self, sentence_pair): 124 | # return ([x if x < self.src_vocab_size else self.unk_id 125 | # for x in sentence_pair[0]], 126 | # [x if x < self.trg_vocab_size else self.unk_id 127 | # for x in sentence_pair[1]]) 128 | 129 | 130 | class _too_long(object): 131 | """Filters sequences longer than given sequence length.""" 132 | def __init__(self, seq_len=50): 133 | self.seq_len = seq_len 134 | 135 | def __call__(self, sentence_pair): 136 | return all([len(sentence) <= self.seq_len 137 | for sentence in sentence_pair]) 138 | 139 | 140 | def get_tr_stream(src_vocab, trg_vocab, src_data, trg_data, 141 | src_vocab_size=30000, trg_vocab_size=30000, unk_id=1, 142 | seq_len=50, batch_size=80, sort_k_batches=12, **kwargs): 143 | """Prepares the training data stream.""" 144 | 145 | # Load dictionaries and ensure special tokens exist 146 | src_vocab = _ensure_special_tokens( 147 | src_vocab if isinstance(src_vocab, dict) 148 | else cPickle.load(open(src_vocab, 'rb')), 149 | bos_idx=0, eos_idx=src_vocab_size - 1, unk_idx=unk_id) 150 | trg_vocab = _ensure_special_tokens( 151 | trg_vocab if isinstance(trg_vocab, dict) else 152 | cPickle.load(open(trg_vocab, 'rb')), 153 | bos_idx=0, eos_idx=trg_vocab_size - 1, unk_idx=unk_id) 154 | 155 | # Get text files from both source and target 156 | src_dataset = TextFile([src_data], src_vocab, None) 157 | trg_dataset = TextFile([trg_data], trg_vocab, None) 158 | 159 | # Merge them to get a source, target pair 160 | stream = Merge([src_dataset.get_example_stream(), 161 | trg_dataset.get_example_stream()], 162 | ('source', 'target')) 163 | 164 | # Filter sequences that are too long 165 | stream = Filter(stream, 166 | predicate=_too_long(seq_len=seq_len)) 167 | 168 | # Replace out of vocabulary tokens with unk token 169 | stream = Mapping(stream, 170 | _oov_to_unk(src_vocab_size=src_vocab_size, 171 | trg_vocab_size=trg_vocab_size, 172 | unk_id=unk_id)) 173 | 174 | # Build a batched version of stream to read k batches ahead 175 | stream = Batch(stream, 176 | iteration_scheme=ConstantScheme( 177 | batch_size*sort_k_batches)) 178 | 179 | # Sort all samples in the read-ahead batch 180 | stream = Mapping(stream, SortMapping(_length)) 181 | 182 | # Convert it into a stream again 183 | stream = Unpack(stream) 184 | 185 | # Construct batches from the stream with specified batch size 186 | stream = Batch( 187 | stream, iteration_scheme=ConstantScheme(batch_size)) 188 | 189 | # Pad sequences that are short 190 | masked_stream = PaddingWithEOS( 191 | stream, [src_vocab_size - 1, trg_vocab_size - 1]) 192 | 193 | return masked_stream 194 | 195 | def get_tr_stream_with_topicalq(src_vocab, trg_vocab,topical_vocab, src_data, trg_data,topical_data, 196 | src_vocab_size=30000, trg_vocab_size=30000,topical_vocab_size=2000, unk_id=1, 197 | seq_len=50, batch_size=80, sort_k_batches=12, **kwargs): 198 | """Prepares the training data stream.""" 199 | 200 | # Load dictionaries and ensure special tokens exist 201 | 202 | src_vocab = _ensure_special_tokens( 203 | src_vocab if isinstance(src_vocab, dict) 204 | else cPickle.load(open(src_vocab, 'rb')), 205 | bos_idx=0, eos_idx=src_vocab_size - 1, unk_idx=unk_id) 206 | trg_vocab = _ensure_special_tokens( 207 | trg_vocab if isinstance(trg_vocab, dict) else 208 | cPickle.load(open(trg_vocab, 'rb')), 209 | bos_idx=0, eos_idx=trg_vocab_size - 1, unk_idx=unk_id) 210 | topical_vocab =cPickle.load(open(topical_vocab, 'rb'));#not ensure special token. 211 | 212 | # Get text files from both source and target 213 | src_dataset = TextFile([src_data], src_vocab, None) 214 | trg_dataset = TextFile([trg_data], trg_vocab, None) 215 | topical_dataset = TextFile([topical_data],topical_vocab,None,None,'10'); 216 | 217 | # Merge them to get a source, target pair 218 | stream = Merge([src_dataset.get_example_stream(), 219 | trg_dataset.get_example_stream(), 220 | topical_dataset.get_example_stream()], 221 | ('source', 'target','source_topical')) 222 | 223 | 224 | # Filter sequences that are too long 225 | stream = Filter(stream, 226 | predicate=_too_long(seq_len=seq_len)) 227 | 228 | # Replace out of vocabulary tokens with unk token 229 | # The topical part are not contained of it, check~ 230 | stream = Mapping(stream, 231 | _oov_to_unk(src_vocab_size=src_vocab_size, 232 | trg_vocab_size=trg_vocab_size, 233 | topical_vocab_size=topical_vocab_size, 234 | unk_id=unk_id)) 235 | 236 | # Build a batched version of stream to read k batches ahead 237 | stream = Batch(stream, 238 | iteration_scheme=ConstantScheme( 239 | batch_size*sort_k_batches)) 240 | 241 | # Sort all samples in the read-ahead batch 242 | stream = Mapping(stream, SortMapping(_length)) 243 | 244 | # Convert it into a stream again 245 | stream = Unpack(stream) 246 | 247 | # Construct batches from the stream with specified batch size 248 | stream = Batch( 249 | stream, iteration_scheme=ConstantScheme(batch_size)) 250 | 251 | # Pad sequences that are short 252 | masked_stream = PaddingWithEOS( 253 | stream, [src_vocab_size - 1,trg_vocab_size - 1, topical_vocab_size - 1]) 254 | 255 | return masked_stream 256 | 257 | def get_tr_stream_with_topic_target(src_vocab, trg_vocab,topic_vocab_input,topic_vocab_output, src_data, trg_data,topical_data, 258 | src_vocab_size=30000, trg_vocab_size=30000,trg_topic_vocab_size=2000,source_topic_vocab_size=2000, unk_id=1, 259 | seq_len=50, batch_size=80, sort_k_batches=12, **kwargs): 260 | """Prepares the training data stream.""" 261 | 262 | # Load dictionaries and ensure special tokens exist 263 | 264 | src_vocab = _ensure_special_tokens( 265 | src_vocab if isinstance(src_vocab, dict) 266 | else cPickle.load(open(src_vocab, 'rb')), 267 | bos_idx=0, eos_idx=src_vocab_size - 1, unk_idx=unk_id) 268 | trg_vocab = _ensure_special_tokens( 269 | trg_vocab if isinstance(trg_vocab, dict) else 270 | cPickle.load(open(trg_vocab, 'rb')), 271 | bos_idx=0, eos_idx=trg_vocab_size - 1, unk_idx=unk_id) 272 | topic_vocab_input=cPickle.load(open(topic_vocab_input,'rb')); 273 | topic_vocab_output=cPickle.load(open(topic_vocab_output, 'rb'));#already has and in it 274 | topic_binary_vocab={}; 275 | for k,v in topic_vocab_output.items(): 276 | if k=='': 277 | topic_binary_vocab[k]=0; 278 | else: 279 | topic_binary_vocab[k]=1; 280 | 281 | 282 | # Get text files from both source and target 283 | src_dataset = TextFile([src_data], src_vocab, None) 284 | trg_dataset = TextFile([trg_data], trg_vocab, None) 285 | src_topic_input=TextFile([topical_data],topic_vocab_input,None,None,'rt') 286 | trg_topic_dataset = TextFile([trg_data],topic_vocab_output,None); 287 | trg_topic_binary_dataset= TextFile([trg_data],topic_binary_vocab,None); 288 | 289 | # Merge them to get a source, target pair 290 | stream = Merge([src_dataset.get_example_stream(), 291 | trg_dataset.get_example_stream(), 292 | src_topic_input.get_example_stream(), 293 | trg_topic_dataset.get_example_stream(), 294 | trg_topic_binary_dataset.get_example_stream()], 295 | ('source', 'target','source_topical','target_topic','target_binary_topic')) 296 | 297 | 298 | # Filter sequences that are too long 299 | stream = Filter(stream, 300 | predicate=_too_long(seq_len=seq_len)) 301 | 302 | # Replace out of vocabulary tokens with unk token 303 | # The topical part are not contained of it, check~ 304 | stream = Mapping(stream, 305 | _oov_to_unk(src_vocab_size=src_vocab_size, 306 | trg_vocab_size=trg_vocab_size, 307 | src_topic_vocab_size=source_topic_vocab_size, 308 | trg_topic_vocab_size=trg_topic_vocab_size, 309 | unk_id=unk_id)) 310 | 311 | # Build a batched version of stream to read k batches ahead 312 | stream = Batch(stream, 313 | iteration_scheme=ConstantScheme( 314 | batch_size*sort_k_batches)) 315 | 316 | # Sort all samples in the read-ahead batch 317 | stream = Mapping(stream, SortMapping(_length)) 318 | 319 | # Convert it into a stream again 320 | stream = Unpack(stream) 321 | 322 | # Construct batches from the stream with specified batch size 323 | stream = Batch( 324 | stream, iteration_scheme=ConstantScheme(batch_size)) 325 | 326 | # Pad sequences that are short 327 | masked_stream = PaddingWithEOS( 328 | stream, [src_vocab_size - 1,trg_vocab_size - 1, source_topic_vocab_size-1,trg_topic_vocab_size - 1,trg_topic_vocab_size-1]) 329 | 330 | return masked_stream 331 | 332 | 333 | def get_dev_tr_stream_with_topic_target(val_set_source=None,val_set_target=None, src_vocab=None,trg_vocab=None, src_vocab_size=30000,trg_vocab_size=30000, 334 | trg_topic_vocab_size=2000,source_topic_vocab_size=2000, 335 | topical_dev_set=None,topic_vocab_input=None,topic_vocab_output=None,topical_vocab_size=2000, 336 | unk_id=1, **kwargs): 337 | """Prepares the training data stream.""" 338 | 339 | dev_stream = None 340 | if val_set_source is not None and src_vocab is not None: 341 | src_vocab = _ensure_special_tokens( 342 | src_vocab if isinstance(src_vocab, dict) 343 | else cPickle.load(open(src_vocab, 'rb')), 344 | bos_idx=0, eos_idx=src_vocab_size - 1, unk_idx=unk_id) 345 | trg_vocab = _ensure_special_tokens( 346 | trg_vocab if isinstance(trg_vocab, dict) else 347 | cPickle.load(open(trg_vocab, 'rb')), 348 | bos_idx=0, eos_idx=trg_vocab_size - 1, unk_idx=unk_id) 349 | topic_vocab_input=cPickle.load(open(topic_vocab_input,'rb')); 350 | topic_vocab_output=cPickle.load(open(topic_vocab_output, 'rb'));#already has and in it 351 | topic_binary_vocab={}; 352 | for k,v in topic_vocab_output.items(): 353 | if k=='': 354 | topic_binary_vocab[k]=0; 355 | else: 356 | topic_binary_vocab[k]=1; 357 | # Get text files from both source and target 358 | src_dataset = TextFile([val_set_source], src_vocab, None) 359 | trg_dataset = TextFile([val_set_target], trg_vocab, None) 360 | src_topic_input=TextFile([topical_dev_set],topic_vocab_input,None,None,'rt') 361 | trg_topic_dataset = TextFile([val_set_target],topic_vocab_output,None); 362 | trg_topic_binary_dataset= TextFile([val_set_target],topic_binary_vocab,None); 363 | 364 | # Merge them to get a source, target pair 365 | dev_stream = Merge([src_dataset.get_example_stream(), 366 | trg_dataset.get_example_stream(), 367 | src_topic_input.get_example_stream(), 368 | trg_topic_dataset.get_example_stream(), 369 | trg_topic_binary_dataset.get_example_stream()], 370 | ('source', 'target','source_topical','target_topic','target_binary_topic')) 371 | stream = Batch( 372 | dev_stream, iteration_scheme=ConstantScheme(1)) 373 | masked_stream = PaddingWithEOS( 374 | stream, [src_vocab_size - 1,trg_vocab_size - 1, source_topic_vocab_size-1,trg_topic_vocab_size - 1,trg_topic_vocab_size-1]) 375 | 376 | return masked_stream 377 | 378 | def get_dev_stream_with_topicalq(val_set_source=None, src_vocab=None, src_vocab_size=30000,topical_dev_set=None,topic_vocab_input=None, 379 | unk_id=1, **kwargs): 380 | """Setup development set stream if necessary.""" 381 | dev_stream = None 382 | if val_set_source is not None and src_vocab is not None: 383 | src_vocab = _ensure_special_tokens( 384 | src_vocab if isinstance(src_vocab, dict) else 385 | cPickle.load(open(src_vocab, 'rb')), 386 | bos_idx=0, eos_idx=src_vocab_size - 1, unk_idx=unk_id) 387 | print val_set_source, type(src_vocab) 388 | topical_vocab =cPickle.load(open(topic_vocab_input, 'rb'));#not ensure special token. 389 | topical_dataset = TextFile([topical_dev_set],topical_vocab,None,None,'rt'); 390 | dev_dataset = TextFile([val_set_source], src_vocab, None) 391 | #dev_stream = DataStream(dev_dataset) 392 | # Merge them to get a source, target pair 393 | dev_stream = Merge([dev_dataset.get_example_stream(), 394 | topical_dataset.get_example_stream()], 395 | ('source','source_topical')) 396 | return dev_stream 397 | 398 | 399 | 400 | def get_dev_stream(val_set_source=None, src_vocab=None, src_vocab_size=30000, 401 | unk_id=1, **kwargs): 402 | """Setup development set stream if necessary.""" 403 | dev_stream = None 404 | if val_set_source is not None and src_vocab is not None: 405 | src_vocab = _ensure_special_tokens( 406 | src_vocab if isinstance(src_vocab, dict) else 407 | cPickle.load(open(src_vocab, 'rb')), 408 | bos_idx=0, eos_idx=src_vocab_size - 1, unk_idx=unk_id) 409 | print val_set_source, type(src_vocab) 410 | dev_dataset = TextFile([val_set_source], src_vocab, None) 411 | dev_stream = DataStream(dev_dataset) 412 | return dev_stream 413 | 414 | def get_tr_stream_unsorted(src_vocab, trg_vocab, src_data, trg_data, 415 | src_vocab_size=30000, trg_vocab_size=30000, unk_id=1, 416 | seq_len=50, batch_size=80, sort_k_batches=12, **kwargs): 417 | """Prepares the training data stream.""" 418 | 419 | # Load dictionaries and ensure special tokens exist 420 | src_vocab = _ensure_special_tokens( 421 | src_vocab if isinstance(src_vocab, dict) 422 | else cPickle.load(open(src_vocab, 'rb')), 423 | bos_idx=0, eos_idx=src_vocab_size - 1, unk_idx=unk_id) 424 | trg_vocab = _ensure_special_tokens( 425 | trg_vocab if isinstance(trg_vocab, dict) else 426 | cPickle.load(open(trg_vocab, 'rb')), 427 | bos_idx=0, eos_idx=trg_vocab_size - 1, unk_idx=unk_id) 428 | 429 | # Get text files from both source and target 430 | src_dataset = TextFile([src_data], src_vocab, None) 431 | trg_dataset = TextFile([trg_data], trg_vocab, None) 432 | 433 | # Merge them to get a source, target pair 434 | stream = Merge([src_dataset.get_example_stream(), 435 | trg_dataset.get_example_stream()], 436 | ('source', 'target')) 437 | 438 | # Filter sequences that are too long 439 | stream = Filter(stream, 440 | predicate=_too_long(seq_len=seq_len)) 441 | 442 | # Replace out of vocabulary tokens with unk token 443 | stream = Mapping(stream, 444 | _oov_to_unk(src_vocab_size=src_vocab_size, 445 | trg_vocab_size=trg_vocab_size, 446 | unk_id=unk_id)) 447 | 448 | # Build a batched version of stream to read k batches ahead 449 | stream = Batch(stream, 450 | iteration_scheme=ConstantScheme(1)) 451 | 452 | # Pad sequences that are short 453 | masked_stream = PaddingWithEOS( 454 | stream, [src_vocab_size - 1, trg_vocab_size - 1]) 455 | 456 | return masked_stream 457 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import logging 4 | import time 5 | import numpy 6 | import os 7 | import cPickle 8 | import string 9 | 10 | from collections import Counter 11 | from theano import tensor, function,shared 12 | from toolz import merge 13 | from progressbar import ProgressBar 14 | 15 | from blocks.algorithms import (GradientDescent, StepClipping, 16 | AdaDelta, AdaGrad, Scale, CompositeRule) 17 | from blocks.extensions import FinishAfter, Printing, Timing 18 | from blocks.extensions.monitoring import TrainingDataMonitoring 19 | from blocks.filter import VariableFilter 20 | from blocks.graph import ComputationGraph, apply_noise, apply_dropout 21 | from blocks.initialization import IsotropicGaussian, Orthogonal, Constant 22 | from blocks.main_loop import MainLoop 23 | from blocks.model import Model 24 | from search_decoder_with_extra_class import BeamSearch 25 | from blocks.select import Selector 26 | 27 | from checkpoint import CheckpointNMT, LoadNMT 28 | from model import BidirectionalEncoder, Decoder, topicalq_transformer 29 | from sampling import BleuValidator, Sampler, SamplingBase, pplValidation 30 | from stream import (get_tr_stream, get_dev_stream, get_tr_stream_with_topic_target,get_dev_stream_with_topicalq, 31 | get_tr_stream_unsorted, _ensure_special_tokens) 32 | from SimplePrinting import SimplePrinting 33 | from learning_rate_halver import (LearningRateHalver, 34 | LearningRateDoubler, 35 | OldModelRemover) 36 | from afterprocess import afterprocesser 37 | 38 | try: 39 | from blocks.extras.extensions.plot import Plot 40 | BOKEH_AVAILABLE = True 41 | except ImportError: 42 | BOKEH_AVAILABLE = False 43 | 44 | logger = logging.getLogger(__name__) 45 | 46 | 47 | def main(mode, config, use_bokeh=False): 48 | 49 | # Construct model 50 | logger.info('Building RNN encoder-decoder') 51 | encoder = BidirectionalEncoder( 52 | config['src_vocab_size'], config['enc_embed'], config['enc_nhids']) 53 | topical_transformer=topicalq_transformer(config['source_topic_vocab_size'],config['topical_embedding_dim'], config['enc_nhids'],config['topical_word_num'],config['batch_size']); 54 | decoder = Decoder(vocab_size=config['trg_vocab_size'], 55 | topicWord_size=config['trg_topic_vocab_size'], 56 | embedding_dim=config['dec_embed'], 57 | topical_dim=config['topical_embedding_dim'], 58 | state_dim=config['dec_nhids'], 59 | representation_dim=config['enc_nhids'] * 2, 60 | match_function=config['match_function'], 61 | use_doubly_stochastic=config['use_doubly_stochastic'], 62 | lambda_ds=config['lambda_ds'], 63 | use_local_attention=config['use_local_attention'], 64 | window_size=config['window_size'], 65 | use_step_decay_cost=config['use_step_decay_cost'], 66 | use_concentration_cost=config['use_concentration_cost'], 67 | lambda_ct=config['lambda_ct'], 68 | use_stablilizer=config['use_stablilizer'], 69 | lambda_st=config['lambda_st']) 70 | # here attended dim (representation_dim) of decoder is 2*enc_nhinds 71 | # because the context given by the encoder is a bidirectional context 72 | 73 | if mode == "train": 74 | 75 | # Create Theano variables 76 | logger.info('Creating theano variables') 77 | source_sentence = tensor.lmatrix('source') 78 | source_sentence_mask = tensor.matrix('source_mask') 79 | target_sentence = tensor.lmatrix('target') 80 | target_sentence_mask = tensor.matrix('target_mask') 81 | target_topic_sentence=tensor.lmatrix('target_topic'); 82 | target_topic_binary_sentence=tensor.lmatrix('target_binary_topic'); 83 | #target_topic_sentence_mask=tensor.lmatrix('target_topic_mask'); 84 | sampling_input = tensor.lmatrix('input') 85 | source_topical_word=tensor.lmatrix('source_topical') 86 | source_topical_mask=tensor.matrix('source_topical_mask') 87 | 88 | topic_embedding=topical_transformer.apply(source_topical_word); 89 | 90 | 91 | # Get training and development set streams 92 | tr_stream = get_tr_stream_with_topic_target(**config) 93 | #dev_stream = get_dev_tr_stream_with_topic_target(**config) 94 | 95 | # Get cost of the model 96 | representations = encoder.apply(source_sentence, source_sentence_mask) 97 | tw_representation=topical_transformer.look_up.apply(source_topical_word.T); 98 | content_embedding=representations[0,:,(representations.shape[2]/2):]; 99 | cost = decoder.cost(representations, 100 | source_sentence_mask, 101 | tw_representation, 102 | source_topical_mask, 103 | target_sentence, 104 | target_sentence_mask, 105 | target_topic_sentence, 106 | target_topic_binary_sentence, 107 | topic_embedding,content_embedding) 108 | 109 | logger.info('Creating computational graph') 110 | perplexity = tensor.exp(cost) 111 | perplexity.name = 'perplexity' 112 | 113 | cg = ComputationGraph(cost) 114 | costs_computer = function([target_sentence, 115 | target_sentence_mask, 116 | source_sentence, 117 | source_sentence_mask,source_topical_word,target_topic_sentence,target_topic_binary_sentence], (perplexity),on_unused_input='ignore') 118 | 119 | # Initialize model 120 | logger.info('Initializing model') 121 | encoder.weights_init = decoder.weights_init = IsotropicGaussian( 122 | config['weight_scale']) 123 | encoder.biases_init = decoder.biases_init = Constant(0) 124 | encoder.push_initialization_config() 125 | decoder.push_initialization_config() 126 | encoder.bidir.prototype.weights_init = Orthogonal() 127 | decoder.transition.weights_init = Orthogonal() 128 | encoder.initialize() 129 | decoder.initialize() 130 | 131 | topical_transformer.weights_init=IsotropicGaussian( 132 | config['weight_scale']); 133 | topical_transformer.biases_init=Constant(0); 134 | topical_transformer.push_allocation_config();#don't know whether the initialize is for 135 | topical_transformer.look_up.weights_init=Orthogonal(); 136 | topical_transformer.transformer.weights_init=Orthogonal(); 137 | topical_transformer.initialize(); 138 | word_topical_embedding=cPickle.load(open(config['topical_embeddings'], 'rb')); 139 | np_word_topical_embedding=numpy.array(word_topical_embedding,dtype='float32'); 140 | topical_transformer.look_up.W.set_value(np_word_topical_embedding); 141 | topical_transformer.look_up.W.tag.role=[]; 142 | 143 | 144 | # apply dropout for regularization 145 | if config['dropout'] < 1.0: 146 | # dropout is applied to the output of maxout in ghog 147 | logger.info('Applying dropout') 148 | dropout_inputs = [x for x in cg.intermediary_variables 149 | if x.name == 'maxout_apply_output'] 150 | cg = apply_dropout(cg, dropout_inputs, config['dropout']) 151 | 152 | # Apply weight noise for regularization 153 | if config['weight_noise_ff'] > 0.0: 154 | logger.info('Applying weight noise to ff layers') 155 | enc_params = Selector(encoder.lookup).get_params().values() 156 | enc_params += Selector(encoder.fwd_fork).get_params().values() 157 | enc_params += Selector(encoder.back_fork).get_params().values() 158 | dec_params = Selector( 159 | decoder.sequence_generator.readout).get_params().values() 160 | dec_params += Selector( 161 | decoder.sequence_generator.fork).get_params().values() 162 | dec_params += Selector(decoder.state_init).get_params().values() 163 | cg = apply_noise( 164 | cg, enc_params+dec_params, config['weight_noise_ff']) 165 | 166 | 167 | # Print shapes 168 | shapes = [param.get_value().shape for param in cg.parameters] 169 | logger.info("Parameter shapes: ") 170 | for shape, count in Counter(shapes).most_common(): 171 | logger.info(' {:15}: {}'.format(shape, count)) 172 | logger.info("Total number of parameters: {}".format(len(shapes))) 173 | 174 | # Print parameter names 175 | enc_dec_param_dict = merge(Selector(encoder).get_parameters(), 176 | Selector(decoder).get_parameters()) 177 | logger.info("Parameter names: ") 178 | for name, value in enc_dec_param_dict.items(): 179 | logger.info(' {:15}: {}'.format(value.get_value().shape, name)) 180 | logger.info("Total number of parameters: {}" 181 | .format(len(enc_dec_param_dict))) 182 | 183 | 184 | # Set up training model 185 | logger.info("Building model") 186 | training_model = Model(cost) 187 | 188 | # Set extensions 189 | logger.info("Initializing extensions") 190 | extensions = [ 191 | FinishAfter(after_n_batches=config['finish_after']), 192 | TrainingDataMonitoring([perplexity], after_batch=True), 193 | CheckpointNMT(config['saveto'], 194 | config['model_name'], 195 | every_n_batches=config['save_freq']) 196 | ] 197 | 198 | # # Set up beam search and sampling computation graphs if necessary 199 | # if config['hook_samples'] >= 1 or config['bleu_script'] is not None: 200 | # logger.info("Building sampling model") 201 | # sampling_representation = encoder.apply( 202 | # sampling_input, tensor.ones(sampling_input.shape)) 203 | # generated = decoder.generate( 204 | # sampling_input, sampling_representation) 205 | # search_model = Model(generated) 206 | # _, samples = VariableFilter( 207 | # bricks=[decoder.sequence_generator], name="outputs")( 208 | # ComputationGraph(generated[1])) 209 | # 210 | # # Add sampling 211 | # if config['hook_samples'] >= 1: 212 | # logger.info("Building sampler") 213 | # extensions.append( 214 | # Sampler(model=search_model, data_stream=tr_stream, 215 | # model_name=config['model_name'], 216 | # hook_samples=config['hook_samples'], 217 | # every_n_batches=config['sampling_freq'], 218 | # src_vocab_size=config['src_vocab_size'])) 219 | # 220 | # # Add early stopping based on bleu 221 | # if False: 222 | # logger.info("Building bleu validator") 223 | # extensions.append( 224 | # BleuValidator(sampling_input, samples=samples, config=config, 225 | # model=search_model, data_stream=dev_stream, 226 | # normalize=config['normalized_bleu'], 227 | # every_n_batches=config['bleu_val_freq'], 228 | # n_best=3, 229 | # track_n_models=6)) 230 | # 231 | # logger.info("Building perplexity validator") 232 | # extensions.append( 233 | # pplValidation( config=config, 234 | # model=costs_computer, data_stream=dev_stream, 235 | # model_name=config['model_name'], 236 | # every_n_batches=config['sampling_freq'])) 237 | 238 | 239 | # Plot cost in bokeh if necessary 240 | if use_bokeh and BOKEH_AVAILABLE: 241 | extensions.append( 242 | Plot('Cs-En', channels=[['decoder_cost_cost']], 243 | after_batch=True)) 244 | 245 | # Reload model if necessary 246 | if config['reload']: 247 | extensions.append(LoadNMT(config['saveto'])) 248 | 249 | initial_learning_rate = config['initial_learning_rate'] 250 | log_path = os.path.join(config['saveto'], 'log') 251 | if config['reload'] and os.path.exists(log_path): 252 | with open(log_path, 'rb') as source: 253 | log = cPickle.load(source) 254 | last = max(log.keys()) - 1 255 | if 'learning_rate' in log[last]: 256 | initial_learning_rate = log[last]['learning_rate'] 257 | 258 | # Set up training algorithm 259 | logger.info("Initializing training algorithm") 260 | algorithm = GradientDescent( 261 | cost=cost, parameters=cg.parameters, 262 | step_rule=CompositeRule([Scale(initial_learning_rate), 263 | StepClipping(config['step_clipping']), 264 | eval(config['step_rule'])()]), 265 | on_unused_sources='ignore') 266 | 267 | _learning_rate = algorithm.step_rule.components[0].learning_rate 268 | if config['learning_rate_decay']: 269 | extensions.append( 270 | LearningRateHalver(record_name='validation_cost', 271 | comparator=lambda x, y: x > y, 272 | learning_rate=_learning_rate, 273 | patience_default=3)) 274 | else: 275 | extensions.append(OldModelRemover(saveto=config['saveto'])) 276 | 277 | if config['learning_rate_grow']: 278 | extensions.append( 279 | LearningRateDoubler(record_name='validation_cost', 280 | comparator=lambda x, y: x < y, 281 | learning_rate=_learning_rate, 282 | patience_default=3)) 283 | 284 | extensions.append( 285 | SimplePrinting(config['model_name'], after_batch=True)) 286 | 287 | # Initialize main loop 288 | logger.info("Initializing main loop") 289 | main_loop = MainLoop( 290 | model=training_model, 291 | algorithm=algorithm, 292 | data_stream=tr_stream, 293 | extensions=extensions 294 | ) 295 | 296 | # Train! 297 | main_loop.run() 298 | 299 | elif mode == 'translate': 300 | 301 | logger.info('Creating theano variables') 302 | sampling_input = tensor.lmatrix('source') 303 | source_topical_word=tensor.lmatrix('source_topical') 304 | tw_vocab_overlap=tensor.lmatrix('tw_vocab_overlap') 305 | tw_vocab_overlap_matrix=cPickle.load(open(config['tw_vocab_overlap'], 'rb')); 306 | tw_vocab_overlap_matrix=numpy.array(tw_vocab_overlap_matrix,dtype='int32'); 307 | #tw_vocab_overlap=shared(tw_vocab_overlap_matrix); 308 | 309 | topic_embedding=topical_transformer.apply(source_topical_word); 310 | 311 | sutils = SamplingBase() 312 | unk_idx = config['unk_id'] 313 | src_eos_idx = config['src_vocab_size'] - 1 314 | trg_eos_idx = config['trg_vocab_size'] - 1 315 | trg_vocab = _ensure_special_tokens( 316 | cPickle.load(open(config['trg_vocab'], 'rb')), bos_idx=0, 317 | eos_idx=trg_eos_idx, unk_idx=unk_idx) 318 | trg_ivocab = {v: k for k, v in trg_vocab.items()} 319 | 320 | logger.info("Building sampling model") 321 | sampling_representation = encoder.apply( 322 | sampling_input, tensor.ones(sampling_input.shape)) 323 | topic_embedding=topical_transformer.apply(source_topical_word); 324 | tw_representation=topical_transformer.look_up.apply(source_topical_word.T); 325 | content_embedding=sampling_representation[0,:,(sampling_representation.shape[2]/2):]; 326 | generated = decoder.generate(sampling_input,sampling_representation, tw_representation,topical_embedding=topic_embedding,content_embedding=content_embedding); 327 | 328 | _, samples = VariableFilter( 329 | bricks=[decoder.sequence_generator], name="outputs")( 330 | ComputationGraph(generated[1])) # generated[1] is next_outputs 331 | beam_search = BeamSearch(samples=samples) 332 | 333 | logger.info("Loading the model..") 334 | model = Model(generated) 335 | #loader = LoadNMT(config['saveto']) 336 | loader = LoadNMT(config['validation_load']); 337 | loader.set_model_parameters(model, loader.load_parameters_default()) 338 | 339 | logger.info("Started translation: ") 340 | test_stream = get_dev_stream_with_topicalq(**config) 341 | ts = test_stream.get_epoch_iterator() 342 | rts = open(config['val_set_source']).readlines() 343 | ftrans_original = open(config['val_output_orig'], 'w') 344 | saved_weights = [] 345 | total_cost = 0.0 346 | 347 | pbar = ProgressBar(max_value=len(rts)).start() 348 | for i, (line, line_raw) in enumerate(zip(ts, rts)): 349 | trans_in = line_raw.split() 350 | seq = sutils._oov_to_unk( 351 | line[0], config['src_vocab_size'], unk_idx) 352 | seq1=line[1]; 353 | input_topical=numpy.tile(seq1,(config['beam_size'],1)) 354 | input_ = numpy.tile(seq, (config['beam_size'], 1)) 355 | 356 | # draw sample, checking to ensure we don't get an empty string back 357 | trans, costs, attendeds, weights = \ 358 | beam_search.search( 359 | input_values={sampling_input: input_,source_topical_word:input_topical,tw_vocab_overlap:tw_vocab_overlap_matrix}, 360 | tw_vocab_overlap=tw_vocab_overlap_matrix, 361 | max_length=3*len(seq), eol_symbol=trg_eos_idx, 362 | ignore_first_eol=True) 363 | 364 | # normalize costs according to the sequence lengths 365 | if config['normalized_bleu']: 366 | lengths = numpy.array([len(s) for s in trans]) 367 | costs = costs / lengths 368 | 369 | best = numpy.argsort(costs)[0] 370 | try: 371 | total_cost += costs[best] 372 | trans_out = trans[best] 373 | weight = weights[best][:, :len(trans_in)] 374 | trans_out = sutils._idx_to_word(trans_out, trg_ivocab) 375 | except ValueError: 376 | logger.info( 377 | "Can NOT find a translation for line: {}".format(i+1)) 378 | trans_out = '' 379 | 380 | saved_weights.append(weight) 381 | print(' '.join(trans_out), file=ftrans_original) 382 | pbar.update(i + 1) 383 | 384 | pbar.finish() 385 | logger.info("Total cost of the test: {}".format(total_cost)) 386 | cPickle.dump(saved_weights, open(config['attention_weights'], 'wb')) 387 | ftrans_original.close() 388 | # ap = afterprocesser(config) 389 | # ap.main() 390 | 391 | elif mode == 'score': 392 | logger.info('Creating theano variables') 393 | source_sentence = tensor.lmatrix('source') 394 | source_sentence_mask = tensor.matrix('source_mask') 395 | target_sentence = tensor.lmatrix('target') 396 | target_sentence_mask = tensor.matrix('target_mask') 397 | target_topic_sentence=tensor.lmatrix('target_topic'); 398 | target_topic_binary_sentence=tensor.lmatrix('target_binary_topic'); 399 | source_topical_word=tensor.lmatrix('source_topical') 400 | 401 | topic_embedding=topical_transformer.apply(source_topical_word); 402 | # Get cost of the model 403 | representations = encoder.apply(source_sentence, source_sentence_mask) 404 | costs = decoder.cost(representations, 405 | source_sentence_mask, 406 | target_sentence, 407 | target_sentence_mask, 408 | target_topic_sentence, 409 | target_topic_binary_sentence, 410 | topic_embedding) 411 | 412 | config['batch_size'] = 1 413 | config['sort_k_batches'] = 1 414 | # Get test set stream 415 | test_stream = get_tr_stream_with_topic_target(**config) 416 | 417 | logger.info("Building sampling model") 418 | 419 | 420 | logger.info("Loading the model..") 421 | model = Model(costs) 422 | loader = LoadNMT(config['validation_load']) 423 | loader.set_model_parameters(model, loader.load_parameters_default()) 424 | 425 | costs_computer = function([target_sentence, 426 | target_sentence_mask, 427 | source_sentence, 428 | source_sentence_mask,source_topical_word,target_topic_sentence,target_topic_binary_sentence], (costs),on_unused_input='ignore') 429 | 430 | iterator = test_stream.get_epoch_iterator() 431 | 432 | scores = [] 433 | att_weights = [] 434 | for i, (src, src_mask, trg, trg_mask,te,te_mask,tt,tt_mask,tb,tb_mask) in enumerate(iterator): 435 | costs = costs_computer(*[trg, trg_mask, src, src_mask,te,tt,tb]) 436 | cost = costs.sum() 437 | print(i, cost) 438 | scores.append(cost) 439 | 440 | print(sum(scores)/10007); 441 | 442 | -------------------------------------------------------------------------------- /attention.py: -------------------------------------------------------------------------------- 1 | import theano 2 | from theano import tensor 3 | 4 | from blocks.bricks import (Brick, Initializable, Sequence, 5 | Feedforward, Linear, Tanh) 6 | from blocks.bricks.base import lazy, application 7 | from blocks.bricks.parallel import Parallel, Distribute 8 | from blocks.bricks.recurrent import recurrent, BaseRecurrent 9 | from blocks.utils import dict_union, dict_subset, pack 10 | from blocks.bricks.attention import ( 11 | GenericSequenceAttention, AbstractAttentionRecurrent) 12 | from match_functions import SumMatchFunction 13 | 14 | 15 | class SequenceContentAttention(GenericSequenceAttention, Initializable): 16 | """Attention mechanism that looks for relevant content in a sequence. 17 | 18 | This is the attention mechanism used in [BCB]_. The idea in a nutshell: 19 | 20 | 1. The states and the sequence are transformed independently, 21 | 22 | 2. The transformed states are summed with every transformed sequence 23 | element to obtain *match vectors*, 24 | 25 | 3. A match vector is transformed into a single number interpreted as 26 | *energy*, 27 | 28 | 4. Energies are normalized in softmax-like fashion. The resulting 29 | summing to one weights are called *attention weights*, 30 | 31 | 5. Weighted average of the sequence elements with attention weights 32 | is computed. 33 | 34 | In terms of the :class:`AbstractAttention` documentation, the sequence 35 | is the attended. The weighted averages from 5 and the attention 36 | weights from 4 form the set of glimpses produced by this attention 37 | mechanism. 38 | 39 | Parameters 40 | ---------- 41 | state_names : list of str 42 | The names of the network states. 43 | attended_dim : int 44 | The dimension of the sequence elements. 45 | match_dim : int 46 | The dimension of the match vector. 47 | state_transformer : :class:`.Brick` 48 | A prototype for state transformations. If ``None``, 49 | a linear transformation is used. 50 | attended_transformer : :class:`.Feedforward` 51 | The transformation to be applied to the sequence. If ``None`` an 52 | affine transformation is used. 53 | energy_computer : :class:`.Feedforward` 54 | Computes energy from the match vector. If ``None``, an affine 55 | transformations preceeded by :math:`tanh` is used. 56 | 57 | Notes 58 | ----- 59 | See :class:`.Initializable` for initialization parameters. 60 | 61 | .. [BCB] Dzmitry Bahdanau, Kyunghyun Cho and Yoshua Bengio. Neural 62 | Machine Translation by Jointly Learning to Align and Translate. 63 | 64 | """ 65 | @lazy(allocation=['match_dim']) 66 | def __init__(self, match_dim, 67 | use_local_attention=False, window_size=10, sigma=None, 68 | state_transformer=None, local_state_transformer=None, 69 | local_predictor=None, attended_transformer=None, 70 | energy_computer=None, **kwargs): 71 | super(SequenceContentAttention, self).__init__(**kwargs) 72 | if not state_transformer: 73 | state_transformer = Linear(use_bias=False, name="state_trans") 74 | if not local_state_transformer: 75 | local_state_transformer = Linear(use_bias=False, 76 | name="local_state_trans") 77 | if not local_predictor: 78 | local_predictor = Linear(use_bias=False, name="local_pred") 79 | if sigma is None: 80 | sigma = window_size * 1.0 / 2 81 | self.use_local_attention = use_local_attention 82 | self.sigma = sigma * sigma 83 | self.match_dim = match_dim 84 | self.state_name = self.state_names[0] 85 | 86 | self.state_transformer = state_transformer 87 | self.local_state_transformer = local_state_transformer 88 | self.local_predictor = local_predictor 89 | 90 | if not attended_transformer: 91 | attended_transformer = Linear(name="preprocess") 92 | if not energy_computer: 93 | energy_computer = SumMatchFunction(name="energy_comp") 94 | self.attended_transformer = attended_transformer 95 | self.energy_computer = energy_computer 96 | 97 | self.children = [self.state_transformer, self.local_state_transformer, 98 | self.local_predictor, self.attended_transformer, 99 | energy_computer] 100 | 101 | def _push_allocation_config(self): 102 | self.state_dim = self.state_dims[0] 103 | 104 | self.state_transformer.input_dim = self.state_dim 105 | self.state_transformer.output_dim = self.match_dim 106 | 107 | self.local_state_transformer.input_dim = self.state_dim 108 | self.local_state_transformer.output_dim = self.match_dim 109 | 110 | self.local_predictor.input_dim = self.state_dim 111 | self.local_predictor.output_dim = 1 112 | 113 | self.attended_transformer.input_dim = self.attended_dim 114 | self.attended_transformer.output_dim = self.match_dim 115 | 116 | self.energy_computer.input_dim = self.match_dim 117 | self.energy_computer.output_dim = 1 118 | 119 | @application 120 | def compute_energies(self, attended, preprocessed_attended, states): 121 | if not preprocessed_attended: 122 | preprocessed_attended = self.preprocess(attended) 123 | _states = states[self.state_name] 124 | transformed_states = self.state_transformer.apply(_states) 125 | # Broadcasting of transformed states should be done automatically 126 | # match_vectors = sum(transformed_states.values(), 127 | # preprocessed_attended) 128 | # energies = self.energy_computer.apply(match_vectors).reshape( 129 | # match_vectors.shape[:-1], ndim=match_vectors.ndim - 1) 130 | energies = self.energy_computer.apply(transformed_states, 131 | preprocessed_attended) 132 | return energies 133 | 134 | @application 135 | def get_local_predition(self, states, attended, attended_mask): 136 | _states = states[self.state_name] 137 | # local_states: [batch, features] 138 | local_states = self.local_state_transformer.apply(_states) 139 | # local_prediction is reshaped to [batch] 140 | local_prediction = self.local_predictor.apply( 141 | tensor.tanh(local_states)).reshape( 142 | local_states.shape[:-1], ndim=local_states.ndim - 1) 143 | local_prediction = tensor.nnet.sigmoid(local_prediction) 144 | # attended_mask is [time, batch] 145 | _attended_mask = tensor.sum(attended_mask, axis=0) 146 | return _attended_mask * local_prediction 147 | 148 | @application 149 | def adjust_weights(self, attended_mask, weights, local_prediction): 150 | # weights: [time, batch] 151 | # local_prediction: [batch] 152 | # locations: [time, batch] 153 | locations = tensor.arange( 154 | attended_mask.shape[0]).repeat( 155 | attended_mask.shape[1]).reshape( 156 | attended_mask.shape).astype( 157 | theano.config.floatX) 158 | # diff: [time, batch] 159 | diff = locations - local_prediction 160 | # gauss: [time, batch] 161 | gauss = tensor.pow(diff, 2) / (2 * self.sigma) 162 | gauss = tensor.exp(-gauss) 163 | weights = weights * gauss 164 | return weights 165 | 166 | @application(outputs=['weighted_averages', 'weights']) 167 | def take_glimpses(self, attended, preprocessed_attended=None, 168 | attended_mask=None, **states): 169 | r"""Compute attention weights and produce glimpses. 170 | 171 | Parameters 172 | ---------- 173 | attended : :class:`~tensor.TensorVariable` 174 | The sequence, time is the 1-st dimension. 175 | preprocessed_attended : :class:`~tensor.TensorVariable` 176 | The preprocessed sequence. If ``None``, is computed by calling 177 | :meth:`preprocess`. 178 | attended_mask : :class:`~tensor.TensorVariable` 179 | A 0/1 mask specifying available data. 0 means that the 180 | corresponding sequence element is fake. 181 | \*\*states 182 | The states of the network. 183 | 184 | Returns 185 | ------- 186 | weighted_averages : :class:`~theano.Variable` 187 | Linear combinations of sequence elements with the attention 188 | weights. 189 | weights : :class:`~theano.Variable` 190 | The attention weights. The first dimension is batch, the second 191 | is time. 192 | 193 | """ 194 | energies = self.compute_energies( 195 | attended, preprocessed_attended, states) 196 | # weights has dimensions: [time (src), batch] 197 | weights = self.compute_weights(energies, attended_mask) 198 | if self.use_local_attention: 199 | # local_pred should have dimension: [batch], 200 | # the predicted position for each batch 201 | local_pred = self.get_local_predition( 202 | states, attended, attended_mask) 203 | weights = self.adjust_weights(attended_mask, weights, local_pred) 204 | weighted_averages = self.compute_weighted_averages(weights, attended) 205 | return weighted_averages, weights.T 206 | 207 | @take_glimpses.property('inputs') 208 | def take_glimpses_inputs(self): 209 | return (['attended', 'preprocessed_attended', 'attended_mask'] + 210 | self.state_names) 211 | 212 | @application(outputs=['weighted_averages', 'weights']) 213 | def initial_glimpses(self, batch_size, attended): 214 | return [tensor.zeros((batch_size, self.attended_dim)), 215 | tensor.zeros((batch_size, attended.shape[0]))] 216 | 217 | @application(inputs=['attended'], outputs=['preprocessed_attended']) 218 | def preprocess(self, attended): 219 | """Preprocess the sequence for computing attention weights. 220 | 221 | Parameters 222 | ---------- 223 | attended : :class:`~tensor.TensorVariable` 224 | The attended sequence, time is the 1-st dimension. 225 | 226 | """ 227 | return self.attended_transformer.apply(attended) 228 | 229 | def get_dim(self, name): 230 | if name in ['weighted_averages']: 231 | return self.attended_dim 232 | if name in ['weights']: 233 | return 0 234 | return super(SequenceContentAttention, self).get_dim(name) 235 | 236 | 237 | class AttentionRecurrent(AbstractAttentionRecurrent, Initializable): 238 | """Combines an attention mechanism and a recurrent transition. 239 | 240 | This brick equips a recurrent transition with an attention mechanism. 241 | In order to do this two more contexts are added: one to be attended and 242 | a mask for it. It is also possible to use the contexts of the given 243 | recurrent transition for these purposes and not add any new ones, 244 | see `add_context` parameter. 245 | 246 | At the beginning of each step attention mechanism produces glimpses; 247 | these glimpses together with the current states are used to compute the 248 | next state and finish the transition. In some cases glimpses from the 249 | previous steps are also necessary for the attention mechanism, e.g. 250 | in order to focus on an area close to the one from the previous step. 251 | This is also supported: such glimpses become states of the new 252 | transition. 253 | 254 | To let the user control the way glimpses are used, this brick also 255 | takes a "distribute" brick as parameter that distributes the 256 | information from glimpses across the sequential inputs of the wrapped 257 | recurrent transition. 258 | 259 | Parameters 260 | ---------- 261 | transition : :class:`.BaseRecurrent` 262 | The recurrent transition. 263 | attention : :class:`.Brick` 264 | The attention mechanism. 265 | distribute : :class:`.Brick`, optional 266 | Distributes the information from glimpses across the input 267 | sequences of the transition. By default a :class:`.Distribute` is 268 | used, and those inputs containing the "mask" substring in their 269 | name are not affected. 270 | add_contexts : bool, optional 271 | If ``True``, new contexts for the attended and the attended mask 272 | are added to this transition, otherwise existing contexts of the 273 | wrapped transition are used. ``True`` by default. 274 | attended_name : str 275 | The name of the attended context. If ``None``, "attended" 276 | or the first context of the recurrent transition is used 277 | depending on the value of `add_contents` flag. 278 | attended_mask_name : str 279 | The name of the mask for the attended context. If ``None``, 280 | "attended_mask" or the second context of the recurrent transition 281 | is used depending on the value of `add_contents` flag. 282 | 283 | Notes 284 | ----- 285 | See :class:`.Initializable` for initialization parameters. 286 | 287 | Wrapping your recurrent brick with this class makes all the 288 | states mandatory. If you feel this is a limitation for you, try 289 | to make it better! This restriction does not apply to sequences 290 | and contexts: those keep being as optional as they were for 291 | your brick. 292 | 293 | Those coming to Blocks from Groundhog might recognize that this is 294 | a `RecurrentLayerWithSearch`, but on steroids :) 295 | 296 | """ 297 | def __init__(self, transition, attention, distribute=None, 298 | add_contexts=True, 299 | attended_name=None, attended_mask_name=None, 300 | **kwargs): 301 | super(AttentionRecurrent, self).__init__(**kwargs) 302 | self._sequence_names = list(transition.apply.sequences) 303 | self._state_names = list(transition.apply.states) 304 | self._context_names = list(transition.apply.contexts) 305 | if add_contexts: 306 | if not attended_name: 307 | attended_name = 'attended' 308 | if not attended_mask_name: 309 | attended_mask_name = 'attended_mask' 310 | self._context_names += [attended_name, attended_mask_name] 311 | else: 312 | attended_name = self._context_names[0] 313 | attended_mask_name = self._context_names[1] 314 | if not distribute: 315 | normal_inputs = [name for name in self._sequence_names 316 | if 'mask' not in name] 317 | distribute = Distribute(normal_inputs, 318 | attention.take_glimpses.outputs[0]) 319 | 320 | self.transition = transition 321 | self.attention = attention 322 | self.distribute = distribute 323 | self.add_contexts = add_contexts 324 | self.attended_name = attended_name 325 | self.attended_mask_name = attended_mask_name 326 | 327 | self.preprocessed_attended_name = "preprocessed_" + self.attended_name 328 | 329 | self._glimpse_names = self.attention.take_glimpses.outputs 330 | # We need to determine which glimpses are fed back. 331 | # Currently we extract it from `take_glimpses` signature. 332 | self.previous_glimpses_needed = [ 333 | name for name in self._glimpse_names 334 | if name in self.attention.take_glimpses.inputs] 335 | 336 | self.children = [self.transition, self.attention, self.distribute] 337 | 338 | def _push_allocation_config(self): 339 | self.attention.state_dims = self.transition.get_dims( 340 | self.attention.state_names) 341 | self.attention.attended_dim = self.get_dim(self.attended_name) 342 | self.distribute.source_dim = self.attention.get_dim( 343 | self.distribute.source_name) 344 | self.distribute.target_dims = self.transition.get_dims( 345 | self.distribute.target_names) 346 | 347 | @application 348 | def take_glimpses(self, **kwargs): 349 | r"""Compute glimpses with the attention mechanism. 350 | 351 | A thin wrapper over `self.attention.take_glimpses`: takes care 352 | of choosing and renaming the necessary arguments. 353 | 354 | Parameters 355 | ---------- 356 | \*\*kwargs 357 | Must contain the attended, previous step states and glimpses. 358 | Can optionaly contain the attended mask and the preprocessed 359 | attended. 360 | 361 | Returns 362 | ------- 363 | glimpses : list of :class:`~tensor.TensorVariable` 364 | Current step glimpses. 365 | 366 | """ 367 | states = dict_subset(kwargs, self._state_names, pop=True) 368 | glimpses = dict_subset(kwargs, self._glimpse_names, pop=True) 369 | glimpses_needed = dict_subset(glimpses, self.previous_glimpses_needed) 370 | result = self.attention.take_glimpses( 371 | kwargs.pop(self.attended_name), 372 | kwargs.pop(self.preprocessed_attended_name, None), 373 | kwargs.pop(self.attended_mask_name, None), 374 | **dict_union(states, glimpses_needed)) 375 | # At this point kwargs may contain additional items. 376 | # e.g. AttentionRecurrent.transition.apply.contexts 377 | return result 378 | 379 | @take_glimpses.property('outputs') 380 | def take_glimpses_outputs(self): 381 | return self._glimpse_names 382 | 383 | @application 384 | def compute_states(self, **kwargs): 385 | r"""Compute current states when glimpses have already been computed. 386 | 387 | Combines an application of the `distribute` that alter the 388 | sequential inputs of the wrapped transition and an application of 389 | the wrapped transition. All unknown keyword arguments go to 390 | the wrapped transition. 391 | 392 | Parameters 393 | ---------- 394 | \*\*kwargs 395 | Should contain everything what `self.transition` needs 396 | and in addition the current glimpses. 397 | 398 | Returns 399 | ------- 400 | current_states : list of :class:`~tensor.TensorVariable` 401 | Current states computed by `self.transition`. 402 | 403 | """ 404 | # make sure we are not popping the mask 405 | normal_inputs = [name for name in self._sequence_names 406 | if 'mask' not in name] 407 | sequences = dict_subset(kwargs, normal_inputs, pop=True) 408 | glimpses = dict_subset(kwargs, self._glimpse_names, pop=True) 409 | if self.add_contexts: 410 | kwargs.pop(self.attended_name) 411 | # attended_mask_name can be optional 412 | kwargs.pop(self.attended_mask_name, None) 413 | 414 | sequences.update(self.distribute.apply( 415 | as_dict=True, **dict_subset(dict_union(sequences, glimpses), 416 | self.distribute.apply.inputs))) 417 | current_states = self.transition.apply( 418 | iterate=False, as_list=True, 419 | **dict_union(sequences, kwargs)) 420 | return current_states 421 | 422 | @compute_states.property('outputs') 423 | def compute_states_outputs(self): 424 | return self._state_names 425 | 426 | @recurrent 427 | def do_apply(self, **kwargs): 428 | r"""Process a sequence attending the attended context every step. 429 | 430 | In addition to the original sequence this method also requires 431 | its preprocessed version, the one computed by the `preprocess` 432 | method of the attention mechanism. Unknown keyword arguments 433 | are passed to the wrapped transition. 434 | 435 | Parameters 436 | ---------- 437 | \*\*kwargs 438 | Should contain current inputs, previous step states, contexts, 439 | the preprocessed attended context, previous step glimpses. 440 | 441 | Returns 442 | ------- 443 | outputs : list of :class:`~tensor.TensorVariable` 444 | The current step states and glimpses. 445 | 446 | """ 447 | attended = kwargs[self.attended_name] 448 | preprocessed_attended = kwargs.pop(self.preprocessed_attended_name) 449 | attended_mask = kwargs.get(self.attended_mask_name) 450 | sequences = dict_subset(kwargs, self._sequence_names, pop=True, 451 | must_have=False) 452 | states = dict_subset(kwargs, self._state_names, pop=True) 453 | glimpses = dict_subset(kwargs, self._glimpse_names, pop=True) 454 | 455 | current_glimpses = self.take_glimpses( 456 | as_dict=True, 457 | **dict_union( 458 | states, glimpses, 459 | {self.attended_name: attended, 460 | self.attended_mask_name: attended_mask, 461 | self.preprocessed_attended_name: preprocessed_attended})) 462 | current_states = self.compute_states( 463 | as_list=True, 464 | **dict_union(sequences, states, current_glimpses, kwargs)) 465 | return current_states + list(current_glimpses.values()) 466 | 467 | @do_apply.property('sequences') 468 | def do_apply_sequences(self): 469 | return self._sequence_names 470 | 471 | @do_apply.property('contexts') 472 | def do_apply_contexts(self): 473 | return self._context_names + [self.preprocessed_attended_name] 474 | 475 | @do_apply.property('states') 476 | def do_apply_states(self): 477 | return self._state_names + self._glimpse_names 478 | 479 | @do_apply.property('outputs') 480 | def do_apply_outputs(self): 481 | return self._state_names + self._glimpse_names 482 | 483 | @application 484 | def apply(self, **kwargs): 485 | """Preprocess a sequence attending the attended context at every step. 486 | 487 | Preprocesses the attended context and runs :meth:`do_apply`. See 488 | :meth:`do_apply` documentation for further information. 489 | 490 | """ 491 | preprocessed_attended = self.attention.preprocess( 492 | kwargs[self.attended_name]) 493 | return self.do_apply( 494 | **dict_union(kwargs, 495 | {self.preprocessed_attended_name: 496 | preprocessed_attended})) 497 | 498 | @apply.delegate 499 | def apply_delegate(self): 500 | # TODO: Nice interface for this trick? 501 | return self.do_apply.__get__(self, None) 502 | 503 | @apply.property('contexts') 504 | def apply_contexts(self): 505 | return self._context_names 506 | 507 | @application 508 | def initial_states(self, batch_size, **kwargs): 509 | return (pack(self.transition.initial_states( 510 | batch_size, **kwargs)) + 511 | pack(self.attention.initial_glimpses( 512 | batch_size, kwargs[self.attended_name]))) 513 | 514 | @initial_states.property('outputs') 515 | def initial_states_outputs(self): 516 | return self.do_apply.states 517 | 518 | def get_dim(self, name): 519 | if name in self._glimpse_names: 520 | return self.attention.get_dim(name) 521 | if name == self.preprocessed_attended_name: 522 | (original_name,) = self.attention.preprocess.outputs 523 | return self.attention.get_dim(original_name) 524 | if self.add_contexts: 525 | if name == self.attended_name: 526 | return self.attention.get_dim( 527 | self.attention.take_glimpses.inputs[0]) 528 | if name == self.attended_mask_name: 529 | return 0 530 | return self.transition.get_dim(name) 531 | --------------------------------------------------------------------------------