├── preprocessor ├── __init__.py └── reader.py ├── .gitignore ├── requirements.txt ├── download_data_and_model.sh ├── train.py ├── evaluation.py ├── README.md ├── cluster_analysis.py ├── lib.py └── tpr_rnn_graph.py /preprocessor/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | logs/ 3 | pre_trained/ 4 | tasks/ 5 | plots/ 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==1.11.0 2 | seaborn==0.9.0 3 | sklearn==0.0 4 | -------------------------------------------------------------------------------- /download_data_and_model.sh: -------------------------------------------------------------------------------- 1 | wget --no-check-certificate -r 'https://docs.google.com/uc?export=download&id=1MclIZs597dnaopRra646-KLQ-7qZnHuw' -O data_and_model.tar.gz 2 | tar xvf data_and_model.tar.gz 3 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Trains a TPR-RNN model from scratch. 2 | # 3 | # Usage: 4 | # python3 train.py [task_id] [log_subfolder] 5 | # - task_id: 0-20 where 0 is all tasks combined and 1-20 are the 20 bAbI tasks 6 | # - log_subfolder: the folder inside the logs directory 7 | # 8 | # Example: 9 | # python3 train.py 0 default 10 | # 11 | # Result: 12 | # - Starts training on the all-tasks objective. 13 | # - New log folder in logs/default 14 | # > Will contain tensorflow event files, terminal output, best model checkpoint, etc. 15 | # 16 | # 17 | 18 | from tpr_rnn_graph import * 19 | 20 | train(steps=1000000, bs=32, terminal_log_every=250, validate_every=500) -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | # Loads a pre-trained model, evaluates it, and prints random validation set samples. 2 | # 3 | # Usage: 4 | # python3 evaluation.py 5 | # 6 | # 7 | 8 | import tensorflow 9 | import sys 10 | import os 11 | 12 | sys.argv = [sys.argv[0], 0, "tmp"] # force parameters for easy use 13 | 14 | from tpr_rnn_graph import * # this import will create the graph 15 | 16 | make_eval_tensors() 17 | print() 18 | 19 | # print test and validation set performance 20 | print("evaluate a random all-tasks model on the validation and test data:") 21 | full_eval() # evaluate the random initialized model 22 | print() 23 | 24 | print("restoring a trained all-tasks model ...") 25 | saver.restore(sess, "pre_trained/model.ckpt") 26 | print() 27 | 28 | print("evaluate the trained all-tasks model on the validation and test data:") 29 | full_eval() # evaluate the trained model 30 | print() 31 | 32 | print("evaluate the trained all-tasks model on individual tasks:") 33 | eval_every_task() 34 | print() 35 | 36 | # show an example 37 | print("printing a random sample from the all-tasks data:") 38 | show_random_sample() 39 | 40 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning to Reason with Third-Order Tensor Products 2 | This repository contains the code accompanying the paper [*Learning to Reason with Third-Order Tensor Products*](https://papers.nips.cc/paper/8203-learning-to-reason-with-third-order-tensor-products) published at NeurIPS, 2018. It encompasses our implementation of the Tensor Product Representation Recurrent Neural Network (TPR-RNN) applied to the bAbI tasks with SOTA results. A download script for a pretrained model is provided. 3 | 4 | # Requirements 5 | - Python 3 6 | - Tensorflow==1.11.0 7 | - Seaborn==0.9.0 8 | - Sklearn==0.0 9 | 10 | Make sure to upgrade pip before installing the requirements. 11 | ```bash 12 | pip3 install -r requirements.txt 13 | sh download_data_and_model.sh 14 | ``` 15 | 16 | # Usage 17 | Run the pre-trained model. 18 | ```bash 19 | python3 evaluation.py 20 | ``` 21 | 22 | Train from scratch. (Look at the train.py files for details) 23 | ```bash 24 | python3 train.py 25 | ``` 26 | 27 | Generate small hierarchically clustered similarity matrices of a random set of sentences using different internal representations and the cosine similarity. 28 | ```bash 29 | python3 cluster_analysis.py 30 | ``` 31 | 32 | # Results 33 | With this code we achieved the following error (rounded to two decimal places) when trained on all bAbI tasks simultaneously. In the appendix of the paper we provide a breakdown per task. 34 | 35 | task|run-0|run-1|run-2|run-3|run-4|run-5|run-6|run-7|best|mean|std 36 | |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| 37 | all|1.50|1.69|1.13|1.04|0.78|0.96|1.20|2.40|0.78|1.34|0.52 38 | 39 | # Citation 40 | ``` 41 | @inproceedings{schlag2018tprrnn, 42 | title={Learning to Reason with Third Order Tensor Products}, 43 | author={Schlag, Imanol and Schmidhuber, J{\"u}rgen}, 44 | booktitle={Advances in Neural Information Processing Systems}, 45 | pages={10002--10013}, 46 | year={2018} 47 | } 48 | ``` 49 | -------------------------------------------------------------------------------- /cluster_analysis.py: -------------------------------------------------------------------------------- 1 | # Loads a pre-trained model and generates hierarchically clustered cosine similarity 2 | # matrices for specific internal representations (like the entities (e1, e2) or the 3 | # relations (r1, r2, r3). 4 | # 5 | # You can generate the matrices for a few sentences or for many. 6 | # Simply vary the number of stories below. 7 | # 8 | # Usage: 9 | # python3 cluster_analysis.py 10 | # 11 | # 12 | 13 | import tensorflow 14 | import sys 15 | import os 16 | 17 | sys.argv = [sys.argv[0], 0, "tmp"] # force parameters for easy use 18 | 19 | from tpr_rnn_graph import * # this import will create the graph 20 | 21 | import matplotlib 22 | matplotlib.use('Agg') 23 | from matplotlib import pyplot as plt 24 | from sklearn.metrics.pairwise import cosine_similarity 25 | import seaborn as sns 26 | 27 | print("restoring a trained model ...") 28 | saver.restore(sess, "pre_trained/model.ckpt") 29 | print() 30 | print("evaluate the trained model:") 31 | full_eval() # evaluate the trained model 32 | print() 33 | 34 | 35 | def cos_sim_clustering(item, number_of_stories=1000): 36 | idxs = np.random.randint(low=1, high=test_epoch_size, size=number_of_stories) 37 | batch = [raw_test[0][idxs,:,:], 38 | raw_test[1][idxs], 39 | raw_test[2][idxs,:], 40 | raw_test[3][idxs]] 41 | # r stands for representations 42 | all_r, all_stories = sess.run([item, story], get_feed_dic(batch)) 43 | all_r = np.reshape(all_r, (-1, all_r.shape[-1])) 44 | sentences = np.reshape(all_stories, (-1, all_stories.shape[-1])) 45 | _, indecies = np.unique(sentences, axis=0, return_index=True) 46 | print("{} unique sentences found in {} random stories.".format(len(indecies), number_of_stories)) 47 | sentences = sentences[indecies] 48 | r = all_r[indecies] 49 | C = cosine_similarity(r) 50 | g = sns.clustermap(C, standard_scale=1, figsize=(20,20)) 51 | return g, sentences 52 | 53 | 54 | def plot_small_random_sample(item, name): 55 | g, sentences = cos_sim_clustering(item, number_of_stories=5) 56 | g.savefig("small_plot_{}.png".format(name)) 57 | for idx in g.dendrogram_row.reordered_ind: 58 | print("{:4}: {}".format(idx, translate(sentences[idx]))) 59 | 60 | print("generating cosine similarity matrices for several representations of every sentence.") 61 | 62 | print("\nthe randomly selected sentences for e1:") 63 | plot_small_random_sample(e1, "e1") 64 | 65 | print("\nthe randomly selected sentences for e2:") 66 | plot_small_random_sample(e2, "e2") 67 | 68 | print("\nthe randomly selected sentences for r1:") 69 | plot_small_random_sample(r1, "r1") 70 | 71 | print("\nthe randomly selected sentences for r2:") 72 | plot_small_random_sample(r2, "r2") 73 | 74 | print("\nthe randomly selected sentences for r3:") 75 | plot_small_random_sample(r3, "r3") 76 | 77 | print("\nall image files written.") 78 | -------------------------------------------------------------------------------- /lib.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import logging 4 | import shutil 5 | import sys 6 | import os 7 | 8 | default_dtype = tf.float32 9 | 10 | ###### Helper classes --------- 11 | class LoopCell(tf.contrib.rnn.RNNCell): 12 | """ 13 | Dummy cell with an empty call/body function that should be overwritten 14 | to easy implement a tf.while_loop with dynamic sequence lengths. 15 | """ 16 | def __init__(self, h_shape, reuse=None, name=None): 17 | super(LoopCell, self).__init__(_reuse=reuse, name=name) 18 | # dummy output to make it work with dynamic_rnn 19 | #dummy_h = tf.zeros((batch_size, entityC_size, roleC_size, entityC_size)) 20 | dummy_h = tf.zeros(h_shape) 21 | 22 | self.out = [dummy_h] 23 | 24 | @property 25 | def state_size(self): 26 | return 0 # not executed but function has to be implemented 27 | 28 | @property 29 | def output_size(self): 30 | return [o.shape[1:] for o in self.out] 31 | 32 | def build(self, inputs_shape): 33 | # init variables here 34 | self.built = True 35 | 36 | def call(self, inputs, state): 37 | # overwritten in the model part 38 | return self.out, state 39 | 40 | 41 | ###### Helper Functions --------- 42 | def zeros_init(dtype=default_dtype): 43 | return tf.zeros_initializer(dtype=dtype) 44 | 45 | def ones_init(dtype=default_dtype): 46 | return tf.ones_initializer(dtype=dtype) 47 | 48 | def uniform_init(limit, dtype=default_dtype): 49 | return tf.random_uniform_initializer(minval=-limit, maxval=limit, dtype=dtype) 50 | 51 | def uniform_glorot_init(in_size, out_size, dtype=default_dtype): 52 | a = np.sqrt(6.0 / (in_size + out_size)) 53 | return tf.random_uniform_initializer(minval=-a, maxval=a, dtype=dtype) 54 | 55 | def get_affine_vars(prefix, shape, w_initializer, b_initializer=zeros_init()): 56 | weights = tf.get_variable(prefix + "_w", shape=shape, initializer=w_initializer) 57 | bias = tf.get_variable(prefix + "_b", shape=[shape[-1]], initializer=b_initializer) 58 | return weights, bias 59 | 60 | def make_summary(tag, value): 61 | return tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 62 | 63 | def norm(item, active, scope, axes=[-1]): 64 | """ Layernorm """ 65 | if not active: 66 | return item 67 | mean, var = tf.nn.moments(item, axes=axes, keep_dims=True) 68 | normed = (item - mean) / tf.sqrt(var + 0.0000001) 69 | 70 | with tf.variable_scope(scope): 71 | gain = tf.get_variable("LN_gain", [1], initializer=ones_init()) 72 | bias = tf.get_variable("LN_bias", [1], initializer=zeros_init()) 73 | normed = normed * gain + bias 74 | return normed 75 | 76 | def MLP(inputs, n_networks, equation, input_size, hidden_size, output_size, scope): 77 | """ compute the output of n distinct 3-layer MLPs in series. """ 78 | with tf.variable_scope(scope): 79 | outputs = [] 80 | for idx in range(n_networks): 81 | W1_w, W1_b = get_affine_vars("W1_net"+str(idx), shape=[input_size, hidden_size], w_initializer=uniform_glorot_init(input_size, hidden_size)) 82 | W2_w, W2_b = get_affine_vars("W2_net"+str(idx), shape=[hidden_size, output_size], w_initializer=uniform_glorot_init(hidden_size, output_size)) 83 | 84 | hidden = tf.nn.tanh(tf.einsum(equation, inputs, W1_w) + W1_b) 85 | out = tf.nn.tanh(tf.einsum(equation, hidden, W2_w) + W2_b) 86 | 87 | outputs.append(out) 88 | return outputs 89 | 90 | def get_total_trainable_parameters(): 91 | total_parameters = 0 92 | for variable in tf.trainable_variables(): 93 | # shape is an array of tf.Dimension 94 | shape = variable.get_shape() 95 | variable_parametes = 1 96 | for dim in shape: 97 | variable_parametes *= dim.value 98 | total_parameters += variable_parametes 99 | return total_parameters 100 | 101 | def init_logger(log_folder, file_name="output.log"): 102 | if os.path.exists(log_folder): 103 | print("WARNING: The results directory (%s) already exists. Delete previous results directory [y/N]? " % log_folder, end="") 104 | var = input() 105 | if var is "y" or var is "Y": 106 | print("removing directory ...") 107 | shutil.rmtree(log_folder, ignore_errors=True) 108 | else: 109 | print("ERROR: The results directory already exists: %s" % log_folder) 110 | sys.exit(1) 111 | os.makedirs(log_folder) 112 | log_file_path = os.path.join(log_folder, file_name) 113 | 114 | logger = logging.getLogger("my_logger") # unable to use a new file handler with the tensorflow logger 115 | logger.setLevel(logging.DEBUG) 116 | logger.addHandler(logging.FileHandler(log_file_path)) 117 | logger.addHandler(logging.StreamHandler()) 118 | 119 | return logger 120 | -------------------------------------------------------------------------------- /preprocessor/reader.py: -------------------------------------------------------------------------------- 1 | """ 2 | bAbI task reader from https://github.com/siddk/entity-network 3 | adapted for python3 and other needs. 4 | 5 | reader.py 6 | 7 | Core script containing preprocessing logic - reads bAbI Task Story, and returns 8 | vectorized forms of the stories, questions, and answers. 9 | """ 10 | import numpy as np 11 | import os 12 | import pickle 13 | import re 14 | 15 | from functools import reduce 16 | 17 | FORMAT_STR = "qa{}_" 18 | PAD_ID = 0 19 | SPLIT_RE = re.compile('(\W+)?') 20 | 21 | def parse(data_path, task_id, word2id=None, bsz=32, DATA_TYPES=['train', 'valid', 'test'], global_sentence_max=0, use_cache=True): 22 | vectorized_data, story_data, global_story_max = [], [], 0 23 | for data_type in DATA_TYPES: 24 | print("read {} ...".format(data_type)) 25 | cache_path = data_path + "-pik/" + FORMAT_STR.format(task_id) + data_type + ".pik" 26 | if os.path.exists(cache_path) and use_cache: 27 | print("accessing cache_path: ", cache_path) 28 | with open(cache_path, 'rb') as f: 29 | vectorized_data.append(pickle.load(f)) 30 | else: 31 | filenames = list(filter(lambda x: FORMAT_STR.format(task_id) in x and data_type in x, os.listdir(data_path))) 32 | if len(filenames) == 0: 33 | print("filename not found for in listdir for {} and {}".format(task_id, data_type)) 34 | print("skipping ... ") 35 | continue 36 | stories, sentence_max, story_max, word2id = parse_stories(os.path.join(data_path, filenames[0]), word2id) 37 | story_data.append(stories) 38 | global_sentence_max = max(global_sentence_max, sentence_max) 39 | global_story_max = max(global_story_max, story_max) 40 | 41 | if vectorized_data: 42 | return vectorized_data + [vectorized_data[0][4]] 43 | else: 44 | for i, data_type in enumerate(DATA_TYPES): 45 | print("vectorize {} ...".format(data_type)) 46 | cache_path = data_path + "-pik/" + FORMAT_STR.format(task_id) + data_type + ".pik" 47 | S, S_len, Q, A = vectorize_stories(story_data[i], global_sentence_max, global_story_max, word2id, task_id) 48 | n = int((S.shape[0] / bsz) * bsz) 49 | with open(cache_path, 'wb') as f: 50 | pickle.dump((S[:n], S_len[:n], Q[:n], A[:n], word2id), f) 51 | vectorized_data.append((S[:n], S_len[:n], Q[:n], A[:n], word2id)) 52 | return vectorized_data + [word2id] 53 | 54 | def parse_stories(filename, word2id=None): 55 | # Open file, get lines 56 | with open(filename, 'r') as f: 57 | lines = f.readlines() 58 | 59 | # Go through lines, building story sets 60 | print("go through lines") 61 | stories, story = [], [] 62 | for line in lines: 63 | line = line.strip() 64 | nid, line = line.split(' ', 1) 65 | nid = int(nid) 66 | if nid == 1: 67 | story = [] 68 | if '\t' in line: 69 | query, answer, supporting = line.split('\t') 70 | query = tokenize(query) 71 | substory = [x for x in story if x] 72 | stories.append((substory, query, answer.lower())) 73 | story.append('') 74 | else: 75 | sentence = tokenize(line) 76 | story.append(sentence) 77 | 78 | # Build Vocabulary 79 | print("build vocab") 80 | if not word2id: 81 | vocab = set(reduce(lambda x, y: x + y, [q for (_, q, _) in stories])) 82 | print("reduce done!") 83 | for (s, _, _) in stories: 84 | for sentence in s: 85 | vocab.update(sentence) 86 | for (_, _, a) in stories: 87 | vocab.add(a) 88 | id2word = ['PAD_ID'] + list(vocab) 89 | word2id = {w: i for i, w in enumerate(id2word)} 90 | 91 | # Get Maximum Lengths 92 | print("get max lengths") 93 | sentence_max, story_max = 0, 0 94 | for (s, q, _) in stories: 95 | if len(q) > sentence_max: 96 | sentence_max = len(q) 97 | if len(s) > story_max: 98 | story_max = len(s) 99 | for sentence in s: 100 | if len(sentence) > sentence_max: 101 | sentence_max = len(sentence) 102 | 103 | return stories, sentence_max, story_max, word2id 104 | 105 | def vectorize_stories(stories, sentence_max, story_max, word2id, task_id): 106 | # Check Story Max 107 | if task_id == 3: 108 | story_max = min(story_max, 130) 109 | else: 110 | story_max = min(story_max, 70) 111 | 112 | # Allocate Arrays 113 | S = np.zeros([len(stories), story_max, sentence_max], dtype=np.int32) 114 | Q = np.zeros([len(stories), sentence_max], dtype=np.int32) 115 | S_len, A = np.zeros([len(stories)], dtype=np.int32), np.zeros([len(stories)], dtype=np.int32) 116 | 117 | # Fill Arrays 118 | for i, (s, q, a) in enumerate(stories): 119 | # Check S Length => All but Task 3 are limited to 70 sentences 120 | if task_id == 3: 121 | s = s[-130:] 122 | else: 123 | s = s[-70:] 124 | 125 | # Populate story 126 | for j in range(len(s)): 127 | for k in range(len(s[j])): 128 | S[i][j][k] = word2id[s[j][k]] 129 | 130 | # Populate story length 131 | S_len[i] = len(s) 132 | 133 | # Populate Question 134 | for j in range(len(q)): 135 | Q[i][j] = word2id[q[j]] 136 | 137 | # Populate Answer 138 | A[i] = word2id[a] 139 | 140 | return S, S_len, Q, A 141 | 142 | def tokenize(sentence): 143 | """ 144 | Tokenize a string by splitting on non-word characters and stripping whitespace. 145 | """ 146 | return [token.strip().lower() for token in re.split(SPLIT_RE, sentence) if token.strip()] -------------------------------------------------------------------------------- /tpr_rnn_graph.py: -------------------------------------------------------------------------------- 1 | # This file loads the data, builds the graph, and provides several other functionality. 2 | # It is made with the use of a jupyter-console/notebook in mind for experimentation. 3 | # The current configuration is the all-task model.# 4 | # Use one of the other files in order to train from scratch or analyse a trained model. 5 | # 6 | 7 | import tensorflow as tf 8 | import numpy as np 9 | import pprint 10 | import types 11 | import copy 12 | import time 13 | import sys 14 | import os 15 | 16 | from preprocessor.reader import parse 17 | from lib import * 18 | 19 | 20 | ###### Hyper Parameters ------------------ 21 | c = types.SimpleNamespace() 22 | # user input 23 | c.task_id = int(sys.argv[1]) 24 | c.log_keyword = str(sys.argv[2]) 25 | 26 | # data loading (necessary for task specific symbol_size parameter) 27 | c.data_path = "tasks/en-valid" + "-10k" 28 | raw_train, raw_valid, raw_test, word2id = parse(c.data_path, c.task_id) 29 | id2word = {word2id[k]:k for k in word2id.keys()} 30 | c.vocab_size = len(word2id) 31 | 32 | # model parameters 33 | c.symbol_size = c.vocab_size 34 | c.entity_size = 90 35 | c.hidden_size = 40 36 | c.role_size = 20 37 | c.init_limit = 0.10 38 | c.LN = True 39 | 40 | # optimizer 41 | c.learning_rate = 0.001 42 | c.beta1 = 0.9 43 | c.beta2 = 0.999 44 | c.max_gradient_norm = 5.0 45 | c.do_warm_up = True 46 | c.warm_up_steps = 50 47 | c.warm_up_factor = 0.1 48 | c.do_decay = True 49 | c.decay_thresh = 0.1 50 | c.decay_factor = 0.5 51 | 52 | # other 53 | c.log_folder = "logs/{}/{}/".format(c.log_keyword, str(c.task_id)) 54 | 55 | 56 | ###### Experiment Setup ------------------ 57 | logger = init_logger(c.log_folder) # will not log tensorflow outputs like GPU infos and warnings 58 | log = lambda *x: logger.debug((x[0].replace('{','{{').replace('}','}}') + "{} " * (len(x)-1)).format(*x[1:])) 59 | 60 | 61 | ###### Load Data ------------------ 62 | batch_size = tf.placeholder(tf.int64) # dynamic batch_size 63 | 64 | p = np.random.permutation(len(raw_train[0])) # random permutation of the train_data 65 | train_data = tf.data.Dataset.from_tensor_slices((raw_train[0][p],raw_train[1][p],raw_train[2][p],raw_train[3][p])).cache().repeat().batch(batch_size) 66 | valid_data = tf.data.Dataset.from_tensor_slices((raw_valid[0],raw_valid[1],raw_valid[2],raw_valid[3])).cache().repeat().batch(batch_size) 67 | test_data = tf.data.Dataset.from_tensor_slices((raw_test[0],raw_test[1],raw_test[2],raw_test[3])).cache().repeat().batch(batch_size) 68 | 69 | train_iterator = train_data.make_initializable_iterator() 70 | valid_iterator = valid_data.make_initializable_iterator() 71 | test_iterator = test_data.make_initializable_iterator() 72 | 73 | train_batch = train_iterator.get_next() 74 | valid_batch = valid_iterator.get_next() 75 | test_batch = test_iterator.get_next() 76 | 77 | train_epoch_size = raw_train[0].shape[0] 78 | valid_epoch_size = raw_valid[0].shape[0] 79 | test_epoch_size = raw_test[0].shape[0] 80 | 81 | # some task specific data attributes 82 | max_story_length = np.max(raw_train[1]) 83 | max_sentences = raw_train[0].shape[1] 84 | max_sentence_length = raw_train[0].shape[2] 85 | max_query_length = raw_train[2].shape[1] 86 | 87 | # full valid and test data requires too much memory for a single batch 88 | valid_steps = 73 89 | valid_batch_size = valid_epoch_size / 73 # 274 90 | test_steps = 20 91 | test_batch_size = test_epoch_size / 20 # 1000 92 | 93 | 94 | ###### Print Run Config --------- 95 | log("Configuration:") 96 | log(pprint.pformat(c.__dict__)) 97 | log("") 98 | 99 | 100 | ###### Graph Structure --------- 101 | with tf.variable_scope("hyper_params", reuse=None, dtype=tf.float32): 102 | # we have dynamic hyper parameters 103 | _learning_rate = tf.get_variable("learning_rate", shape=[], trainable=False) 104 | _beta1 = tf.get_variable("beta1", shape=[], trainable=False) 105 | _beta2 = tf.get_variable("beta2", shape=[], trainable=False) 106 | 107 | # op to set hyper parameters 108 | hyper_param_init = [ 109 | _learning_rate.assign(c.learning_rate), 110 | _beta1.assign(c.beta1), 111 | _beta2.assign(c.beta2) 112 | ] 113 | 114 | with tf.variable_scope("inputs"): 115 | # story shape: [batch_size, max_sentences, max_sentence_length] 116 | story = tf.placeholder(dtype=tf.int32, shape=[None, None, None], name='story') # [batch_size, sentences, words] 117 | story_length = tf.placeholder(dtype=tf.int32, shape=[None], name='story_length') # [batch_size] 118 | query = tf.placeholder(dtype=tf.int32, shape=[None, None], name='query') # [batch_size, words] 119 | answer = tf.placeholder(dtype=tf.int32, shape=[None], name='answer') # [batch_size] 120 | 121 | _batch_size = tf.shape(story)[0] 122 | sentence_length = tf.shape(story)[2] 123 | 124 | with tf.variable_scope("variables"): # Note, the MLP weights are created in the MLP function. 125 | # initialize the embeddings 126 | word_embedding = tf.get_variable(name="word_embedding", 127 | shape=[c.vocab_size, c.symbol_size], 128 | initializer=uniform_init(c.init_limit)) 129 | position = tf.get_variable(name="story_position_embedding", 130 | shape=[max_sentence_length, c.symbol_size], # [words, symbol_size] 131 | initializer=ones_init()) / max_sentence_length 132 | Z = tf.get_variable(name="output_embedding", 133 | shape=[c.entity_size, c.vocab_size], 134 | initializer=uniform_glorot_init(c.entity_size, c.vocab_size)) # output projection Z, final transformation 135 | 136 | # initial state of the TPR 137 | TPR_init = tf.zeros((_batch_size, c.entity_size, c.role_size, c.entity_size)) 138 | 139 | # we use a cell and dynamic_rnn instead of a tf.while_loop due to its dynamic squence_length capability 140 | loopCell = LoopCell((_batch_size, c.entity_size, c.role_size, c.entity_size)) 141 | 142 | with tf.variable_scope("model"): 143 | with tf.variable_scope("update_module"): 144 | # we had problems with embedding_lookup and the optimizer implementation. 145 | # [batch_size, sentences, words, embedding_size] 146 | #sentence_emb = tf.nn.embedding_lookup(params=word_embedding, ids=story) 147 | sentence_emb = tf.einsum('bswv,ve->bswe', tf.one_hot(story, depth=c.vocab_size), word_embedding) 148 | # [batch_size, words, embedding_size] 149 | #query_emb = tf.nn.embedding_lookup(params=word_embedding, ids=query) 150 | query_emb = tf.einsum('bwv,ve->bwe', tf.one_hot(query, depth=c.vocab_size), word_embedding) 151 | 152 | # summing over the words of a sentence into sentence representations 153 | # [batch_size, sentences, embedding_size] 154 | sentence_sum = tf.einsum('bswe,we->bse', sentence_emb, position) # eq. 5 for a normal sentence 155 | # [batch_size, embedding_size] 156 | query_sum = tf.einsum('bwe,we->be', query_emb, position) # eq.5 for the question sentence 157 | 158 | # Five MLPs that extract the entity and relation representations 159 | e1, e2 = MLP(sentence_sum, n_networks=2, equation='bse,er->bsr', input_size=c.symbol_size, 160 | hidden_size=c.hidden_size, output_size=c.entity_size, scope="story_entity") 161 | r1, r2, r3 = MLP(sentence_sum, n_networks=3, equation='bse,er->bsr', input_size=c.symbol_size, 162 | hidden_size=c.hidden_size, output_size=c.role_size, scope="story_roles") 163 | 164 | # compute part of the tensor update outside the loop for efficency 165 | # (b)atch, (s)tory, (r)ole, (f)iller 166 | partial_add_W = tf.einsum('bsr,bsf->bsrf', r1, e2) # part of eq.10 167 | partial_add_B = tf.einsum('bsr,bsf->bsrf', r3, e1) # part of eq.14 168 | 169 | # perform loop operation using dynamic_rnn (so we can exploit dynamic sequence lengths) 170 | def body(inputs, TPR): 171 | e1, r1, partial_add_W, e2, r2, partial_add_B, r3 = inputs 172 | # e1 and e2 are [batch_size, entity_size] 173 | # r1 and r2 are [batch_size, role_size] 174 | # TPR is [batch_size, entity_size, role_size, entity_size] 175 | 176 | w_hat = (tf.einsum('be,br,berf->bf', e1, r1, TPR)) # eq. 9 177 | partial_remove_W = tf.einsum('br,bf->brf', r1, w_hat) # part of eq.10 178 | 179 | m_hat = (tf.einsum('be,br,berf->bf', e1, r2, TPR)) # eq. 11 180 | partial_remove_M = tf.einsum('br,bf->brf', r2, m_hat) # part of eq.12 181 | 182 | partial_add_M = tf.einsum('br,bf->brf', r2, w_hat) # part of eq.12 183 | 184 | b_hat = (tf.einsum('be,br,berf->bf', e2, r3, TPR)) # eq. 13 185 | partial_remove_B = tf.einsum('br,bf->brf', r3, b_hat) # part of eq.14 186 | 187 | # tensor product obeys a distributive law with the direct sum operation 188 | # this allows for a more efficient implementation 189 | # we first add the ops before we go from order 2 to order 3 190 | write_op = partial_add_W - partial_remove_W 191 | move_op = partial_add_M - partial_remove_M 192 | backlink_op = partial_add_B - partial_remove_B 193 | delta_F = tf.einsum('be,brf->berf', e1, write_op + move_op) \ 194 | + tf.einsum('be,brf->berf', e2, backlink_op) # eq. 6 195 | 196 | # direct sum of the old state with the new ones. Removes old associations and adds new ones. 197 | TPR += delta_F # eq. 4 198 | 199 | return [delta_F], TPR 200 | 201 | # we set the body of our empty loop cell and make use of 202 | # the dynamic sequence_length capability of dynamic_rnn 203 | loopCell.call = body 204 | inputs = (e1, r1, partial_add_W, e2, r2, partial_add_B, r3) # all input tensors are already batch major 205 | _, TPR = tf.nn.dynamic_rnn(loopCell, inputs, initial_state=TPR_init, sequence_length=story_length) 206 | 207 | with tf.variable_scope("inference_module"): 208 | # for the question we use the same sentence encoding but different MLPs (these are used in the inference module) 209 | q_e1, q_e2 = MLP(query_sum, n_networks=2, equation='be,er->br', input_size=c.symbol_size, 210 | hidden_size=c.hidden_size, output_size=c.entity_size, scope="query_entity") 211 | q_r1, q_r2, q_r3 = MLP(query_sum, n_networks=3, equation='be,er->br', input_size=c.symbol_size, 212 | hidden_size=c.hidden_size, output_size=c.role_size, scope="query_roles") 213 | 214 | ## compute question answer ((b)atch, (e)ntity, (r)ole, (f)iller, (q)ueries) 215 | # simple association 216 | one_step_raw = tf.einsum('be,br,berf->bf', q_e1, q_r1, TPR) 217 | i_1 = norm(one_step_raw, active=c.LN, scope="one_step") # eq. 17 218 | 219 | # transitive inference 220 | two_step_raw = tf.einsum('be,br,berf->bf', i_1, q_r2, TPR) 221 | i_2 = norm(two_step_raw, active=c.LN, scope="two_step") # eq. 18 222 | 223 | # third step 224 | three_step_raw = tf.einsum('be,br,berf->bf', i_2, q_r3, TPR) 225 | i_3 = norm(three_step_raw, active=c.LN, scope="three_step") # eq. 19 226 | 227 | # it is possible to do some gating but it doesn't give any improvement. 228 | step_sum = i_1 + i_2 + i_3 229 | logits = tf.einsum('bf,fl->bl', step_sum, Z, name="logits") # projection into symbol space, eq. 20 230 | 231 | with tf.variable_scope("outputs"): 232 | costs = tf.losses.sparse_softmax_cross_entropy(labels=answer, logits=logits, reduction='none') # [batch_size, queries] 233 | cost = tf.reduce_mean(costs) # scalar 234 | 235 | predictions = tf.argmax(logits, axis=-1, output_type=tf.int32, name="predictions") # [batch_size] 236 | 237 | correct = tf.cast(tf.equal(answer, predictions), tf.float32, name="correct") 238 | accuracy = tf.reduce_mean(correct, name="accuracy") # [] 239 | 240 | log("all trainable tensorflow variables:") 241 | log(pprint.pformat(tf.trainable_variables())) 242 | log("total number of trainable parameters: ", get_total_trainable_parameters()) 243 | log("") 244 | 245 | ###### Optimizer --------- 246 | with tf.variable_scope("optimizer"): 247 | optimizer = tf.contrib.opt.NadamOptimizer(_learning_rate, beta1=_beta1, beta2=_beta2) 248 | trainable_vars = tf.trainable_variables() 249 | gradients = tf.gradients(cost, trainable_vars) 250 | global_norm = tf.global_norm(gradients) # compute global norm 251 | clipped_global_norm = tf.where(tf.is_nan(global_norm), c.max_gradient_norm, global_norm) # clip NaN gradients to max norm 252 | clipped_gradients, gradient_norm = tf.clip_by_global_norm(t_list=gradients, 253 | clip_norm=c.max_gradient_norm, 254 | use_norm=clipped_global_norm) # uses this norm instead of computing it 255 | train_op = optimizer.apply_gradients(zip(clipped_gradients, trainable_vars)) 256 | 257 | 258 | ###### Session and Other--------- 259 | with tf.variable_scope("session_and_other"): 260 | saver = tf.train.Saver() 261 | merged_summaries = tf.summary.merge_all() 262 | writer = tf.summary.FileWriter(c.log_folder, graph=tf.get_default_graph()) 263 | config = tf.ConfigProto() 264 | config.gpu_options.allow_growth=True 265 | #config.operation_timeout_in_ms=60000 266 | sess = tf.Session(config=config) 267 | sess.run(tf.global_variables_initializer()) 268 | sess.run(hyper_param_init) 269 | 270 | # it seems like we have to initialize the iterators here but they are reinitialized when train() or eval() is called 271 | sess.run(train_iterator.initializer, {batch_size: 1}) 272 | sess.run(test_iterator.initializer, {batch_size: 1}) 273 | sess.run(valid_iterator.initializer, {batch_size: 1}) 274 | 275 | 276 | ###### Helper Functions --------- 277 | _total_steps = 0 278 | _start_time = 0 279 | _best_valid_acc = 0.0 280 | _decay_done = False 281 | 282 | # train and performance functions 283 | def get_feed_dic(batch): 284 | feed_dic = { 285 | story: batch[0], 286 | story_length: batch[1], 287 | query: batch[2], 288 | answer: batch[3], 289 | } 290 | return feed_dic 291 | 292 | def train(steps=1000000, bs=128, terminal_log_every=50, validate_every=200): 293 | global _total_steps 294 | global _start_time 295 | global _decay_done 296 | # --- 297 | # reinitializing the iterator resets it to the first batch 298 | sess.run(train_iterator.initializer, {batch_size: bs}) 299 | 300 | _start_time = time.time() 301 | _prev_time = _start_time 302 | _prev_step = 0 303 | 304 | acc_sum, cost_sum, valid_acc, valid_cost = 0, 0, 0, 999 305 | 306 | for i in range(steps): 307 | # reduce learning rate for warm up period 308 | if i < c.warm_up_steps and c.do_warm_up: 309 | sess.run(_learning_rate.assign(c.learning_rate * c.warm_up_factor)) 310 | 311 | if i > c.warm_up_steps and c.do_warm_up: 312 | sess.run(_learning_rate.assign(c.learning_rate)) 313 | 314 | # decay learning once 315 | if valid_cost < c.decay_thresh and c.do_decay and not _decay_done: 316 | saver.save(sess, os.path.join(c.log_folder, "preReduction.ckpt")) 317 | sess.run(_learning_rate.assign(c.learning_rate * c.decay_factor)) 318 | _decay_done = True # prevent further decays 319 | 320 | # get a batch and run a step 321 | batch = sess.run(train_batch) # batch size is defined by the iterator initialization 322 | feed_dic = get_feed_dic(batch) 323 | query_dic = { 324 | _learning_rate.name: _learning_rate, 325 | train_op.name: train_op, 326 | cost.name: cost, 327 | accuracy.name: accuracy, 328 | gradient_norm.name: gradient_norm 329 | } 330 | result = sess.run(query_dic, feed_dic) 331 | cost_sum += result[cost.name] 332 | acc_sum += result[accuracy.name] 333 | 334 | # if we have NaN values during the warm_up phase we restart training 335 | if np.isnan(result[cost.name]) and i < c.warm_up_steps: 336 | log("NaN cost during warmup, reinitializing.") 337 | _total_steps = 0 338 | i = 0 339 | sess.run(tf.global_variables_initializer()) 340 | sess.run(tf.local_variables_initializer()) 341 | sess.run(hyper_param_init) 342 | continue 343 | 344 | # print and log the current training performance 345 | if (_total_steps % terminal_log_every == 0) or np.isnan(result[cost.name]) or _total_steps < c.warm_up_steps: 346 | # track speed 347 | epochs_seen = (_total_steps * bs) / train_epoch_size 348 | time_done = (time.time() - _prev_time) / 60.0 # in minutes 349 | step_speed = (i - _prev_step) / time_done 350 | stories_speed = step_speed * bs 351 | _prev_time = time.time() 352 | _prev_step = i 353 | 354 | log("{:4}: cost={:6.4f}, accuracy={:7.4f}, norm={:06.3f}, lr={:.4f} (epochs={:.1f}, steps/min={:2.0f}, stories/min={:2.0f})".format( 355 | _total_steps, result[cost.name], result[accuracy.name], result[gradient_norm.name], result[_learning_rate.name], 356 | epochs_seen, step_speed, stories_speed)) 357 | 358 | writer.add_summary(make_summary(tag="train/cost", value=result[cost.name]), _total_steps) 359 | writer.add_summary(make_summary(tag="train/accuracy", value=result[accuracy.name]), _total_steps) 360 | 361 | # print and log the current validation and test performance 362 | if _total_steps % validate_every == 0 and _total_steps != 0: 363 | log("") 364 | log("task: ", c.task_id) 365 | valid_cost, valid_acc = eval(prefix="valid", batch_source=valid_batch, steps=valid_steps, bs=valid_batch_size) 366 | test_cost, test_acc = eval(prefix="test", batch_source=test_batch, steps=test_steps, bs=test_batch_size) 367 | 368 | writer.add_summary(make_summary(tag="valid/cost", value=valid_cost), _total_steps) 369 | writer.add_summary(make_summary(tag="valid/accuracy", value=valid_acc), _total_steps) 370 | writer.add_summary(make_summary(tag="test/cost", value=test_cost), _total_steps) 371 | writer.add_summary(make_summary(tag="test/accuracy", value=test_acc), _total_steps) 372 | 373 | log("Total time passed: {:.1f} min".format((time.time() - _start_time) / 60.0)) 374 | log("") 375 | 376 | if valid_acc >= _best_valid_acc: 377 | saver.save(sess, os.path.join(c.log_folder, "model.ckpt")) 378 | 379 | _total_steps += 1 380 | 381 | return valid_acc, valid_cost, train_acc, train_cost 382 | 383 | def eval(steps=1, prefix="", batch_source=valid_batch, bs=valid_epoch_size): 384 | acc_sum, cost_sum = 0.0, 0.0 385 | etime = time.time() 386 | 387 | sess.run(valid_iterator.initializer, {batch_size: bs}) 388 | sess.run(test_iterator.initializer, {batch_size: bs}) 389 | 390 | for j in range(steps): 391 | batch = sess.run(batch_source) 392 | feed_dic = get_feed_dic(batch) 393 | 394 | query_dic = { 395 | cost.name: cost, 396 | accuracy.name: accuracy, 397 | } 398 | result = sess.run(query_dic, feed_dic) 399 | cost_sum += result[cost.name] 400 | acc_sum += result[accuracy.name] 401 | 402 | # the following fetch is only to trigger caching and have the cache warning go away 403 | try: 404 | sess.run(batch_source) 405 | except tf.errors.OutOfRangeError: # single-task eval doesn't have repeat() 406 | pass 407 | 408 | n_stories = steps * bs 409 | eval_time = (time.time() - etime) / 60.0 410 | log("{} evaluation: cost={:.4f}, accuracy={:.4f} ({} stories in {:.1f} min)".format( 411 | prefix, cost_sum / steps, acc_sum / steps, n_stories, eval_time)) 412 | return cost_sum / steps, acc_sum / steps 413 | 414 | def full_eval(): 415 | eval(prefix="valid", batch_source=valid_batch, steps=valid_steps, bs=valid_batch_size) 416 | eval(prefix="test", batch_source=test_batch, steps=test_steps, bs=test_batch_size) 417 | 418 | 419 | # print functions 420 | def translate(nparr, id2word=id2word): 421 | assert (type(nparr) == np.ndarray), "You can only translate numpy arrays" 422 | old_shape = nparr.shape 423 | arr = np.reshape(nparr, (-1)) 424 | arr = np.asarray([id2word[x] for x in arr]) 425 | arr = np.reshape(arr, old_shape) 426 | as_string = np.apply_along_axis(lambda x: " ".join(x), axis=-1, arr=arr) 427 | return as_string 428 | 429 | def show_random_sample(): 430 | idx = np.random.randint(2000) 431 | batch = [raw_test[0][np.newaxis,idx,:,:], 432 | raw_test[1][np.newaxis,idx], 433 | raw_test[2][np.newaxis,idx,:], 434 | raw_test[3][np.newaxis,idx]] 435 | feed_dic = get_feed_dic(batch) 436 | 437 | tensors = [story, query, answer, predictions] 438 | query_dic = {t.name:t for t in tensors} 439 | res_dic = sess.run(query_dic, feed_dic) 440 | for t in tensors: 441 | print("{}: ".format(t.name)) 442 | print(pprint.pformat(translate(res_dic[t.name]))) 443 | print() 444 | 445 | 446 | # functions to evaluate on all tasks 447 | eval_valid_test = [] 448 | 449 | def transform_task(task_id): 450 | _train, _valid, _test, old_dic = parse(c.data_path, task_id) 451 | new_train = copy.deepcopy(_train) 452 | new_valid = copy.deepcopy(_valid) 453 | new_test = copy.deepcopy(_test) 454 | 455 | for key in sorted(list(old_dic.keys())): 456 | for i in [0,2,3]: 457 | new_train[i][ _train[i] == old_dic[key]] = word2id[key] 458 | new_valid[i][ _valid[i] == old_dic[key]] = word2id[key] 459 | new_test[i][ _test[i] == old_dic[key]] = word2id[key] 460 | 461 | def pad(A): 462 | out = [] 463 | _maxS = 130 464 | _maxW = 12 465 | out.append(np.zeros((A[0].shape[0], _maxS, _maxW))) 466 | out[0][:, :A[0].shape[1], :A[0].shape[2]] = A[0] 467 | 468 | out.append(A[1]) 469 | 470 | out.append(np.zeros((A[2].shape[0], _maxW))) 471 | out[2][:, :A[2].shape[1]] = A[2] 472 | 473 | out.append(A[3]) 474 | return out 475 | 476 | sets = [] 477 | sets.append(pad(new_train)) 478 | sets.append(pad(new_valid)) 479 | sets.append(pad(new_test)) 480 | sets.append(word2id) 481 | return sets 482 | 483 | def make_eval_tensors(): 484 | global eval_valid_test 485 | eval_valid_test = [] 486 | for i in range(1,21): 487 | _, raw_valid, raw_test, _ = transform_task(i) 488 | valid_epoch_size = raw_valid[0].shape[0] 489 | test_epoch_size = raw_test[0].shape[0] 490 | 491 | valid_data = tf.data.Dataset.from_tensor_slices((raw_valid[0],raw_valid[1],raw_valid[2],raw_valid[3])).batch(valid_epoch_size) 492 | test_data = tf.data.Dataset.from_tensor_slices((raw_test[0],raw_test[1],raw_test[2],raw_test[3])).batch(test_epoch_size) 493 | 494 | valid_batch = valid_data.make_one_shot_iterator().get_next() 495 | test_batch = test_data.make_one_shot_iterator().get_next() 496 | 497 | eval_valid_test.append((valid_batch, valid_epoch_size, test_batch, test_epoch_size)) 498 | 499 | def eval_every_task(): 500 | for idx, (valid_batch, valid_epoch_size, test_batch, test_epoch_size) in enumerate(eval_valid_test): 501 | prefix = "valid_task_" + str(idx + 1) 502 | eval(prefix=prefix, batch_source=valid_batch, steps=1, bs=valid_epoch_size) 503 | print() 504 | 505 | for idx, (valid_batch, valid_epoch_size, test_batch, test_epoch_size) in enumerate(eval_valid_test): 506 | prefix = "test_task_" + str(idx + 1) 507 | eval(prefix=prefix, batch_source=test_batch, steps=1, bs=test_epoch_size) 508 | print() 509 | 510 | --------------------------------------------------------------------------------