├── README.md ├── code ├── models │ ├── blocks.py │ ├── estimator.py │ ├── estimator_c.py │ ├── network.py │ ├── network_c.py │ ├── replay_buffer.py │ ├── replay_buffer_c.py │ └── segment_tree.py ├── run_model.py ├── run_model_c.py ├── run_rpc_server.py ├── setup.py └── utils │ ├── analytics.py │ ├── data_processing.py │ ├── io_utils.py │ ├── io_utils_c.py │ ├── rpc_client.py │ ├── tree_navigation.py │ └── tree_navigation_c.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # Learning to Search in Long Documents Using Document Structure 2 | This repo contains the code used in our paper [https://arxiv.org/abs/1806.03529](https://arxiv.org/abs/1806.03529). 3 | The code includes a framework for training and evaluation of the DocQN and DQN models on TriviaQA-NoP, our version of [TriviaQA](http://nlp.cs.washington.edu/triviaqa/) where documents represented as tree objects. 4 | The data is available for download [here](https://www.cs.tau.ac.il/~taunlp/triviaqa-nop/triviaqa-nop.gz). 5 | 6 | There are two code versions included, of the two baselines in the paper - full and coupled. 7 | The full models leverage RaSoR predictions during navigation, while the coupled models do not. 8 | All files ending with `_c.py` belong to the coupled version. 9 | 10 | 11 | ## Setup 12 | The code requires python >= 3.5, tensorflow 1.3, and several other supporting libraries. 13 | Tensorflow should be installed separately following the docs. To install the other dependencies use: 14 | ```bash 15 | $ pip install -r requirements.txt 16 | ``` 17 | Once the environment is set, you can download and extract the data by running the setup script: 18 | ```bash 19 | $ python setup.py 20 | ``` 21 | 22 | Loading the data into memory requires at least 34GB RAM, where additional amount that depends on the replay memory size is required for training. To allow memory-efficient execution, which supports multiple executions in parallel, we run an RPC server that holds a single copy of the data in memory. 23 | Running the RPC server is a requirement for the full models, and an option for the coupled models. To use it, [RabbitMQ](https://www.rabbitmq.com/install-debian.html) must be installed. 24 | 25 | The code can run both on GPU and CPU devices. 26 | 27 | 28 | ## Data 29 | TriviaQA-NoP comprises of dataset files and preprocessed files that are needed for code execution. 30 | By running the setup script, as described above, all files will be downloaded and extracted into the `data` folder. 31 | 32 | ### TriviaQA-NoP dataset 33 | The raw data is compressed in the `triviaqa-nop.gz` file, which comprises of raw evidence files without the preface section and their corresponding tree objects. In addition, the train, dev and test sets of TriviaQA (`json` files) 34 | updated to the evidence files in TriviaQA-NoP. 35 | 36 | ### Preprocessed files 37 | These include vocabulary and word embeddings based on [GloVe](https://nlp.stanford.edu/projects/glove/), per-paragraph [RaSoR](https://github.com/shimisalant/RaSoR) predictions, and "evidence dictionary" of question-evidence pairs that holds the data (tree objects, tokenized evidence files, etc.) to be loaded into memory during training and evaluation. 38 | The `.exp.pkl` files under `data/qa` are an expanded version of the datasets (`json` files), where each sample of question and multiple evidences is broken into multiple question-evidence pairs. 39 | 40 | 41 | ## Running the RPC server 42 | This step is a requirement for training and evaluation of DocQN/DQN models that use RaSoR predictions during navigation (e.g. the full models). For the coupled models, it is optional. To run the RPC server, execute the following command: 43 | ```bash 44 | $ python run_rpc_server.py 45 | ``` 46 | 47 | It will start the server, which will keep running until being shut down (with Ctrl+C). 48 | 49 | 50 | ## Training 51 | 52 | Use `run_model[_c].py` for training as follows: 53 | 54 | ```bash 55 | $ PYTHONHASHSEED=[seed] python run_model[_c].py --train 56 | ``` 57 | 58 | Where [seed] is an integer that python's hash seed will be fixed to. We set up the PYTHONHASHSEED environment variable in this way, due to a usage of the python hash function in the code. Fixing PYTHONHASHSEED guarantees a consistent hash function across different executions and machines. 59 | 60 | In order to use the RPC server in the coupled version, add the flag `--use_rpc`. There are plenty of configuration options that can be listed with the `--help` menu. One important argument is `--train_protocol` which controls the tree sampling method during training. Specifically, for training of DocQN, run: 61 | ```bash 62 | $ PYTHONHASHSEED=[seed] python run_model[_c].py --train --train_protocol combined_ans_radius 63 | ``` 64 | and for training of DQN, run: 65 | ```bash 66 | $ PYTHONHASHSEED=[seed] python run_model[_c].py --train --train_protocol sequential 67 | ``` 68 | 69 | During training, metrics of the navigation performance will be output, including navigation accuracy ('avg_acc'). 70 | Model checkpoints and logs will be stored under the `models` and `logs` folders, accordingly, where a unique id is generated for every model. 71 | 72 | It is possible to resume training by using the `--resume` argument, together with `--model_id` and `--model_step`. Notice that the reply memory will be re-initialized in this case. 73 | 74 | 75 | ## Evaluation 76 | Use `run_model[_c].py` for evaluation as follows: 77 | 78 | ```bash 79 | $ PYTHONHASHSEED=[seed] python run_model[_c].py --evaluate --model_id [id] --model_best 80 | ``` 81 | 82 | For evaluation of a specific checkpoint, use `--model_step [step]` instead of `--model_best`. 83 | This will evaluate the model on the development set of TriviaQA-NoP, and output two files: 84 | * `logs/[model_id]_[model_step]_dev_output.json` - contains the selected paragraph for every question-evidence pair 85 | in a [SQuAD format](https://rajpurkar.github.io/SQuAD-explorer/), that can be given as input to RaSoR (or any other reading comprehension model). 86 | * `logs/[model_id]_[model_step]_dev_dbg.log` - a full navigation log, containing a description of the steps 87 | performed by the model for every question-evidence pair 88 | 89 | To obtain predictions for the test set, run: 90 | ```bash 91 | $ PYTHONHASHSEED=[seed] python run_model[_c].py --test --model_id [id] --model_best 92 | ``` 93 | or: 94 | ```bash 95 | $ PYTHONHASHSEED=[seed] python run_model[_c].py --test --model_id [id] --model_step [step] 96 | ``` 97 | 98 | Final answer predictions per question were obtained by running a version of [this implementation of RaSoR](https://github.com/shimisalant/RaSoR) on the model's output, and aggregating the predictions of multiple question evidences. Currently, we are not publishing this version of RaSoR. 99 | 100 | Please feel welcome to contact us for further details and resources. 101 | 102 | 103 | ## Pre-Trained Models 104 | We release four pre-trained models: 105 | * DocQN - 1524410969.1593015 106 | * DQN - 1524411144.512193 107 | * DocQN coupled - 1517547376.1149364 108 | * DQN coupled - 1518010594.2258544 109 | 110 | The models can be downloaded from [this](https://www.cs.tau.ac.il/~taunlp/triviaqa-nop/triviaqa-nop-pretrained-models.gz) link, and should be extracted to the `models` folder in the root directory. 111 | Training and evaluation of these models were initiated with `PYTHONHASHSEED=1618033988`. 112 | -------------------------------------------------------------------------------- /code/models/blocks.py: -------------------------------------------------------------------------------- 1 | ###################################################################### 2 | # Basic Network Building Blocks 3 | # 4 | # Some of the functions are taken from the open source repositories 5 | # of TensorFlow and OpenAI, as well as from Danijar Hafner blog. 6 | # - https://github.com/openai/baselines 7 | # - https://danijar.com/variable-sequence-lengths-in-tensorflow/ 8 | # 9 | ###################################################################### 10 | 11 | 12 | ################################### 13 | # Imports 14 | # 15 | 16 | import tensorflow as tf 17 | import numpy as np 18 | 19 | 20 | ################################### 21 | # Functions 22 | # 23 | 24 | def seqlen(sequence): 25 | """ 26 | Get true length of sequences (without padding), and mask for true-length in max-length. 27 | Input of shape: (batch_size, max_seq_length, hidden_dim) 28 | Output shapes, 29 | length: (batch_size) 30 | mask: (batch_size, max_seq_length, 1) 31 | """ 32 | with tf.name_scope('seq_len'): 33 | populated = tf.sign(tf.abs(sequence)) 34 | length = tf.cast(tf.reduce_sum(populated, axis=1), tf.int32) 35 | mask = tf.cast(tf.expand_dims(populated, -1), tf.float32) 36 | 37 | return length, mask 38 | 39 | 40 | def pathlen(path): 41 | with tf.name_scope('path_len'): 42 | populated = tf.sign(tf.reduce_sum(tf.abs(path), 2)) 43 | length = tf.cast(tf.reduce_sum(populated, 1), tf.int32) 44 | mask = tf.cast(populated, tf.float32) 45 | 46 | return length, mask 47 | 48 | 49 | def get_padded(seq, seq_len, val): 50 | return np.pad(seq, (0, seq_len-len(seq)), 'constant', constant_values=(val,)) 51 | 52 | 53 | def last_relevant(output, length): 54 | with tf.name_scope('last_relevant'): 55 | batch_size = tf.shape(output)[0] 56 | max_length = tf.shape(output)[1] 57 | index = tf.range(0, batch_size) * max_length + (length - 1) 58 | flat = tf.reshape(output, [-1, 1]) 59 | relevant = tf.squeeze(tf.gather(flat, index)) 60 | 61 | return relevant 62 | 63 | 64 | def mask_loss(total_loss, mask): 65 | with tf.name_scope('mask_loss'): 66 | loss = tf.reduce_sum(total_loss * mask, axis=1) 67 | 68 | return loss 69 | 70 | 71 | def huber_loss(x, delta=1.0): 72 | """Reference: https://en.wikipedia.org/wiki/Huber_loss""" 73 | return tf.where( 74 | tf.abs(x) < delta, 75 | tf.square(x) * 0.5, 76 | delta * (tf.abs(x) - 0.5 * delta) 77 | ) 78 | 79 | 80 | def minimize_and_clip(optimizer, objective, var_list, clip_val=10): 81 | """Minimized `objective` using `optimizer` w.r.t. variables in 82 | `var_list` while ensure the norm of the gradients for each 83 | variable is clipped to `clip_val` 84 | """ 85 | with tf.variable_scope("minimize_and_clip"): 86 | gradients = optimizer.compute_gradients(objective, var_list=var_list) 87 | for i, (grad, var) in enumerate(gradients): 88 | if grad is not None: 89 | gradients[i] = (tf.clip_by_norm(grad, clip_val), var) 90 | grad_norm = tf.global_norm([g[0] for g in gradients]) 91 | return optimizer.apply_gradients(gradients), grad_norm 92 | 93 | 94 | def birnn(inputs, dim, keep_prob, seq_len, name): 95 | with tf.name_scope(name): 96 | with tf.variable_scope('forward' + name): 97 | cell_fwd = tf.contrib.rnn.LSTMCell(num_units=dim) 98 | # cell_fwd = tf.nn.rnn_cell.DropoutWrapper(cell_fwd, output_keep_prob=keep_prob) 99 | with tf.variable_scope('backward' + name): 100 | cell_bwd = tf.contrib.rnn.LSTMCell(num_units=dim) 101 | # cell_bwd = tf.nn.rnn_cell.DropoutWrapper(cell_bwd, output_keep_prob=keep_prob) 102 | 103 | outputs, states = tf.nn.bidirectional_dynamic_rnn(cell_fw=cell_fwd, cell_bw=cell_bwd, inputs=inputs, 104 | sequence_length=seq_len, dtype=tf.float32, scope=name) 105 | return outputs, states 106 | 107 | 108 | def rnn(inputs, dim, keep_prob, seq_lens, name): 109 | with tf.name_scope(name): 110 | cell_fwd = tf.contrib.rnn.LSTMCell(num_units=dim) 111 | # cell_fwd = tf.nn.rnn_cell.DropoutWrapper(cell_fwd, output_keep_prob=keep_prob) 112 | outputs, state = tf.nn.dynamic_rnn(cell=cell_fwd, inputs=inputs, 113 | sequence_length=seq_lens, dtype=tf.float32, scope=name) 114 | 115 | return outputs, state 116 | 117 | 118 | # based on code from: 119 | # https://github.com/google/seq2seq/blob/master/seq2seq/decoders/attention.py 120 | def attention_layer(query, keys, values, values_length, num_units, reuse=False): 121 | """ Computes attention scores and outputs. 122 | Args: 123 | query: The query used to calculate attention scores. 124 | In seq2seq this is typically the current state of the decoder. 125 | A tensor of shape `[B, ...]` 126 | keys: The keys used to calculate attention scores. In seq2seq, these 127 | are typically the outputs of the encoder and equivalent to `values`. 128 | A tensor of shape `[B, T, ...]` where each element in the `T` 129 | dimension corresponds to the key for that value. 130 | values: The elements to compute attention over. In seq2seq, this is 131 | typically the sequence of encoder outputs. 132 | A tensor of shape `[B, T, input_dim]`. 133 | values_length: An int32 tensor of shape `[B]` defining the sequence 134 | length of the attention values. 135 | Returns: 136 | A tuple `(scores, context)`. 137 | `scores` is vector of length `T` where each element is the 138 | normalized "score" of the corresponding `inputs` element. 139 | `context` is the final attention layer output corresponding to 140 | the weighted inputs. 141 | A tensor fo shape `[B, input_dim]`. 142 | """ 143 | with tf.variable_scope('attention_layer', reuse=reuse): 144 | values_depth = values.get_shape().as_list()[-1] 145 | 146 | # Fully connected layers to transform both keys and query 147 | # into a tensor with 'num_units' units 148 | att_keys = tf.contrib.layers.fully_connected( 149 | inputs=keys, 150 | num_outputs=num_units, 151 | activation_fn=None, 152 | scope="att_keys") 153 | att_query = tf.contrib.layers.fully_connected( 154 | inputs=query, 155 | num_outputs=num_units, 156 | activation_fn=None, 157 | scope="att_query") 158 | 159 | scores = tf.reduce_sum(att_keys * tf.expand_dims(att_query, 1), [2]) 160 | 161 | # Replace all scores for padded inputs with tf.float32.min 162 | num_scores = tf.shape(scores)[1] 163 | scores_mask = tf.sequence_mask( 164 | lengths=tf.to_int32(values_length), 165 | maxlen=tf.to_int32(num_scores), 166 | dtype=tf.float32) 167 | scores = scores * scores_mask + ((1.0 - scores_mask) * tf.float32.min) 168 | 169 | # Normalize the scores 170 | scores_normalized = tf.nn.softmax(scores, name="scores_normalized") 171 | 172 | # Calculate the weighted average of the attention inputs 173 | # according to the scores 174 | context = tf.expand_dims(scores_normalized, 2) * values 175 | context = tf.reduce_sum(context, 1, name="context") 176 | context.set_shape([None, values_depth]) 177 | 178 | return scores_normalized, context 179 | 180 | 181 | def dense(x, size, name, activation=None, bias=True, weight_init=None, bias_init=None): 182 | w = tf.get_variable(name + "/w", [x.get_shape()[-1], size], initializer=weight_init) 183 | ret = tf.matmul(x, w) 184 | 185 | if bias: 186 | bias_init = tf.zeros_initializer() if bias_init is None else bias_init 187 | b = tf.get_variable(name + "/b", [size], initializer=bias_init) 188 | ret = ret + b 189 | 190 | if activation is not None: 191 | return activation(ret) 192 | else: 193 | return ret 194 | 195 | 196 | def ffnn2l(inputs, scope, size=128, activation=tf.nn.relu, weight_init=None, bias_init=tf.zeros_initializer()): 197 | with tf.variable_scope(scope): 198 | layer1 = tf.layers.dense(inputs=inputs, units=size, activation=activation, kernel_initializer=weight_init, bias_initializer=bias_init) 199 | layer2 = tf.layers.dense(inputs=layer1, units=size, activation=activation, kernel_initializer=weight_init, bias_initializer=bias_init) 200 | return layer2 201 | 202 | 203 | def ffnn(inputs, scope, n_layers=2, size=128, activation=tf.nn.relu): 204 | with tf.variable_scope(scope): 205 | layers = [None] * n_layers 206 | layers[0] = tf.layers.dense(inputs=inputs, units=size, activation=activation) 207 | for i in range(1, n_layers): 208 | layers[i] = tf.layers.dense(inputs=layers[i-1], units=size, activation=activation) 209 | 210 | return layers[-1] 211 | 212 | 213 | def conv_max_pool(inputs, num_filters, filter_size, scope_name): 214 | with tf.variable_scope("conv-maxpool-{}-{}".format(filter_size, scope_name)): 215 | num_channels = inputs.get_shape()[3] 216 | filter_shape = [1, filter_size, num_channels, num_filters] 217 | filter_ = tf.get_variable("filter", shape=filter_shape, dtype=tf.float32) 218 | bias = tf.get_variable("bias", shape=[num_filters], dtype=tf.float32) 219 | strides = [1, 1, 1, 1] 220 | 221 | conv = tf.nn.conv2d(inputs, filter_, strides=strides, padding="VALID", name="conv") 222 | h = tf.nn.relu(tf.nn.bias_add(conv, bias)) 223 | 224 | rank = len(h.shape) - 2 225 | pooled = tf.reduce_max(h, axis=rank) 226 | 227 | return pooled 228 | -------------------------------------------------------------------------------- /code/models/estimator_c.py: -------------------------------------------------------------------------------- 1 | ###################################################################### 2 | # Estimator - Coupled 3 | # 4 | # Main model 5 | # 6 | ###################################################################### 7 | 8 | 9 | ################################### 10 | # Imports 11 | # 12 | 13 | import tensorflow as tf 14 | from utils.analytics import * 15 | from utils.io_utils_c import * 16 | from utils.tree_navigation_c import * 17 | from utils.data_processing import hash_token, PADDING 18 | from models.network_c import * 19 | from models.replay_buffer_c import State, Transition, PrioritizedReplayBuffer 20 | import random 21 | from time import perf_counter 22 | 23 | 24 | ################################### 25 | # Classes 26 | # 27 | 28 | class Model: 29 | def __init__(self, model_id, seed, log_config): 30 | self.model_id = model_id 31 | self.seed = seed 32 | self.init = None 33 | self.sess = None 34 | self.train_writer = None 35 | self.saver = None 36 | self.best_saver = None 37 | 38 | tf.set_random_seed(self.seed) 39 | np.random.seed(self.seed) 40 | random.seed(self.seed) 41 | 42 | model_dir_path = log_config.model_temp.format(self.model_id) 43 | if not os.path.exists(model_dir_path): 44 | os.makedirs(model_dir_path) 45 | if not os.path.exists(log_config.log_dir): 46 | os.makedirs(log_config.log_dir) 47 | 48 | def start_sess(self, num_threads, tfevents=False): 49 | if num_threads != -1: 50 | config = tf.ConfigProto(intra_op_parallelism_threads=num_threads, inter_op_parallelism_threads=num_threads, 51 | allow_soft_placement=True, device_count={'CPU': 1}) 52 | self.sess = tf.Session(config=config) 53 | else: 54 | self.sess = tf.Session() 55 | if tfevents: 56 | self.train_writer = tf.summary.FileWriter('../logs') 57 | self.train_writer.add_graph(self.sess.graph) 58 | assert self.init is not None 59 | self.sess.run(self.init) 60 | 61 | def close_sess(self): 62 | if self.train_writer is not None: 63 | self.train_writer.close() 64 | self.sess.close() 65 | self.sess = None 66 | 67 | def load(self, step, log_config): 68 | assert self.sess is not None 69 | model_step_path = os.path.join(log_config.model_temp.format(self.model_id), self.model_id + '-' + str(step)) 70 | self.saver.restore(self.sess, model_step_path) 71 | 72 | def store(self, step, log_config): 73 | assert self.sess is not None 74 | model_step_path = os.path.join(log_config.model_temp.format(self.model_id), self.model_id) 75 | save_path = self.saver.save(self.sess, model_step_path, global_step=step) 76 | return save_path 77 | 78 | def load_best(self, log_config): 79 | assert self.sess is not None 80 | model_step_path = os.path.join(log_config.model_temp.format(self.model_id), self.model_id + '-best') 81 | self.best_saver.restore(self.sess, model_step_path) 82 | 83 | def store_best(self, log_config): 84 | assert self.sess is not None 85 | model_step_path = os.path.join(log_config.model_temp.format(self.model_id), self.model_id + '-best') 86 | save_path = self.best_saver.save(self.sess, model_step_path) 87 | return save_path 88 | 89 | 90 | class ModelEstimator(Model): 91 | def __init__(self, word_embeddings, char_emb_len, model_id, seed, model_conf, train_conf, log_config): 92 | Model.__init__(self, model_id, seed, log_config) 93 | self.mc = model_conf 94 | self.tc = train_conf 95 | self.lc = log_config 96 | self.step = 0 97 | self.epoch = 0 98 | self.best_acc = 0.0 99 | self.best_acc_dev = 0.0 100 | 101 | # estimators 102 | self.q_estimator = RLModel(known_emb=word_embeddings.known, unknown_emb=word_embeddings.unknown, 103 | char_emb_len=char_emb_len, model_conf=self.mc, scope='q_estimator') 104 | self.t_estimator = RLModel(known_emb=word_embeddings.known, unknown_emb=word_embeddings.unknown, 105 | char_emb_len=char_emb_len, model_conf=self.mc, scope='t_estimator') 106 | self.estimator_copy = ModelParametersCopier(self.q_estimator, self.t_estimator) 107 | 108 | # reply memory 109 | self.replay_memory = PrioritizedReplayBuffer(self.tc.replay_memory_size, alpha=self.tc.per_alpha) 110 | self.beta_schedule = np.linspace(self.tc.per_beta_start, self.tc.per_beta_end, self.tc.per_beta_growth_steps) 111 | 112 | # policy 113 | self.epsilon_a_schedule = np.linspace(self.tc.epsilon_a_start, self.tc.epsilon_a_end, self.tc.epsilon_a_decay_steps) 114 | if self.tc.policy_type == 'egp': 115 | self.policy = self.make_epsilon_greedy_policy(self.mc.output_dim) 116 | else: 117 | self.policy = self.make_epsilon_greedy_legal_policy(self.mc.output_dim) 118 | 119 | self.init = tf.global_variables_initializer() 120 | self.saver = tf.train.Saver(max_to_keep=2) 121 | self.best_saver = tf.train.Saver(max_to_keep=1) 122 | 123 | def train(self, train_samples, dev_samples, evidence_dict, encoder, flogstats, flogperf): 124 | print('\npopulating replay memory...', end='') 125 | self.populate_replay_memory(train_samples, evidence_dict, encoder) 126 | print('done, {} transitions\n'.format(len(self.replay_memory))) 127 | print("model {} starts training".format(self.model_id)) 128 | 129 | if self.tc.train_protocol.startswith("combined"): 130 | self.train_combined(train_samples, dev_samples, evidence_dict, encoder, flogstats, flogperf) 131 | 132 | elif self.tc.train_protocol == "random_balanced": 133 | self.train_random(train_samples, dev_samples, evidence_dict, encoder, flogstats, flogperf) 134 | 135 | else: 136 | self.train_sequential(train_samples, dev_samples, evidence_dict, encoder, flogstats, flogperf) 137 | 138 | print("model {} finished training".format(self.model_id)) 139 | 140 | def train_combined(self, train_samples, dev_samples, evidence_dict, encoder, flogstats, flogperf): 141 | epsilon_s_schedule = np.linspace(self.tc.epsilon_s_start, self.tc.epsilon_s_end, self.tc.epsilon_s_decay_steps) 142 | 143 | while True: 144 | np.random.shuffle(train_samples) 145 | rewards, avg_grads, avg_loss = [], [], [] 146 | reward_sums, path_avg_grads, path_avg_loss, path_lengths = [], [], [], [] 147 | 148 | for sample in train_samples: 149 | # check if we're done 150 | if self.step >= self.tc.max_steps: 151 | break 152 | 153 | info = get_sample_info(sample, evidence_dict, encoder, self.mc.token_length) 154 | (question_w, question_c, question_txt, answer_txt, eidx, ans_line_idx, evidence) = info 155 | 156 | epsilon_s = epsilon_s_schedule[min(self.step, self.tc.epsilon_s_decay_steps-1)] 157 | 158 | # random state sampling 159 | if np.random.rand() < epsilon_s: 160 | for _ in range(self.tc.combined_random_samples): 161 | if self.tc.train_protocol == "combined_ans_radius": 162 | node, observ_w, observ_c, props, t = init_step_random_answer_radius( 163 | evidence, encoder, ans_line_idx, self.mc.seq_length, self.mc.observ_length, self.mc.token_length, self.tc.max_episode_steps, 164 | self.tc.ans_radius, self.tc.ans_dist_prob) 165 | else: 166 | node, observ_w, observ_c, props, t = init_step_random_balanced( 167 | evidence, encoder, self.mc.seq_length, self.mc.observ_length, self.mc.token_length, self.tc.max_episode_steps) 168 | 169 | state = State(q_w=question_w, q_c=question_c, x_w=observ_w, x_c=observ_c, p=props) 170 | next_node, next_state, done, reward, grad, loss = self.predict_sample_update(state, node, question_w, question_c, 171 | ans_line_idx, t, evidence, encoder) 172 | if next_node is None and next_state is None: 173 | continue 174 | 175 | rewards.append(reward) 176 | avg_grads.append(grad) 177 | avg_loss.append(loss) 178 | 179 | self.step += 1 180 | self.post_update_checks(train_samples[:400], dev_samples[:400], evidence_dict, encoder, 181 | rewards, avg_loss, avg_grads, [-1], flogstats, flogperf) 182 | 183 | # sequential state sampling 184 | else: 185 | t, done = 0, False 186 | node, observ_w, observ_c, props = init_step(evidence, encoder, self.mc.seq_length, self.mc.observ_length, 187 | self.mc.token_length, t) 188 | state = State(q_w=question_w, q_c=question_c, x_w=observ_w, x_c=observ_c, p=props) 189 | path_rewards, path_grads, path_loss = [], [], [] 190 | while True: 191 | next_node, next_state, done, reward, grad, loss = self.predict_sample_update(state, node, question_w, question_c, 192 | ans_line_idx, t, evidence, encoder) 193 | t += 1 194 | if next_node is None and next_state is None: 195 | break 196 | 197 | path_rewards.append(reward) 198 | path_grads.append(grad) 199 | path_loss.append(loss) 200 | 201 | node, state = next_node, next_state 202 | self.step += 1 203 | self.post_update_checks(train_samples[:400], dev_samples[:400], evidence_dict, encoder, 204 | reward_sums, path_avg_loss, path_avg_grads, path_lengths, flogstats, flogperf) 205 | 206 | if done or t == self.tc.max_episode_steps: 207 | break 208 | 209 | reward_sums.append(np.sum(path_rewards)) 210 | path_avg_grads.append(np.mean(path_grads)) 211 | path_avg_loss.append(np.mean(path_loss)) 212 | path_lengths.append(t) 213 | 214 | save_path = self.store(self.step, self.lc) 215 | print("----- step {}\tmodel stored: {}".format(self.step, save_path)) 216 | 217 | # check if we're done 218 | if self.step >= self.tc.max_steps: 219 | break 220 | 221 | print("----- step {}\tfinished epoch: {}".format(self.step, self.epoch + 1)) 222 | self.epoch += 1 223 | 224 | def train_sequential(self, train_samples, dev_samples, evidence_dict, encoder, flogstats, flogperf): 225 | while True: 226 | np.random.shuffle(train_samples) 227 | reward_sums, path_avg_grads, path_avg_loss, path_lengths = [], [], [], [] 228 | 229 | for sample in train_samples: 230 | # check if we're done 231 | if self.step >= self.tc.max_steps: 232 | break 233 | 234 | info = get_sample_info(sample, evidence_dict, encoder, self.mc.token_length) 235 | (question_w, question_c, question_txt, answer_txt, eidx, ans_line_idx, evidence) = info 236 | t, done = 0, False 237 | node, observ_w, observ_c, props = init_step(evidence, encoder, self.mc.seq_length, self.mc.observ_length, 238 | self.mc.token_length, t) 239 | state = State(q_w=question_w, q_c=question_c, x_w=observ_w, x_c=observ_c, p=props) 240 | 241 | path_rewards, path_grads, path_loss = [], [], [] 242 | while True: 243 | next_node, next_state, done, reward, grad, loss = self.predict_sample_update(state, node, question_w, question_c, 244 | ans_line_idx, t, evidence, encoder) 245 | t += 1 246 | if next_node is None and next_state is None: 247 | break 248 | 249 | path_rewards.append(reward) 250 | path_grads.append(grad) 251 | path_loss.append(loss) 252 | 253 | node, state = next_node, next_state 254 | self.step += 1 255 | self.post_update_checks(train_samples[:400], dev_samples[:400], evidence_dict, encoder, 256 | reward_sums, path_avg_loss, path_avg_grads, path_lengths, flogstats, flogperf) 257 | 258 | if done or t == self.tc.max_episode_steps: 259 | break 260 | 261 | reward_sums.append(np.sum(path_rewards)) 262 | path_avg_grads.append(np.mean(path_grads)) 263 | path_avg_loss.append(np.mean(path_loss)) 264 | path_lengths.append(t) 265 | 266 | save_path = self.store(self.step, self.lc) 267 | print("----- step {}\tmodel stored: {}".format(self.step, save_path)) 268 | 269 | # check if we're done 270 | if self.step >= self.tc.max_steps: 271 | break 272 | 273 | print("----- step {}\tfinished epoch: {}".format(self.step, self.epoch+1)) 274 | self.epoch += 1 275 | 276 | def train_random(self, train_samples, dev_samples, evidence_dict, encoder, flogstats, flogperf): 277 | while True: 278 | np.random.shuffle(train_samples) 279 | rewards, avg_grads, avg_loss = [], [], [] 280 | 281 | for sample in train_samples: 282 | # check if we're done 283 | if self.step >= self.tc.max_steps: 284 | break 285 | 286 | info = get_sample_info(sample, evidence_dict, encoder, self.mc.token_length) 287 | (question_w, question_c, question_txt, answer_txt, eidx, ans_line_idx, evidence) = info 288 | 289 | # TODO: use init_step_batch instead? for more efficient sampling 290 | node, observ_w, observ_c, props, t = init_step_random_balanced(evidence, encoder, self.mc.seq_length, self.mc.observ_length, 291 | self.mc.token_length, self.tc.max_episode_steps) 292 | 293 | state = State(q_w=question_w, q_c=question_c, x_w=observ_w, x_c=observ_c, p=props) 294 | next_node, next_state, done, reward, grad, loss = self.predict_sample_update(state, node, question_w, question_c, 295 | ans_line_idx, t, evidence, encoder) 296 | if next_node is None and next_state is None: 297 | continue 298 | 299 | rewards.append(reward) 300 | avg_grads.append(grad) 301 | avg_loss.append(loss) 302 | 303 | self.step += 1 304 | self.post_update_checks(train_samples[:400], dev_samples[:400], evidence_dict, encoder, 305 | rewards, avg_loss, avg_grads, [-1], flogstats, flogperf) 306 | 307 | save_path = self.store(self.step, self.lc) 308 | print("----- step {}\tmodel stored: {}".format(self.step, save_path)) 309 | 310 | # check if we're done 311 | if self.step >= self.tc.max_steps: 312 | break 313 | 314 | print("----- step {}\tfinished epoch: {}".format(self.step, self.epoch + 1)) 315 | self.epoch += 1 316 | 317 | def predict_paths(self, samples, evidence_dict, encoder, flog=None, fout=None): 318 | write_predict_paths_header(flog) 319 | metrics, reward_sums = [], [] 320 | predictions = [] 321 | 322 | for sample in samples: 323 | info = get_sample_info(sample, evidence_dict, encoder, self.mc.token_length) 324 | (question_w, question_c, question_txt, answer_txt, eidx, ans_line_idx, evidence) = info 325 | question_tokens = encoder.idxs_to_ws(question_w) 326 | 327 | t, done = 0, False 328 | node, observ_w, observ_c, props = init_step(evidence, encoder, self.mc.seq_length, self.mc.observ_length, 329 | self.mc.token_length, t) 330 | observ_tokens = encoder.idxs_to_ws(observ_w) 331 | state = State(q_w=question_w, q_c=question_c, x_w=observ_w, x_c=observ_c, p=props) 332 | 333 | path_rewards, path_actions, path_num_illegal_moves = [], [], 0 334 | while True: 335 | action_probs, q_values, x_weights, q_weights = self.get_aprobs_qvals_weights(state, node, flog) 336 | action = np.random.choice(np.arange(len(action_probs)), p=action_probs) 337 | path_actions.append(action) 338 | t += 1 339 | 340 | next_node, observ_w, observ_c, props, done = make_step(evidence, encoder, node, action, self.mc.seq_length, 341 | self.mc.observ_length, self.mc.token_length, t) 342 | next_state = State(q_w=question_w, q_c=question_c, x_w=observ_w, x_c=observ_c, p=props) 343 | reward = get_reward(node, action, ans_line_idx, evidence, self.tc.scores) 344 | path_rewards.append(reward) 345 | path_num_illegal_moves += int(is_illegal_move(node, action)) 346 | 347 | write_step_start_msg(flog, info, t, node, observ_tokens, x_weights, question_tokens, q_weights, 348 | q_values, action, reward) 349 | if done or t == self.tc.max_episode_steps: 350 | break 351 | write_step_end_msg(flog, node, next_node) 352 | 353 | node, state = next_node, next_state 354 | observ_tokens = encoder.idxs_to_ws(observ_w) 355 | 356 | metrics.append(get_sample_metrics(node, ans_line_idx, evidence, t, path_num_illegal_moves, q_values, reward)) 357 | closest_line_idx, closest_line_diff = get_closest_idx_diff(node, ans_line_idx) 358 | write_path_end_msg(flog, node, done, action, closest_line_diff) 359 | if fout is not None: 360 | predictions.append(create_json_record(sample, node.line - int(node.is_root))) 361 | 362 | reward_sums.append(sum(path_rewards)) 363 | 364 | write_predictions_json(predictions, fout) 365 | 366 | return get_metrics(metrics, reward_sums) 367 | 368 | def predict_paths_test(self, samples, evidence_dict, encoder, fout): 369 | predictions = [] 370 | total_time, total_steps = 0, 0 371 | 372 | for sample in samples: 373 | # start timer 374 | t0 = perf_counter() 375 | 376 | info = get_sample_info(sample, evidence_dict, encoder, self.mc.token_length, test=True) 377 | (question_w, question_c, question_txt, answer_txt, eidx, ans_line_idx, evidence) = info 378 | t, done = 0, False 379 | node, observ_w, observ_c, props = init_step(evidence, encoder, self.mc.seq_length, self.mc.observ_length, 380 | self.mc.token_length, t) 381 | state = State(q_w=question_w, q_c=question_c, x_w=observ_w, x_c=observ_c, p=props) 382 | 383 | while True: 384 | action_probs, q_values, x_weights, q_weights = self.get_aprobs_qvals_weights(state, node, None) 385 | action = np.random.choice(np.arange(len(action_probs)), p=action_probs) 386 | t += 1 387 | 388 | next_node, observ_w, observ_c, props, done = make_step(evidence, encoder, node, action, self.mc.seq_length, 389 | self.mc.observ_length, self.mc.token_length, t) 390 | next_state = State(q_w=question_w, q_c=question_c, x_w=observ_w, x_c=observ_c, p=props) 391 | if done or t == self.tc.max_episode_steps: 392 | break 393 | 394 | node, state = next_node, next_state 395 | 396 | # stop timer 397 | total_time += perf_counter() - t0 398 | total_steps += t 399 | 400 | predictions.append(create_json_record(sample, node.line - int(node.is_root))) 401 | 402 | write_predictions_json(predictions, fout) 403 | 404 | def check_performance(self, train_samples, dev_samples, evidence_dict, encoder, flog): 405 | self.check_performance_samples(dev_samples, evidence_dict, encoder, 'DEV', flog) 406 | self.check_performance_samples(train_samples, evidence_dict, encoder, 'TRN', flog) 407 | 408 | def check_performance_samples(self, samples, evidence_dict, encoder, desc, flog): 409 | metrics_agg, df = self.predict_paths(samples, evidence_dict, encoder) 410 | msg = "{}\tstep {}\t".format(desc, self.step) + metrics_agg_to_str(metrics_agg) 411 | print('----- ' + msg) 412 | write_flog('{}\t{}\t'.format(desc, self.step) + metrics_agg_to_str(metrics_agg, fields=False) + '\n', flog) 413 | 414 | if metrics_agg.avg_acc > self.best_acc: 415 | self.best_acc = metrics_agg.avg_acc 416 | print("----- step {}\tnew best accuracy {:.4f} ({})".format(self.step, self.best_acc, desc)) 417 | 418 | if desc == 'DEV' and metrics_agg.avg_acc > self.best_acc_dev: 419 | self.best_acc_dev = metrics_agg.avg_acc 420 | save_path = self.store_best(self.lc) 421 | print("----- step {}\tbest model stored: {}".format(self.step, save_path)) 422 | 423 | def evaluate(self, samples, evidence_dict, encoder, perf_dbg_path, output_path): 424 | np.random.shuffle(samples) 425 | if perf_dbg_path is None: 426 | metrics_agg, df = self.predict_paths(samples, evidence_dict, encoder, None, output_path) 427 | else: 428 | with open(perf_dbg_path, 'w', encoding='utf-8') as flog: 429 | metrics_agg, df = self.predict_paths(samples, evidence_dict, encoder, flog, output_path) 430 | 431 | msg = metrics_agg_to_str(metrics_agg) 432 | print("model {} finished evaluation:\n{}".format(self.model_id, msg)) 433 | 434 | def make_epsilon_greedy_policy(self, num_actions): 435 | estimator = self.q_estimator 436 | 437 | def policy_fn(sess, state, epsilon, get_weights=False): 438 | actions = np.ones(num_actions, dtype=float) * epsilon / num_actions 439 | if get_weights: 440 | q_values, x_weights, q_weights_context = estimator.predict(sess, state, get_weights) 441 | else: 442 | q_values = estimator.predict(sess, state, get_weights) 443 | 444 | q_values = q_values[0] 445 | best_action = np.argmax(q_values) 446 | actions[best_action] += (1.0 - epsilon) 447 | if get_weights: 448 | return actions, q_values, x_weights[0], q_weights_context[0] 449 | else: 450 | return actions, q_values 451 | 452 | return policy_fn 453 | 454 | def make_epsilon_greedy_legal_policy(self, num_actions): 455 | estimator = self.q_estimator 456 | 457 | def policy_fn(sess, state, node, epsilon, get_weights=False): 458 | legal_actions = get_legal_actions(node) 459 | num_legal_actions = len(legal_actions) 460 | actions = np.zeros(num_actions, dtype=float) 461 | actions[legal_actions] = epsilon / num_legal_actions 462 | if get_weights: 463 | q_values, x_weights, q_weights_context = estimator.predict(sess, state, get_weights) 464 | else: 465 | q_values = estimator.predict(sess, state, get_weights) 466 | 467 | q_values = q_values[0] 468 | best_action = np.argmax(q_values) 469 | actions[best_action] += (1.0 - epsilon) 470 | if get_weights: 471 | return actions, q_values, x_weights[0], q_weights_context[0] 472 | else: 473 | return actions, q_values 474 | 475 | return policy_fn 476 | 477 | def predict_sample_update(self, state, node, question_w, question_c, ans_line_idx, t, evidence, encoder): 478 | action_probs, q_values = self.get_aprobs_qvals(state, node) 479 | action = np.random.choice(np.arange(len(action_probs)), p=action_probs) 480 | t += 1 481 | 482 | try: 483 | next_node, observ_w, observ_c, props, done = make_step(evidence, encoder, node, action, self.mc.seq_length, 484 | self.mc.observ_length, self.mc.token_length, t) 485 | except Exception as e: 486 | msg = "Error during 'make_step': {}\nevidence: {}\nnode: {}\naction: {}".format( 487 | e, evidence["tree"].name, node, action) 488 | print(msg) 489 | return None, None, None, None, None, None 490 | 491 | next_state = State(q_w=question_w, q_c=question_c, x_w=observ_w, x_c=observ_c, p=props) 492 | reward = get_reward(node, action, ans_line_idx, evidence, self.tc.scores) 493 | self.replay_memory.add(Transition(state, action, reward, next_state, 494 | done or t == self.tc.max_episode_steps)) 495 | 496 | # sample a batch from the replay memory, perform gradient descent update 497 | transitions = self.replay_memory.sample(self.tc.batch_size, 498 | beta=self.beta_schedule[min(self.step, self.tc.per_beta_growth_steps - 1)]) 499 | (states_batch, action_batch, reward_batch, next_states_batch, done_batch, weights, idx_batch) = transitions 500 | 501 | # DDQN 502 | q_values_next = self.q_estimator.predict(self.sess, next_states_batch) 503 | best_actions = np.argmax(q_values_next, axis=1) 504 | q_values_next_target = self.t_estimator.predict(self.sess, next_states_batch) 505 | targets_batch = reward_batch + np.invert(done_batch).astype(np.float32) * self.tc.gamma * \ 506 | q_values_next_target[np.arange(self.tc.batch_size), best_actions] 507 | 508 | grad, loss, td_err = self.q_estimator.update(self.sess, states_batch, action_batch, targets_batch, 509 | weights, self.train_writer, self.step) 510 | new_priorities = np.abs(td_err) + self.tc.per_eps 511 | self.replay_memory.update_priorities(idx_batch, new_priorities) 512 | 513 | return next_node, next_state, done, reward, grad, loss 514 | 515 | def post_update_checks(self, train_samples, dev_samples, evidence_dict, encoder, 516 | rewards, avg_loss, avg_grads, avg_path_len, flogstats, flogperf): 517 | # target estimator periodic update 518 | if self.step % self.tc.update_estimator_freq == 0: 519 | self.estimator_copy.make(self.sess) 520 | print("----- step {}\testimator was updated".format(self.step)) 521 | 522 | # writing training statistics 523 | if self.step % 2000 == 0 and len(rewards) > 0: 524 | write_train_stats(self.step, rewards, avg_loss, avg_grads, avg_path_len, flogstats) 525 | for lst in [rewards, avg_grads, avg_loss, avg_path_len]: 526 | lst.clear() 527 | 528 | # checking model performance 529 | if self.step % self.tc.check_freq == 0: 530 | self.check_performance(train_samples, dev_samples, evidence_dict, encoder, flogperf) 531 | 532 | # storing current model 533 | if self.step % 50000 == 0: 534 | save_path = self.store(self.step, self.lc) 535 | print("----- step {}\tmodel stored: {}".format(self.step, save_path)) 536 | 537 | def get_aprobs_qvals(self, state, node): 538 | if self.tc.policy_type == 'egp': 539 | action_probs, q_values = self.policy( 540 | self.sess, state, self.epsilon_a_schedule[min(self.step, self.tc.epsilon_a_decay_steps - 1)]) 541 | else: 542 | action_probs, q_values = self.policy( 543 | self.sess, state, node, self.epsilon_a_schedule[min(self.step, self.tc.epsilon_a_decay_steps - 1)]) 544 | 545 | return action_probs, q_values 546 | 547 | def get_aprobs_qvals_weights(self, state, node, flog): 548 | # no weights 549 | if flog is None: 550 | x_weights, q_weights = None, None 551 | action_probs, q_values = self.get_aprobs_qvals(state, node) 552 | 553 | # get weights 554 | else: 555 | if self.tc.policy_type == 'egp': 556 | action_probs, q_values, x_weights, q_weights = self.policy( 557 | self.sess, state, self.epsilon_a_schedule[min(self.step, self.tc.epsilon_a_decay_steps - 1)], get_weights=True) 558 | else: 559 | action_probs, q_values, x_weights, q_weights = self.policy( 560 | self.sess, state, node, self.epsilon_a_schedule[min(self.step, self.tc.epsilon_a_decay_steps - 1)], get_weights=True) 561 | 562 | return action_probs, q_values, x_weights, q_weights 563 | 564 | def populate_replay_memory(self, samples, evidence_dict, encoder): 565 | if self.tc.train_protocol.startswith("combined"): 566 | self.populate_replay_memory_combined(samples, evidence_dict, encoder) 567 | 568 | elif self.tc.train_protocol == "random_balanced": 569 | self.populate_replay_memory_random(samples, evidence_dict, encoder) 570 | 571 | else: 572 | self.populate_replay_memory_sequential(samples, evidence_dict, encoder) 573 | 574 | def populate_replay_memory_sequential(self, samples, evidence_dict, encoder): 575 | for sample in samples: 576 | info = get_sample_info(sample, evidence_dict, encoder, self.mc.token_length) 577 | (question_w, question_c, question_txt, answer_txt, eidx, ans_line_idx, evidence) = info 578 | 579 | t, done = 0, False 580 | node, observ_w, observ_c, props = init_step(evidence, encoder, self.mc.seq_length, self.mc.observ_length, 581 | self.mc.token_length, t) 582 | state = State(q_w=question_w, q_c=question_c, x_w=observ_w, x_c=observ_c, p=props) 583 | 584 | while len(self.replay_memory) < self.tc.replay_memory_init_size: 585 | if self.tc.policy_type == 'egp': 586 | action_probs, q_values = self.policy(self.sess, state, 587 | self.epsilon_a_schedule[min(self.step, self.tc.epsilon_a_decay_steps - 1)]) 588 | else: 589 | action_probs, q_values = self.policy(self.sess, state, node, 590 | self.epsilon_a_schedule[min(self.step, self.tc.epsilon_a_decay_steps - 1)]) 591 | action = np.random.choice(np.arange(len(action_probs)), p=action_probs) 592 | 593 | t += 1 594 | next_node, observ_w, observ_c, props, done = make_step(evidence, encoder, node, action, self.mc.seq_length, 595 | self.mc.observ_length, self.mc.token_length, t) 596 | next_state = State(q_w=question_w, q_c=question_c, x_w=observ_w, x_c=observ_c, p=props) 597 | reward = get_reward(node, action, ans_line_idx, evidence, self.tc.scores) 598 | self.replay_memory.add(Transition(state, action, reward, next_state, done or t == self.tc.max_episode_steps)) 599 | 600 | node, state = next_node, next_state 601 | if done or t == self.tc.max_episode_steps: 602 | break 603 | 604 | if len(self.replay_memory) >= self.tc.replay_memory_init_size: 605 | break 606 | 607 | def populate_replay_memory_random(self, samples, evidence_dict, encoder): 608 | while len(self.replay_memory) < self.tc.replay_memory_init_size: 609 | # choose a random sample 610 | sample = random.choice(samples) 611 | info = get_sample_info(sample, evidence_dict, encoder, self.mc.token_length) 612 | (question_w, question_c, question_txt, answer_txt, eidx, ans_line_idx, evidence) = info 613 | 614 | # choose a random state 615 | node, observ_w, observ_c, props, t = init_step_random_balanced(evidence, encoder, self.mc.seq_length, self.mc.observ_length, 616 | self.mc.token_length, self.tc.max_episode_steps) 617 | 618 | state = State(q_w=question_w, q_c=question_c, x_w=observ_w, x_c=observ_c, p=props) 619 | 620 | # choose an action 621 | # sample randomly if the model starts training, otherwise use pre-trained policy 622 | if self.step > 0: 623 | if self.tc.policy_type == 'egp': 624 | action_probs, q_values = self.policy(self.sess, state, 625 | self.epsilon_a_schedule[min(self.step, self.tc.epsilon_a_decay_steps - 1)]) 626 | else: 627 | action_probs, q_values = self.policy(self.sess, state, node, 628 | self.epsilon_a_schedule[min(self.step, self.tc.epsilon_a_decay_steps - 1)]) 629 | action = np.random.choice(np.arange(len(action_probs)), p=action_probs) 630 | else: 631 | action = np.random.choice(np.arange(self.mc.output_dim)) 632 | t += 1 633 | 634 | # add a transition to memory 635 | next_node, observ_w, observ_c, props, done = make_step(evidence, encoder, node, action, self.mc.seq_length, 636 | self.mc.observ_length, self.mc.token_length, t) 637 | next_state = State(q_w=question_w, q_c=question_c, x_w=observ_w, x_c=observ_c, p=props) 638 | reward = get_reward(node, action, ans_line_idx, evidence, self.tc.scores) 639 | self.replay_memory.add(Transition(state, action, reward, next_state, done or t == self.tc.max_episode_steps)) 640 | 641 | def populate_replay_memory_combined(self, samples, evidence_dict, encoder): 642 | epsilon_s_schedule = np.linspace(self.tc.epsilon_s_start, self.tc.epsilon_s_end, self.tc.epsilon_s_decay_steps) 643 | epsilon_s = epsilon_s_schedule[min(self.step, self.tc.epsilon_s_decay_steps - 1)] 644 | 645 | while len(self.replay_memory) < self.tc.replay_memory_init_size: 646 | # choose a random sample 647 | sample = random.choice(samples) 648 | info = get_sample_info(sample, evidence_dict, encoder, self.mc.token_length) 649 | (question_w, question_c, question_txt, answer_txt, eidx, ans_line_idx, evidence) = info 650 | 651 | # random state sampling 652 | if np.random.rand() < epsilon_s: 653 | for _ in range(self.tc.combined_random_samples): 654 | if self.tc.train_protocol == "combined_ans_radius": 655 | node, observ_w, observ_c, props, t = init_step_random_answer_radius( 656 | evidence, encoder, ans_line_idx, self.mc.seq_length, self.mc.observ_length, self.mc.token_length, self.tc.max_episode_steps, 657 | self.tc.ans_radius, self.tc.ans_dist_prob) 658 | else: 659 | node, observ_w, observ_c, props, t = init_step_random_balanced( 660 | evidence, encoder, self.mc.seq_length, self.mc.observ_length, self.mc.token_length, self.tc.max_episode_steps) 661 | 662 | state = State(q_w=question_w, q_c=question_c, x_w=observ_w, x_c=observ_c, p=props) 663 | 664 | # choose an action 665 | # sample randomly if the model starts training, otherwise use pre-trained policy 666 | if self.step > 0: 667 | if self.tc.policy_type == 'egp': 668 | action_probs, q_values = self.policy(self.sess, state, 669 | self.epsilon_a_schedule[min(self.step, self.tc.epsilon_a_decay_steps - 1)]) 670 | else: 671 | action_probs, q_values = self.policy(self.sess, state, node, 672 | self.epsilon_a_schedule[min(self.step, self.tc.epsilon_a_decay_steps - 1)]) 673 | action = np.random.choice(np.arange(len(action_probs)), p=action_probs) 674 | else: 675 | action = np.random.choice(np.arange(self.mc.output_dim)) 676 | t += 1 677 | 678 | # add a transition to memory 679 | next_node, observ_w, observ_c, props, done = make_step(evidence, encoder, node, action, self.mc.seq_length, 680 | self.mc.observ_length, self.mc.token_length, t) 681 | next_state = State(q_w=question_w, q_c=question_c, x_w=observ_w, x_c=observ_c, p=props) 682 | reward = get_reward(node, action, ans_line_idx, evidence, self.tc.scores) 683 | self.replay_memory.add(Transition(state, action, reward, next_state, done or t == self.tc.max_episode_steps)) 684 | 685 | # sequential state sampling 686 | else: 687 | t, done = 0, False 688 | node, observ_w, observ_c, props = init_step(evidence, encoder, self.mc.seq_length, self.mc.observ_length, 689 | self.mc.token_length, t) 690 | state = State(q_w=question_w, q_c=question_c, x_w=observ_w, x_c=observ_c, p=props) 691 | 692 | while len(self.replay_memory) < self.tc.replay_memory_init_size: 693 | if self.tc.policy_type == 'egp': 694 | action_probs, q_values = self.policy(self.sess, state, 695 | self.epsilon_a_schedule[min(self.step, self.tc.epsilon_a_decay_steps - 1)]) 696 | else: 697 | action_probs, q_values = self.policy(self.sess, state, node, 698 | self.epsilon_a_schedule[min(self.step, self.tc.epsilon_a_decay_steps - 1)]) 699 | action = np.random.choice(np.arange(len(action_probs)), p=action_probs) 700 | 701 | t += 1 702 | next_node, observ_w, observ_c, props, done = make_step(evidence, encoder, node, action, self.mc.seq_length, 703 | self.mc.observ_length, self.mc.token_length, t) 704 | next_state = State(q_w=question_w, q_c=question_c, x_w=observ_w, x_c=observ_c, p=props) 705 | reward = get_reward(node, action, ans_line_idx, evidence, self.tc.scores) 706 | self.replay_memory.add(Transition(state, action, reward, next_state, done or t == self.tc.max_episode_steps)) 707 | 708 | node, state = next_node, next_state 709 | if done or t == self.tc.max_episode_steps: 710 | break 711 | 712 | if len(self.replay_memory) >= self.tc.replay_memory_init_size: 713 | break 714 | 715 | 716 | class Encoder: 717 | def __init__(self, vocabulary): 718 | self.vocab = vocabulary 719 | 720 | def get_char_emb_len(self): 721 | return len(self.vocab.char_indices) 722 | 723 | def ws_to_idxs(self, words): 724 | seq_length = len(words) 725 | windices = np.zeros(seq_length, dtype=np.int32) 726 | for i in range(seq_length): 727 | windices[i] = self.w_to_idx(words[i][0]) 728 | 729 | return windices 730 | 731 | def w_to_idx(self, word): 732 | if word in self.vocab.word_indices: 733 | return self.vocab.word_indices[word] 734 | else: 735 | return self.vocab.word_indices[hash_token(word)] 736 | 737 | def ws_to_c_idxs(self, words): 738 | seq_length = len(words) 739 | cindices = [] 740 | for i in range(seq_length): 741 | cindices.append(self.w_to_c_idxs(words[i][0])) 742 | 743 | return cindices 744 | 745 | def w_to_c_idxs(self, word): 746 | return np.asarray([self.vocab.char_indices.get(word[j], 1) for j in range(len(word))], dtype=np.int32) 747 | 748 | def idxs_to_ws(self, indices): 749 | if indices.ndim > 1: 750 | tokens = [] 751 | for indices_line in indices: 752 | tokens.append([self.vocab.index_words[index] for index in indices_line]) 753 | tokens = tokens[0] 754 | else: 755 | tokens = [self.vocab.index_words[index] for index in indices] 756 | 757 | return [x for x in tokens if x != PADDING] 758 | 759 | def encode_seq(self, tokens): 760 | res_w = self.ws_to_idxs(tokens) 761 | res_c = self.ws_to_c_idxs(tokens) 762 | return res_w, res_c 763 | 764 | def pad_idx_seq_1dim(self, seq, seq_len, val): 765 | return np.pad(seq, (0, seq_len-len(seq)), 'constant', constant_values=(val,)) 766 | 767 | def pad_idx_seq_2dim(self, seq, seq_len, val): 768 | return np.asarray([self.pad_idx_seq_1dim(x[:seq_len], seq_len, val) for x in seq], dtype=np.int32) 769 | 770 | def concate_pad_seq(self, seq, seq1_len, seq2_len, val): 771 | # assuming len(seq) <= seq1_len 772 | return np.concatenate([seq, np.ones((seq1_len - len(seq), seq2_len), dtype=np.int32) * val]) 773 | 774 | 775 | ################################### 776 | # Functions 777 | # 778 | 779 | def get_sample_info(sample, evidence_dict, encoder, token_len, test=False): 780 | question_w, question_c = encoder.encode_seq(sample['QuestionTokens']) 781 | question_c = encoder.pad_idx_seq_2dim(question_c, token_len, PADDING_IDX) 782 | question_c = np.reshape(question_c, (1, len(question_c), token_len)) 783 | 784 | question_txt = sample['Question'] 785 | eidx = sample['OrigEvidenceIdx'] 786 | evidence = get_evidence(evidence_dict, sample) 787 | 788 | if test: 789 | answer_txt = None 790 | ans_line_idx = None 791 | else: 792 | answer_txt = sample['NormalizedAliases'] 793 | ans_line_idx = sample['AnswerLineIdx'] 794 | 795 | return question_w, question_c, question_txt, answer_txt, eidx, ans_line_idx, evidence 796 | 797 | 798 | def get_reward_line_diff(node, action, ans_line_idx, evidence, scores): 799 | navigation_reward = scores.r_delta 800 | 801 | if action == ACTIONS['STOP']: 802 | closest_idx, line_diff = get_closest_idx_diff(node, ans_line_idx) 803 | ev_len = get_evidence_length(evidence) 804 | navigation_reward = (ev_len - line_diff) / ev_len 805 | navigation_reward += scores.r_win if line_diff == 0 else 0 806 | 807 | return navigation_reward 808 | 809 | 810 | def get_reward(node, action, ans_line_idx, evidence, scores): 811 | return get_reward_line_diff(node, action, ans_line_idx, evidence, scores) 812 | 813 | 814 | -------------------------------------------------------------------------------- /code/models/network.py: -------------------------------------------------------------------------------- 1 | ###################################################################### 2 | # Neural Model 3 | # 4 | # Dueling Deep Q-Network 5 | # 6 | ###################################################################### 7 | 8 | 9 | ################################### 10 | # Imports 11 | # 12 | 13 | import tensorflow as tf 14 | import numpy as np 15 | from models.replay_buffer import State, StateExt 16 | from models.blocks import * 17 | 18 | 19 | ################################### 20 | # Classes 21 | # 22 | 23 | class RLModel(object): 24 | def __init__(self, known_emb, unknown_emb, char_emb_len, model_conf, scope): 25 | # hyperparameters 26 | self.word_embedding_dim = model_conf.word_embedding_dim 27 | self.char_embedding_dim = model_conf.char_embedding_dim 28 | self.hidden_dim_q = model_conf.hidden_dim_q 29 | self.hidden_dim_x = model_conf.hidden_dim_x 30 | self.hidden_dim_a = model_conf.hidden_dim_a 31 | self.props_dim = model_conf.props_dim 32 | self.ans_props_dim = model_conf.ans_props_dim 33 | self.output_dim = model_conf.output_dim 34 | self.token_length = model_conf.token_length 35 | self.observ_length = model_conf.observ_length 36 | self.learning_rate = model_conf.learning_rate 37 | self.dropout_rate = model_conf.dropout_rate 38 | 39 | self.kernel_initializer = tf.glorot_uniform_initializer() 40 | self.bias_initializer = tf.truncated_normal_initializer(mean=0.011, stddev=0.005) 41 | 42 | self.scope = scope 43 | with tf.variable_scope(scope): 44 | self._build_model(known_emb, unknown_emb, char_emb_len) 45 | 46 | def _build_model(self, known_emb, unknown_emb, char_emb_len): 47 | # placeholders 48 | self.q_w = tf.placeholder(tf.int32, [None, None], name='q_w') 49 | self.q_c = tf.placeholder(tf.int32, [None, None, self.token_length], name='q_c') 50 | self.x_w = tf.placeholder(tf.int32, [None, self.observ_length], name='x_w') 51 | self.x_c = tf.placeholder(tf.int32, [None, self.observ_length, self.token_length], name='x_c') 52 | self.p = tf.placeholder(tf.int32, [None, self.props_dim], name='p') 53 | self.a_w = tf.placeholder(tf.int32, [None, None], name='a_w') 54 | self.a_c = tf.placeholder(tf.int32, [None, None, self.token_length], name='a_c') 55 | self.a_p = tf.placeholder(tf.float32, [None, self.ans_props_dim], name='a_p') 56 | 57 | self.a = tf.placeholder(tf.int32, [None], name='a') 58 | self.y = tf.placeholder(tf.float32, [None], name='y') 59 | self.w = tf.placeholder(tf.float32, [None], name='w') 60 | self.do = tf.placeholder_with_default(1.0, shape=()) 61 | 62 | # get lengths of unpadded sentences 63 | q_seq_length, q_seq_mask = seqlen(self.q_w) 64 | x_seq_lengths, x_seq_mask = seqlen(self.x_w) 65 | a_seq_lengths, a_seq_mask = seqlen(self.a_w) 66 | 67 | # word embedding lookup and dropout at embedding layer 68 | self.EWk = tf.Variable(known_emb, name='EWk', trainable=False) 69 | self.EWu = tf.Variable(unknown_emb, name='EWu', trainable=True) 70 | self.EW = tf.concat([self.EWk, self.EWu], axis=0) 71 | 72 | qw_emb = tf.nn.embedding_lookup(self.EW, self.q_w) 73 | qw_emb_drop = tf.nn.dropout(qw_emb, self.do) 74 | xw_emb = tf.nn.embedding_lookup(self.EW, self.x_w) 75 | xw_emb_drop = tf.nn.dropout(xw_emb, self.do) 76 | aw_emb = tf.nn.embedding_lookup(self.EW, self.a_w) 77 | aw_emb_drop = tf.nn.dropout(aw_emb, self.do) 78 | 79 | # character embedding lookup and dropout at embedding layer 80 | # +1 for uncommon characters 81 | self.ECku = tf.get_variable("ECku", (char_emb_len+1, self.char_embedding_dim), tf.float32, initializer=self.kernel_initializer) 82 | self.EC = tf.concat([tf.zeros((1, self.char_embedding_dim), dtype=tf.float32), 83 | self.ECku], axis=0, name='EC') 84 | 85 | qc_emb = tf.nn.embedding_lookup(self.EC, self.q_c) 86 | qc_emb_drop = tf.nn.dropout(qc_emb, self.do) 87 | qc_pooled = conv_max_pool(qc_emb_drop, 100, 5, "qc_pooled") 88 | xc_emb = tf.nn.embedding_lookup(self.EC, self.x_c) 89 | xc_emb_drop = tf.nn.dropout(xc_emb, self.do) 90 | xc_pooled = conv_max_pool(xc_emb_drop, 100, 5, "xc_pooled") 91 | ac_emb = tf.nn.embedding_lookup(self.EC, self.a_c) 92 | ac_emb_drop = tf.nn.dropout(ac_emb, self.do) 93 | ac_pooled = conv_max_pool(ac_emb_drop, 100, 5, "ac_pooled") 94 | 95 | # BiLSTM layer - q 96 | qw_qc_concat = tf.concat([qw_emb_drop, qc_pooled], 2) 97 | q_outputs, q_states = birnn(qw_qc_concat, dim=self.hidden_dim_q, keep_prob=self.do, 98 | seq_len=q_seq_length, name='q_bilstm') 99 | # q_states_concat = tf.concat([q_states[0][1], q_states[1][1]], 1, name='q_states_concat') 100 | q_outputs_concat = tf.concat(q_outputs, 2, name='q_outputs_concat') 101 | self.h2q_outputs = ffnn2l(q_outputs_concat, "h2q_outputs", size=256, activation=tf.nn.relu) 102 | q_weights = tf.layers.dense(inputs=self.h2q_outputs, units=1, activation=None, name='q_weights') 103 | self.q_weights_norm = tf.nn.softmax(tf.squeeze(q_weights, [2]), name="q_weights_norm") 104 | self.q_weighted = tf.reduce_sum(tf.multiply(tf.expand_dims(self.q_weights_norm, -1), q_outputs_concat), axis=1, name="q_weighted") 105 | 106 | # LSTM layer - x 107 | xw_xc_concat = tf.concat([xw_emb_drop, xc_pooled], 2) 108 | x_outputs, x_state = rnn(xw_xc_concat, dim=self.hidden_dim_x, keep_prob=self.do, 109 | seq_lens=x_seq_lengths, name='x_lstm') 110 | self.h2x_outputs = ffnn2l(x_outputs, "h2x_outputs", size=256, activation=tf.nn.relu) 111 | x_weights = tf.layers.dense(inputs=self.h2x_outputs, units=1, activation=None, name='x_weights') 112 | self.x_weights_norm = tf.nn.softmax(tf.squeeze(x_weights, [2]), name="x_weights_norm") 113 | self.x_weighted = tf.reduce_sum(tf.multiply(tf.expand_dims(self.x_weights_norm, -1), x_outputs), axis=1, name="x_weighted") 114 | 115 | # LSTM layer - a 116 | aw_ac_concat = tf.concat([aw_emb_drop, ac_pooled], axis=2) 117 | a_outputs, a_state = rnn(aw_ac_concat, dim=self.hidden_dim_a, keep_prob=self.do, 118 | seq_lens=a_seq_lengths, name='a_lstm') 119 | # c:= a_state[0], h:= a_state[1] 120 | a_repr = tf.concat([a_state[1], self.a_p], axis=1) 121 | 122 | # concatenate encoded observations 123 | self.h0 = tf.concat([self.q_weighted, self.x_weighted, a_repr], axis=1, name='h0') 124 | self.h1 = tf.layers.dense(inputs=self.h0, units=512, activation=tf.nn.relu, name='h1', 125 | kernel_initializer=self.kernel_initializer, bias_initializer=self.bias_initializer) 126 | 127 | # dueling 128 | self.h2v = tf.layers.dense(inputs=self.h1, units=256, activation=tf.nn.relu, name='h2v', 129 | kernel_initializer=self.kernel_initializer, bias_initializer=self.bias_initializer) 130 | self.h2a = tf.layers.dense(inputs=self.h1, units=256, activation=tf.nn.relu, name='h2a', 131 | kernel_initializer=self.kernel_initializer, bias_initializer=self.bias_initializer) 132 | 133 | # concatenate encoded observations with node props 134 | props = tf.cast(self.p, tf.float32) 135 | self.h2v_props = tf.concat([self.h2v, props], 1, name='h2v_props') 136 | self.h2a_props = tf.concat([self.h2a, props], 1, name='h2a_props') 137 | 138 | self.values = tf.layers.dense(inputs=self.h2v_props, units=1, activation=None, name='values') 139 | self.advantages = tf.layers.dense(inputs=self.h2a_props, units=self.output_dim, activation=None, name='advantages') 140 | 141 | self.preds = self.values + (self.advantages - 142 | tf.reduce_mean(self.advantages, reduction_indices=1, keep_dims=True)) 143 | 144 | # loss calculation 145 | flat_preds = tf.reshape(self.preds, [-1]) 146 | batch_size = tf.shape(self.x_w)[0] 147 | gather_indices = tf.range(batch_size) * self.output_dim + self.a 148 | partitions = tf.reduce_sum(tf.one_hot(gather_indices, tf.shape(flat_preds)[0], dtype='int32'), 0) 149 | self.a_preds = tf.dynamic_partition(flat_preds, partitions, 2)[1] 150 | 151 | self.td_err = self.a_preds - self.y 152 | losses = huber_loss(self.td_err) 153 | self.weighted_loss = tf.reduce_mean(self.w * losses) 154 | 155 | # GD with Adam 156 | self.optimizer = tf.train.RMSPropOptimizer(self.learning_rate, decay=0.99, momentum=0.0, epsilon=1e-06) 157 | # self.optimizer = tf.train.AdamOptimizer(self.learning_rate) 158 | scope_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope) 159 | self.train_op, self.grad_norm = minimize_and_clip(self.optimizer, self.weighted_loss, scope_vars) 160 | 161 | # visualization 162 | self.summaries = self.get_summaries() 163 | 164 | def predict(self, sess, states, get_weights=False): 165 | feed = self.get_feed_dict(states) 166 | if get_weights: 167 | return sess.run([self.preds, self.x_weights_norm, self.q_weights_norm], feed) 168 | else: 169 | return sess.run(self.preds, feed) 170 | 171 | def update(self, sess, states, actions, targets, weights, writer, step): 172 | feed = self.get_feed_dict_batch(states, actions, targets, weights) 173 | 174 | if writer is None: 175 | _, grads, loss, td_err = sess.run([self.train_op, self.grad_norm, self.weighted_loss, self.td_err], feed) 176 | else: 177 | _, grads, loss, td_err, summaries = sess.run([self.train_op, self.grad_norm, self.weighted_loss, 178 | self.td_err, self.summaries], feed) 179 | writer.add_summary(summaries, step) 180 | 181 | return grads, loss, td_err 182 | 183 | def get_feed_dict(self, states): 184 | if isinstance(states, State) or isinstance(states, StateExt): 185 | feed_dict = {self.q_w: np.vstack([states.q_w]), 186 | self.q_c: states.q_c, 187 | self.x_w: states.x_w, 188 | self.x_c: states.x_c, 189 | self.p: states.p, 190 | self.a_w: np.zeros(shape=(1, 1), dtype=np.float32), 191 | self.a_c: np.zeros(shape=(1, 1, self.token_length), dtype=np.float32), 192 | self.a_p: np.zeros(shape=(1, self.ans_props_dim), dtype=np.float32), 193 | self.do: self.dropout_rate} 194 | if isinstance(states, StateExt): 195 | feed_dict.update({self.a_w: np.vstack([states.a_w]), 196 | self.a_c: states.a_c, 197 | self.a_p: states.a_p}) 198 | 199 | else: 200 | qs_w = [state.q_w for state in states] 201 | max_q_len = max([len(q) for q in qs_w]) 202 | qs_w = [get_padded(q, max_q_len, 0) for q in qs_w] 203 | qs_c = [np.reshape(np.concatenate( 204 | [state.q_c[0], np.zeros((max_q_len-len(state.q_c[0]), self.token_length), dtype=np.int32)]), 205 | (1, max_q_len, self.token_length)) 206 | for state in states] 207 | 208 | as_w = [state.a_w if isinstance(state, StateExt) else np.zeros(shape=1, dtype=np.float32) 209 | for state in states] 210 | as_c = [state.a_c if isinstance(state, StateExt) else np.zeros(shape=(1, 1, self.token_length), dtype=np.float32) 211 | for state in states] 212 | as_p = [state.a_p[0] if isinstance(state, StateExt) else np.zeros(shape=self.ans_props_dim, dtype=np.float32) 213 | for state in states] 214 | max_a_len = max([len(a) for a in as_w]) 215 | as_w = [get_padded(a, max_a_len, 0) for a in as_w] 216 | as_c = [np.reshape(np.concatenate( 217 | [a[0], np.zeros((max_a_len - len(a[0]), self.token_length), dtype=np.int32)]), 218 | (1, max_a_len, self.token_length)) 219 | for a in as_c] 220 | 221 | feed_dict = {self.q_w: np.vstack(qs_w), 222 | self.q_c: np.vstack(qs_c), 223 | self.x_w: np.vstack([state.x_w[0] for state in states]), 224 | self.x_c: np.vstack([state.x_c for state in states]), 225 | self.p: np.vstack([state.p[0] for state in states]), 226 | self.a_w: np.vstack(as_w), 227 | self.a_c: np.vstack(as_c), 228 | self.a_p: np.vstack(as_p), 229 | self.do: self.dropout_rate} 230 | 231 | return feed_dict 232 | 233 | def get_feed_dict_batch(self, states_batch, actions_batch, targets_batch, weights): 234 | feed_dict = self.get_feed_dict(states_batch) 235 | feed_dict.update({self.a: np.array(actions_batch), 236 | self.y: np.array(targets_batch), 237 | self.w: weights}) 238 | 239 | return feed_dict 240 | 241 | def get_num_model_params(self): 242 | scope_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope) 243 | total_params = 0 244 | for variable in scope_vars: 245 | shape = variable.get_shape() 246 | variable_params = 1 247 | for dim in shape: 248 | variable_params *= dim.value 249 | total_params += variable_params 250 | 251 | return total_params 252 | 253 | def get_summaries(self): 254 | scope_hists = self.get_var_weights() 255 | summaries = tf.summary.merge(scope_hists + [ 256 | tf.summary.histogram("h1_act", self.h1), 257 | tf.summary.histogram("h2v_act", self.h2v), 258 | tf.summary.histogram("h2a_act", self.h2a), 259 | tf.summary.histogram("values_act", self.values), 260 | tf.summary.histogram("advantages_act", self.advantages), 261 | tf.summary.histogram('preds', self.preds), 262 | tf.summary.scalar("avg.loss", self.weighted_loss) 263 | ]) 264 | 265 | return summaries 266 | 267 | def get_var_weights(self): 268 | scope_hists = [] 269 | scope_names = ['h1', 'h2v', 'h2a', 'values', 'advantages'] 270 | for scope_name in scope_names: 271 | scope_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope + '/' + scope_name) 272 | scope_hists.extend([tf.summary.histogram(scope_name + '_kernel', scope_vars[0]), 273 | tf.summary.histogram(scope_name + '_bias', scope_vars[1])]) 274 | 275 | return scope_hists 276 | 277 | def write_weights_log(self, sess, step, writer): 278 | summaries = tf.summary.merge(self.get_var_weights()) 279 | writer.add_summary(sess.run(summaries, {}), step) 280 | 281 | 282 | class ModelParametersCopier: 283 | def __init__(self, estimator1, estimator2): 284 | e1_params = [t for t in tf.trainable_variables() if t.name.startswith(estimator1.scope)] 285 | e1_params = sorted(e1_params, key=lambda v: v.name) 286 | e2_params = [t for t in tf.trainable_variables() if t.name.startswith(estimator2.scope)] 287 | e2_params = sorted(e2_params, key=lambda v: v.name) 288 | 289 | self.update_ops = [] 290 | for e1_v, e2_v in zip(e1_params, e2_params): 291 | op = e2_v.assign(e1_v) 292 | self.update_ops.append(op) 293 | 294 | def make(self, sess): 295 | sess.run(self.update_ops) 296 | -------------------------------------------------------------------------------- /code/models/network_c.py: -------------------------------------------------------------------------------- 1 | ###################################################################### 2 | # Neural Model - Coupled 3 | # 4 | # Dueling Deep Q-Network 5 | # 6 | ###################################################################### 7 | 8 | 9 | ################################### 10 | # Imports 11 | # 12 | 13 | import tensorflow as tf 14 | import numpy as np 15 | from models.replay_buffer_c import State 16 | from models.blocks import * 17 | 18 | 19 | ################################### 20 | # Classes 21 | # 22 | 23 | class RLModel(object): 24 | def __init__(self, known_emb, unknown_emb, char_emb_len, model_conf, scope): 25 | # hyperparameters 26 | self.word_embedding_dim = model_conf.word_embedding_dim 27 | self.char_embedding_dim = model_conf.char_embedding_dim 28 | self.hidden_dim_q = model_conf.hidden_dim_q 29 | self.hidden_dim_x = model_conf.hidden_dim_x 30 | self.props_dim = model_conf.props_dim 31 | self.output_dim = model_conf.output_dim 32 | self.token_length = model_conf.token_length 33 | self.observ_length = model_conf.observ_length 34 | self.learning_rate = model_conf.learning_rate 35 | self.dropout_rate = model_conf.dropout_rate 36 | 37 | self.kernel_initializer = tf.glorot_uniform_initializer() 38 | self.bias_initializer = tf.truncated_normal_initializer(mean=0.011, stddev=0.005) 39 | 40 | self.scope = scope 41 | with tf.variable_scope(scope): 42 | self._build_model(known_emb, unknown_emb, char_emb_len) 43 | 44 | def _build_model(self, known_emb, unknown_emb, char_emb_len): 45 | # placeholders 46 | self.q_w = tf.placeholder(tf.int32, [None, None], name='q_w') 47 | self.q_c = tf.placeholder(tf.int32, [None, None, self.token_length], name='q_c') 48 | self.x_w = tf.placeholder(tf.int32, [None, self.observ_length], name='x_w') 49 | self.x_c = tf.placeholder(tf.int32, [None, self.observ_length, self.token_length], name='x_c') 50 | self.p = tf.placeholder(tf.int32, [None, self.props_dim], name='p') 51 | self.a = tf.placeholder(tf.int32, [None], name='a') 52 | self.y = tf.placeholder(tf.float32, [None], name='y') 53 | self.w = tf.placeholder(tf.float32, [None], name='w') 54 | self.do = tf.placeholder_with_default(1.0, shape=()) 55 | 56 | # get lengths of unpadded sentences 57 | q_seq_length, q_seq_mask = seqlen(self.q_w) 58 | x_seq_lengths, x_seq_mask = seqlen(self.x_w) 59 | 60 | # word embedding lookup and dropout at embedding layer 61 | self.EWk = tf.Variable(known_emb, name='EWk', trainable=False) 62 | self.EWu = tf.Variable(unknown_emb, name='EWu', trainable=True) 63 | self.EW = tf.concat([self.EWk, self.EWu], axis=0) 64 | 65 | qw_emb = tf.nn.embedding_lookup(self.EW, self.q_w) 66 | qw_emb_drop = tf.nn.dropout(qw_emb, self.do) 67 | xw_emb = tf.nn.embedding_lookup(self.EW, self.x_w) 68 | xw_emb_drop = tf.nn.dropout(xw_emb, self.do) 69 | 70 | # character embedding lookup and dropout at embedding layer 71 | # +1 for uncommon characters 72 | self.ECku = tf.get_variable("ECku", (char_emb_len+1, self.char_embedding_dim), tf.float32, initializer=self.kernel_initializer) 73 | self.EC = tf.concat([tf.zeros((1, self.char_embedding_dim), dtype=tf.float32), 74 | self.ECku], axis=0, name='EC') 75 | 76 | qc_emb = tf.nn.embedding_lookup(self.EC, self.q_c) 77 | qc_emb_drop = tf.nn.dropout(qc_emb, self.do) 78 | qc_pooled = conv_max_pool(qc_emb_drop, 100, 5, "qc_pooled") 79 | xc_emb = tf.nn.embedding_lookup(self.EC, self.x_c) 80 | xc_emb_drop = tf.nn.dropout(xc_emb, self.do) 81 | xc_pooled = conv_max_pool(xc_emb_drop, 100, 5, "xc_pooled") 82 | 83 | # BiLSTM layer - q 84 | qw_qc_concat = tf.concat([qw_emb_drop, qc_pooled], 2) 85 | q_outputs, q_states = birnn(qw_qc_concat, dim=self.hidden_dim_q, keep_prob=self.do, 86 | seq_len=q_seq_length, name='q_bilstm') 87 | # q_states_concat = tf.concat([q_states[0][1], q_states[1][1]], 1, name='q_states_concat') 88 | q_outputs_concat = tf.concat(q_outputs, 2, name='q_outputs_concat') 89 | self.h2q_outputs = ffnn2l(q_outputs_concat, "h2q_outputs", size=256, activation=tf.nn.relu) 90 | q_weights = tf.layers.dense(inputs=self.h2q_outputs, units=1, activation=None, name='q_weights') 91 | self.q_weights_norm = tf.nn.softmax(tf.squeeze(q_weights, [2]), name="q_weights_norm") 92 | self.q_weighted = tf.reduce_sum(tf.multiply(tf.expand_dims(self.q_weights_norm, -1), q_outputs_concat), axis=1, name="q_weighted") 93 | 94 | # LSTM layer - x 95 | xw_xc_concat = tf.concat([xw_emb_drop, xc_pooled], 2) 96 | x_outputs, x_states = rnn(xw_xc_concat, dim=self.hidden_dim_x, keep_prob=self.do, 97 | seq_lens=x_seq_lengths, name='x_lstm') 98 | self.h2x_outputs = ffnn2l(x_outputs, "h2x_outputs", size=256, activation=tf.nn.relu) 99 | x_weights = tf.layers.dense(inputs=self.h2x_outputs, units=1, activation=None, name='x_weights') 100 | self.x_weights_norm = tf.nn.softmax(tf.squeeze(x_weights, [2]), name="x_weights_norm") 101 | self.x_weighted = tf.reduce_sum(tf.multiply(tf.expand_dims(self.x_weights_norm, -1), x_outputs), axis=1, name="x_weighted") 102 | 103 | # concatenate encoded observations 104 | self.h0 = tf.concat([self.q_weighted, self.x_weighted], 1, name='h0') 105 | self.h1 = tf.layers.dense(inputs=self.h0, units=512, activation=tf.nn.relu, name='h1', 106 | kernel_initializer=self.kernel_initializer, bias_initializer=self.bias_initializer) 107 | 108 | # dueling 109 | self.h2v = tf.layers.dense(inputs=self.h1, units=256, activation=tf.nn.relu, name='h2v', 110 | kernel_initializer=self.kernel_initializer, bias_initializer=self.bias_initializer) 111 | self.h2a = tf.layers.dense(inputs=self.h1, units=256, activation=tf.nn.relu, name='h2a', 112 | kernel_initializer=self.kernel_initializer, bias_initializer=self.bias_initializer) 113 | 114 | # concatenate encoded observations with node props 115 | props = tf.cast(self.p, tf.float32) 116 | self.h2v_props = tf.concat([self.h2v, props], 1, name='h2v_props') 117 | self.h2a_props = tf.concat([self.h2a, props], 1, name='h2a_props') 118 | 119 | self.values = tf.layers.dense(inputs=self.h2v_props, units=1, activation=None, name='values') 120 | self.advantages = tf.layers.dense(inputs=self.h2a_props, units=self.output_dim, activation=None, name='advantages') 121 | 122 | self.preds = self.values + (self.advantages - 123 | tf.reduce_mean(self.advantages, reduction_indices=1, keep_dims=True)) 124 | 125 | # loss calculation 126 | flat_preds = tf.reshape(self.preds, [-1]) 127 | batch_size = tf.shape(self.x_w)[0] 128 | gather_indices = tf.range(batch_size) * self.output_dim + self.a 129 | partitions = tf.reduce_sum(tf.one_hot(gather_indices, tf.shape(flat_preds)[0], dtype='int32'), 0) 130 | self.a_preds = tf.dynamic_partition(flat_preds, partitions, 2)[1] 131 | 132 | self.td_err = self.a_preds - self.y 133 | losses = huber_loss(self.td_err) 134 | self.weighted_loss = tf.reduce_mean(self.w * losses) 135 | 136 | # GD with Adam 137 | self.optimizer = tf.train.RMSPropOptimizer(self.learning_rate, decay=0.99, momentum=0.0, epsilon=1e-06) 138 | # self.optimizer = tf.train.AdamOptimizer(self.learning_rate) 139 | scope_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope) 140 | self.train_op, self.grad_norm = minimize_and_clip(self.optimizer, self.weighted_loss, scope_vars) 141 | 142 | # visualization 143 | self.summaries = self.get_summaries() 144 | 145 | def predict(self, sess, states, get_weights=False): 146 | feed = self.get_feed_dict(states) 147 | if get_weights: 148 | return sess.run([self.preds, self.x_weights_norm, self.q_weights_norm], feed) 149 | else: 150 | return sess.run(self.preds, feed) 151 | 152 | def update(self, sess, states, actions, targets, weights, writer, step): 153 | feed = self.get_feed_dict_batch(states, actions, targets, weights) 154 | 155 | if writer is None: 156 | _, grads, loss, td_err = sess.run([self.train_op, self.grad_norm, self.weighted_loss, self.td_err], feed) 157 | else: 158 | _, grads, loss, td_err, summaries = sess.run([self.train_op, self.grad_norm, self.weighted_loss, 159 | self.td_err, self.summaries], feed) 160 | writer.add_summary(summaries, step) 161 | 162 | return grads, loss, td_err 163 | 164 | def get_feed_dict(self, states): 165 | if isinstance(states, State): 166 | feed_dict = {self.q_w: np.vstack([states.q_w]), 167 | self.q_c: states.q_c, 168 | self.x_w: states.x_w, 169 | self.x_c: states.x_c, 170 | self.p: states.p, 171 | self.do: self.dropout_rate} 172 | else: 173 | qs_w = [state.q_w for state in states] 174 | max_q_len = max([len(q) for q in qs_w]) 175 | qs_w = [get_padded(q, max_q_len, 0) for q in qs_w] 176 | qs_c = [np.reshape(np.concatenate( 177 | [state.q_c[0], np.zeros((max_q_len-len(state.q_c[0]), self.token_length), dtype=np.int32)]), 178 | (1, max_q_len, self.token_length)) 179 | for state in states] 180 | 181 | feed_dict = {self.q_w: np.vstack(qs_w), 182 | self.q_c: np.vstack(qs_c), 183 | self.x_w: np.vstack([state.x_w[0] for state in states]), 184 | self.x_c: np.vstack([state.x_c for state in states]), 185 | self.p: np.vstack([state.p[0] for state in states]), 186 | self.do: self.dropout_rate} 187 | 188 | return feed_dict 189 | 190 | def get_feed_dict_batch(self, states_batch, actions_batch, targets_batch, weights): 191 | feed_dict = self.get_feed_dict(states_batch) 192 | feed_dict.update({self.a: np.array(actions_batch), 193 | self.y: np.array(targets_batch), 194 | self.w: weights}) 195 | 196 | return feed_dict 197 | 198 | def get_num_model_params(self): 199 | scope_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=self.scope) 200 | total_params = 0 201 | for variable in scope_vars: 202 | shape = variable.get_shape() 203 | variable_params = 1 204 | for dim in shape: 205 | variable_params *= dim.value 206 | total_params += variable_params 207 | 208 | return total_params 209 | 210 | def get_summaries(self): 211 | scope_hists = self.get_var_weights() 212 | summaries = tf.summary.merge(scope_hists + [ 213 | tf.summary.histogram("h1_act", self.h1), 214 | tf.summary.histogram("h2v_act", self.h2v), 215 | tf.summary.histogram("h2a_act", self.h2a), 216 | tf.summary.histogram("values_act", self.values), 217 | tf.summary.histogram("advantages_act", self.advantages), 218 | tf.summary.histogram('preds', self.preds), 219 | tf.summary.scalar("avg.loss", self.weighted_loss) 220 | ]) 221 | 222 | return summaries 223 | 224 | def get_var_weights(self): 225 | scope_hists = [] 226 | scope_names = ['h1', 'h2v', 'h2a', 'values', 'advantages'] 227 | for scope_name in scope_names: 228 | scope_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.scope + '/' + scope_name) 229 | scope_hists.extend([tf.summary.histogram(scope_name + '_kernel', scope_vars[0]), 230 | tf.summary.histogram(scope_name + '_bias', scope_vars[1])]) 231 | 232 | return scope_hists 233 | 234 | def write_weights_log(self, sess, step, writer): 235 | summaries = tf.summary.merge(self.get_var_weights()) 236 | writer.add_summary(sess.run(summaries, {}), step) 237 | 238 | 239 | class ModelParametersCopier: 240 | def __init__(self, estimator1, estimator2): 241 | e1_params = [t for t in tf.trainable_variables() if t.name.startswith(estimator1.scope)] 242 | e1_params = sorted(e1_params, key=lambda v: v.name) 243 | e2_params = [t for t in tf.trainable_variables() if t.name.startswith(estimator2.scope)] 244 | e2_params = sorted(e2_params, key=lambda v: v.name) 245 | 246 | self.update_ops = [] 247 | for e1_v, e2_v in zip(e1_params, e2_params): 248 | op = e2_v.assign(e1_v) 249 | self.update_ops.append(op) 250 | 251 | def make(self, sess): 252 | sess.run(self.update_ops) 253 | -------------------------------------------------------------------------------- /code/models/replay_buffer.py: -------------------------------------------------------------------------------- 1 | ###################################################################### 2 | # Replay Buffer 3 | # 4 | # The OpenAI implementation of Segment Tree based replay buffer. 5 | # Taken from: https://github.com/openai/baselines 6 | # 7 | ###################################################################### 8 | 9 | 10 | ################################### 11 | # Imports 12 | # 13 | 14 | import numpy as np 15 | import random 16 | from collections import namedtuple 17 | 18 | from models.segment_tree import SumSegmentTree, MinSegmentTree 19 | 20 | 21 | ################################### 22 | # Globals 23 | # 24 | 25 | State = namedtuple("State", [ 26 | "q_w", "q_c", "x_w", "x_c", "p" 27 | ]) 28 | 29 | StateExt = namedtuple("StateExt", [ 30 | "q_w", "q_c", "x_w", "x_c", "p", "a_w", "a_c", "a_p" 31 | ]) 32 | 33 | Transition = namedtuple("Transition", [ 34 | "state", "action", "reward", "next_state", "done" 35 | ]) 36 | 37 | 38 | ################################### 39 | # Classes 40 | # 41 | 42 | class ReplayBuffer(object): 43 | def __init__(self, size): 44 | """Create Prioritized Replay buffer. 45 | 46 | Parameters 47 | ---------- 48 | size: int 49 | Max number of transitions to store in the buffer. When the buffer 50 | overflows the old memories are dropped. 51 | """ 52 | self._storage = [] 53 | self._maxsize = size 54 | self._next_idx = 0 55 | 56 | def __len__(self): 57 | return len(self._storage) 58 | 59 | def add(self, transition): 60 | if self._next_idx >= len(self._storage): 61 | self._storage.append(transition) 62 | else: 63 | self._storage[self._next_idx] = transition 64 | self._next_idx = (self._next_idx + 1) % self._maxsize 65 | 66 | def _encode_sample(self, idxes): 67 | obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], [] 68 | for i in idxes: 69 | data = self._storage[i] 70 | obs_t, action, reward, obs_tp1, done = data 71 | obses_t.append(obs_t) 72 | actions.append(action) 73 | rewards.append(reward) 74 | obses_tp1.append(obs_tp1) 75 | dones.append(done) 76 | return obses_t, actions, rewards, obses_tp1, dones 77 | 78 | def sample(self, batch_size): 79 | """Sample a batch of experiences. 80 | 81 | Parameters 82 | ---------- 83 | batch_size: int 84 | How many transitions to sample. 85 | 86 | Returns 87 | ------- 88 | obs_batch: np.array 89 | batch of observations 90 | act_batch: np.array 91 | batch of actions executed given obs_batch 92 | rew_batch: np.array 93 | rewards received as results of executing act_batch 94 | next_obs_batch: np.array 95 | next set of observations seen after executing act_batch 96 | done_mask: np.array 97 | done_mask[i] = 1 if executing act_batch[i] resulted in 98 | the end of an episode and 0 otherwise. 99 | """ 100 | idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)] 101 | return self._encode_sample(idxes) 102 | 103 | 104 | class PrioritizedReplayBuffer(ReplayBuffer): 105 | def __init__(self, size, alpha): 106 | """Create Prioritized Replay buffer. 107 | 108 | Parameters 109 | ---------- 110 | size: int 111 | Max number of transitions to store in the buffer. When the buffer 112 | overflows the old memories are dropped. 113 | alpha: float 114 | how much prioritization is used 115 | (0 - no prioritization, 1 - full prioritization) 116 | 117 | See Also 118 | -------- 119 | ReplayBuffer.__init__ 120 | """ 121 | super(PrioritizedReplayBuffer, self).__init__(size) 122 | assert alpha > 0 123 | self._alpha = alpha 124 | 125 | it_capacity = 1 126 | while it_capacity < size: 127 | it_capacity *= 2 128 | 129 | self._it_sum = SumSegmentTree(it_capacity) 130 | self._it_min = MinSegmentTree(it_capacity) 131 | self._max_priority = 1.0 132 | 133 | def add(self, *args, **kwargs): 134 | """See ReplayBuffer.store_effect""" 135 | idx = self._next_idx 136 | super().add(*args, **kwargs) 137 | self._it_sum[idx] = self._max_priority ** self._alpha 138 | self._it_min[idx] = self._max_priority ** self._alpha 139 | 140 | def _sample_proportional(self, batch_size): 141 | res = [] 142 | for _ in range(batch_size): 143 | # TODO(szymon): should we ensure no repeats? 144 | mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1) 145 | idx = self._it_sum.find_prefixsum_idx(mass) 146 | res.append(idx) 147 | return res 148 | 149 | def sample(self, batch_size, beta): 150 | """Sample a batch of experiences. 151 | 152 | compared to ReplayBuffer.sample 153 | it also returns importance weights and idxes 154 | of sampled experiences. 155 | 156 | 157 | Parameters 158 | ---------- 159 | batch_size: int 160 | How many transitions to sample. 161 | beta: float 162 | To what degree to use importance weights 163 | (0 - no corrections, 1 - full correction) 164 | 165 | Returns 166 | ------- 167 | obs_batch: np.array 168 | batch of observations 169 | act_batch: np.array 170 | batch of actions executed given obs_batch 171 | rew_batch: np.array 172 | rewards received as results of executing act_batch 173 | next_obs_batch: np.array 174 | next set of observations seen after executing act_batch 175 | done_mask: np.array 176 | done_mask[i] = 1 if executing act_batch[i] resulted in 177 | the end of an episode and 0 otherwise. 178 | weights: np.array 179 | Array of shape (batch_size,) and dtype np.float32 180 | denoting importance weight of each sampled transition 181 | idxes: np.array 182 | Array of shape (batch_size,) and dtype np.int32 183 | idexes in buffer of sampled experiences 184 | """ 185 | assert beta > 0 186 | 187 | idxes = self._sample_proportional(batch_size) 188 | 189 | weights = [] 190 | p_min = self._it_min.min() / self._it_sum.sum() 191 | max_weight = (p_min * len(self._storage)) ** (-beta) 192 | 193 | for idx in idxes: 194 | p_sample = self._it_sum[idx] / self._it_sum.sum() 195 | weight = (p_sample * len(self._storage)) ** (-beta) 196 | weights.append(weight / max_weight) 197 | weights = np.array(weights) 198 | encoded_sample = self._encode_sample(idxes) 199 | return tuple(list(encoded_sample) + [weights, idxes]) 200 | 201 | def update_priorities(self, idxes, priorities): 202 | """Update priorities of sampled transitions. 203 | 204 | sets priority of transition at index idxes[i] in buffer 205 | to priorities[i]. 206 | 207 | Parameters 208 | ---------- 209 | idxes: [int] 210 | List of idxes of sampled transitions 211 | priorities: [float] 212 | List of updated priorities corresponding to 213 | transitions at the sampled idxes denoted by 214 | variable `idxes`. 215 | """ 216 | assert len(idxes) == len(priorities) 217 | for idx, priority in zip(idxes, priorities): 218 | assert priority > 0 219 | assert 0 <= idx < len(self._storage) 220 | self._it_sum[idx] = priority ** self._alpha 221 | self._it_min[idx] = priority ** self._alpha 222 | 223 | self._max_priority = max(self._max_priority, priority) 224 | -------------------------------------------------------------------------------- /code/models/replay_buffer_c.py: -------------------------------------------------------------------------------- 1 | ###################################################################### 2 | # Replay Buffer - Coupled 3 | # 4 | # The OpenAI implementation of Segment Tree based replay buffer. 5 | # Taken from: https://github.com/openai/baselines 6 | # 7 | ###################################################################### 8 | 9 | 10 | ################################### 11 | # Imports 12 | # 13 | 14 | import numpy as np 15 | import random 16 | from collections import namedtuple 17 | 18 | from models.segment_tree import SumSegmentTree, MinSegmentTree 19 | 20 | 21 | ################################### 22 | # Globals 23 | # 24 | 25 | State = namedtuple("State", [ 26 | "q_w", "q_c", "x_w", "x_c", "p" 27 | ]) 28 | 29 | Transition = namedtuple("Transition", [ 30 | "state", "action", "reward", "next_state", "done" 31 | ]) 32 | 33 | 34 | ################################### 35 | # Classes 36 | # 37 | 38 | class ReplayBuffer(object): 39 | def __init__(self, size): 40 | """Create Prioritized Replay buffer. 41 | 42 | Parameters 43 | ---------- 44 | size: int 45 | Max number of transitions to store in the buffer. When the buffer 46 | overflows the old memories are dropped. 47 | """ 48 | self._storage = [] 49 | self._maxsize = size 50 | self._next_idx = 0 51 | 52 | def __len__(self): 53 | return len(self._storage) 54 | 55 | def add(self, transition): 56 | if self._next_idx >= len(self._storage): 57 | self._storage.append(transition) 58 | else: 59 | self._storage[self._next_idx] = transition 60 | self._next_idx = (self._next_idx + 1) % self._maxsize 61 | 62 | def _encode_sample(self, idxes): 63 | obses_t, actions, rewards, obses_tp1, dones = [], [], [], [], [] 64 | for i in idxes: 65 | data = self._storage[i] 66 | obs_t, action, reward, obs_tp1, done = data 67 | obses_t.append(obs_t) 68 | actions.append(action) 69 | rewards.append(reward) 70 | obses_tp1.append(obs_tp1) 71 | dones.append(done) 72 | return obses_t, actions, rewards, obses_tp1, dones 73 | 74 | def sample(self, batch_size): 75 | """Sample a batch of experiences. 76 | 77 | Parameters 78 | ---------- 79 | batch_size: int 80 | How many transitions to sample. 81 | 82 | Returns 83 | ------- 84 | obs_batch: np.array 85 | batch of observations 86 | act_batch: np.array 87 | batch of actions executed given obs_batch 88 | rew_batch: np.array 89 | rewards received as results of executing act_batch 90 | next_obs_batch: np.array 91 | next set of observations seen after executing act_batch 92 | done_mask: np.array 93 | done_mask[i] = 1 if executing act_batch[i] resulted in 94 | the end of an episode and 0 otherwise. 95 | """ 96 | idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)] 97 | return self._encode_sample(idxes) 98 | 99 | 100 | class PrioritizedReplayBuffer(ReplayBuffer): 101 | def __init__(self, size, alpha): 102 | """Create Prioritized Replay buffer. 103 | 104 | Parameters 105 | ---------- 106 | size: int 107 | Max number of transitions to store in the buffer. When the buffer 108 | overflows the old memories are dropped. 109 | alpha: float 110 | how much prioritization is used 111 | (0 - no prioritization, 1 - full prioritization) 112 | 113 | See Also 114 | -------- 115 | ReplayBuffer.__init__ 116 | """ 117 | super(PrioritizedReplayBuffer, self).__init__(size) 118 | assert alpha > 0 119 | self._alpha = alpha 120 | 121 | it_capacity = 1 122 | while it_capacity < size: 123 | it_capacity *= 2 124 | 125 | self._it_sum = SumSegmentTree(it_capacity) 126 | self._it_min = MinSegmentTree(it_capacity) 127 | self._max_priority = 1.0 128 | 129 | def add(self, *args, **kwargs): 130 | """See ReplayBuffer.store_effect""" 131 | idx = self._next_idx 132 | super().add(*args, **kwargs) 133 | self._it_sum[idx] = self._max_priority ** self._alpha 134 | self._it_min[idx] = self._max_priority ** self._alpha 135 | 136 | def _sample_proportional(self, batch_size): 137 | res = [] 138 | for _ in range(batch_size): 139 | # TODO(szymon): should we ensure no repeats? 140 | mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1) 141 | idx = self._it_sum.find_prefixsum_idx(mass) 142 | res.append(idx) 143 | return res 144 | 145 | def sample(self, batch_size, beta): 146 | """Sample a batch of experiences. 147 | 148 | compared to ReplayBuffer.sample 149 | it also returns importance weights and idxes 150 | of sampled experiences. 151 | 152 | 153 | Parameters 154 | ---------- 155 | batch_size: int 156 | How many transitions to sample. 157 | beta: float 158 | To what degree to use importance weights 159 | (0 - no corrections, 1 - full correction) 160 | 161 | Returns 162 | ------- 163 | obs_batch: np.array 164 | batch of observations 165 | act_batch: np.array 166 | batch of actions executed given obs_batch 167 | rew_batch: np.array 168 | rewards received as results of executing act_batch 169 | next_obs_batch: np.array 170 | next set of observations seen after executing act_batch 171 | done_mask: np.array 172 | done_mask[i] = 1 if executing act_batch[i] resulted in 173 | the end of an episode and 0 otherwise. 174 | weights: np.array 175 | Array of shape (batch_size,) and dtype np.float32 176 | denoting importance weight of each sampled transition 177 | idxes: np.array 178 | Array of shape (batch_size,) and dtype np.int32 179 | idexes in buffer of sampled experiences 180 | """ 181 | assert beta > 0 182 | 183 | idxes = self._sample_proportional(batch_size) 184 | 185 | weights = [] 186 | p_min = self._it_min.min() / self._it_sum.sum() 187 | max_weight = (p_min * len(self._storage)) ** (-beta) 188 | 189 | for idx in idxes: 190 | p_sample = self._it_sum[idx] / self._it_sum.sum() 191 | weight = (p_sample * len(self._storage)) ** (-beta) 192 | weights.append(weight / max_weight) 193 | weights = np.array(weights) 194 | encoded_sample = self._encode_sample(idxes) 195 | return tuple(list(encoded_sample) + [weights, idxes]) 196 | 197 | def update_priorities(self, idxes, priorities): 198 | """Update priorities of sampled transitions. 199 | 200 | sets priority of transition at index idxes[i] in buffer 201 | to priorities[i]. 202 | 203 | Parameters 204 | ---------- 205 | idxes: [int] 206 | List of idxes of sampled transitions 207 | priorities: [float] 208 | List of updated priorities corresponding to 209 | transitions at the sampled idxes denoted by 210 | variable `idxes`. 211 | """ 212 | assert len(idxes) == len(priorities) 213 | for idx, priority in zip(idxes, priorities): 214 | assert priority > 0 215 | assert 0 <= idx < len(self._storage) 216 | self._it_sum[idx] = priority ** self._alpha 217 | self._it_min[idx] = priority ** self._alpha 218 | 219 | self._max_priority = max(self._max_priority, priority) 220 | -------------------------------------------------------------------------------- /code/models/segment_tree.py: -------------------------------------------------------------------------------- 1 | ###################################################################### 2 | # Segment Tree 3 | # 4 | # Segment Tree data structure implementation by OpenAI. 5 | # Taken from: https://github.com/openai/baselines 6 | # 7 | ###################################################################### 8 | 9 | 10 | ################################### 11 | # Imports 12 | # 13 | 14 | import operator 15 | 16 | 17 | ################################### 18 | # Classes 19 | # 20 | 21 | class SegmentTree(object): 22 | def __init__(self, capacity, operation, neutral_element): 23 | """Build a Segment Tree data structure. 24 | 25 | https://en.wikipedia.org/wiki/Segment_tree 26 | 27 | Can be used as regular array, but with two 28 | important differences: 29 | 30 | a) setting item's value is slightly slower. 31 | It is O(lg capacity) instead of O(1). 32 | b) user has access to an efficient `reduce` 33 | operation which reduces `operation` over 34 | a contiguous subsequence of items in the 35 | array. 36 | 37 | Paramters 38 | --------- 39 | capacity: int 40 | Total size of the array - must be a power of two. 41 | operation: lambda obj, obj -> obj 42 | and operation for combining elements (eg. sum, max) 43 | must for a mathematical group together with the set of 44 | possible values for array elements. 45 | neutral_element: obj 46 | neutral element for the operation above. eg. float('-inf') 47 | for max and 0 for sum. 48 | """ 49 | assert capacity > 0 and capacity & (capacity - 1) == 0, "capacity must be positive and a power of 2." 50 | self._capacity = capacity 51 | self._value = [neutral_element for _ in range(2 * capacity)] 52 | self._operation = operation 53 | 54 | def _reduce_helper(self, start, end, node, node_start, node_end): 55 | if start == node_start and end == node_end: 56 | return self._value[node] 57 | mid = (node_start + node_end) // 2 58 | if end <= mid: 59 | return self._reduce_helper(start, end, 2 * node, node_start, mid) 60 | else: 61 | if mid + 1 <= start: 62 | return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end) 63 | else: 64 | return self._operation( 65 | self._reduce_helper(start, mid, 2 * node, node_start, mid), 66 | self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end) 67 | ) 68 | 69 | def reduce(self, start=0, end=None): 70 | """Returns result of applying `self.operation` 71 | to a contiguous subsequence of the array. 72 | 73 | self.operation(arr[start], operation(arr[start+1], operation(... arr[end]))) 74 | 75 | Parameters 76 | ---------- 77 | start: int 78 | beginning of the subsequence 79 | end: int 80 | end of the subsequences 81 | 82 | Returns 83 | ------- 84 | reduced: obj 85 | result of reducing self.operation over the specified range of array elements. 86 | """ 87 | if end is None: 88 | end = self._capacity 89 | if end < 0: 90 | end += self._capacity 91 | end -= 1 92 | return self._reduce_helper(start, end, 1, 0, self._capacity - 1) 93 | 94 | def __setitem__(self, idx, val): 95 | # index of the leaf 96 | idx += self._capacity 97 | self._value[idx] = val 98 | idx //= 2 99 | while idx >= 1: 100 | self._value[idx] = self._operation( 101 | self._value[2 * idx], 102 | self._value[2 * idx + 1] 103 | ) 104 | idx //= 2 105 | 106 | def __getitem__(self, idx): 107 | assert 0 <= idx < self._capacity 108 | return self._value[self._capacity + idx] 109 | 110 | 111 | class SumSegmentTree(SegmentTree): 112 | def __init__(self, capacity): 113 | super(SumSegmentTree, self).__init__( 114 | capacity=capacity, 115 | operation=operator.add, 116 | neutral_element=0.0 117 | ) 118 | 119 | def sum(self, start=0, end=None): 120 | """Returns arr[start] + ... + arr[end]""" 121 | return super(SumSegmentTree, self).reduce(start, end) 122 | 123 | def find_prefixsum_idx(self, prefixsum): 124 | """Find the highest index `i` in the array such that 125 | sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum 126 | 127 | if array values are probabilities, this function 128 | allows to sample indexes according to the discrete 129 | probability efficiently. 130 | 131 | Parameters 132 | ---------- 133 | perfixsum: float 134 | upperbound on the sum of array prefix 135 | 136 | Returns 137 | ------- 138 | idx: int 139 | highest index satisfying the prefixsum constraint 140 | """ 141 | assert 0 <= prefixsum <= self.sum() + 1e-5 142 | idx = 1 143 | while idx < self._capacity: # while non-leaf 144 | if self._value[2 * idx] > prefixsum: 145 | idx = 2 * idx 146 | else: 147 | prefixsum -= self._value[2 * idx] 148 | idx = 2 * idx + 1 149 | return idx - self._capacity 150 | 151 | 152 | class MinSegmentTree(SegmentTree): 153 | def __init__(self, capacity): 154 | super(MinSegmentTree, self).__init__( 155 | capacity=capacity, 156 | operation=min, 157 | neutral_element=float('inf') 158 | ) 159 | 160 | def min(self, start=0, end=None): 161 | """Returns min(arr[start], ..., arr[end])""" 162 | 163 | return super(MinSegmentTree, self).reduce(start, end) 164 | -------------------------------------------------------------------------------- /code/run_model.py: -------------------------------------------------------------------------------- 1 | ###################################################################### 2 | # Run Model 3 | # 4 | # Main routine for execution of model training and evaluation 5 | # 6 | ###################################################################### 7 | 8 | 9 | ################################### 10 | # Imports 11 | # 12 | 13 | from models.estimator import * 14 | from utils.rpc_client import EvidenceRpcClient 15 | from utils.data_processing import WORD_EMBEDDING_DIM, DATA_CONFIG_NOP 16 | from utils.analytics import MetricsAgg 17 | from time import time 18 | import pickle 19 | 20 | 21 | ################################### 22 | # Globals 23 | # 24 | 25 | LOG_CONFIG, MODEL_CONFIG, TRAIN_CONFIG, SEED = None, None, None, None 26 | DATA_CONFIG_ = None 27 | 28 | 29 | ################################### 30 | # Helper functions 31 | # 32 | 33 | def load_data(): 34 | # load word embeddings 35 | print("Loading word embeddings...") 36 | with open(DATA_CONFIG_.glove_embeddings_path, 'rb') as fd: 37 | word_embeddings = pickle.load(fd) 38 | 39 | print("Loading vocabulary...") 40 | with open(DATA_CONFIG_.glove_vocab_path, 'rb') as fd: 41 | vocabulary = pickle.load(fd) 42 | encoder = Encoder(vocabulary) 43 | 44 | # get evidence RPC client 45 | print("Connect to evidence RPC server...") 46 | evidence_dict = EvidenceRpcClient() 47 | 48 | # load preprocessed datasets 49 | print("Loading datasets...") 50 | with open(DATA_CONFIG_.train_dataset.replace('.json', '.exp.pkl'), 'rb') as fd: 51 | train_samples = pickle.load(fd) 52 | with open(DATA_CONFIG_.dev_dataset.replace('.json', '.exp.pkl'), 'rb') as fd: 53 | dev_samples = pickle.load(fd) 54 | print("{} train samples, {} dev samples".format(len(train_samples), len(dev_samples))) 55 | 56 | return word_embeddings, encoder, evidence_dict, train_samples, dev_samples 57 | 58 | 59 | def set_configuration(args): 60 | global LOG_CONFIG, MODEL_CONFIG, TRAIN_CONFIG, SEED, DATA_CONFIG_ 61 | LOG_CONFIG, MODEL_CONFIG, TRAIN_CONFIG, SEED = get_configuration(WORD_EMBEDDING_DIM, PROPS_DIM, args) 62 | DATA_CONFIG_ = DATA_CONFIG_NOP 63 | 64 | 65 | def create_model(word_embeddings, char_emb_len, model_id, train, step=0): 66 | model = ModelEstimator(word_embeddings, char_emb_len, model_id, SEED, MODEL_CONFIG, TRAIN_CONFIG, LOG_CONFIG) 67 | model.step = step 68 | log_trn_perf_navigator = LOG_CONFIG.log_trn_perf_navigator.format(model_id, step) if train else None 69 | log_trn_stats_path = LOG_CONFIG.log_trn_stats_navigator.format(model_id, step) if train else None 70 | 71 | return model, log_trn_perf_navigator, log_trn_stats_path 72 | 73 | 74 | def evaluate_model(model, step, dev_samples, evidence_dict, encoder): 75 | log_dev_perf_dbg_path = LOG_CONFIG.dbg_log_perf_navigator.format(model.model_id, step, "dev") 76 | 77 | # zero exploration for evaluation 78 | model.step = model.tc.epsilon_a_decay_steps 79 | model.epsilon_a_schedule[-1] = 0.0 80 | 81 | print("Evaluating {} (step {}) on dev set...".format(model.model_id, step)) 82 | output_path = LOG_CONFIG.navigator_output_path.format(model.model_id, step, "dev") 83 | model.evaluate(dev_samples, evidence_dict, encoder, log_dev_perf_dbg_path, output_path) 84 | 85 | 86 | def test_model(model, step, evidence_dict, encoder): 87 | with open(DATA_CONFIG_.test_dataset.replace('.json', '.exp.pkl'), 'rb') as fd: 88 | samples = pickle.load(fd) 89 | 90 | # zero exploration for evaluation 91 | model.step = model.tc.epsilon_a_decay_steps 92 | model.epsilon_a_schedule[-1] = 0.0 93 | 94 | print("Test model {} (step {}) on test set...".format(model.model_id, step)) 95 | output_path = LOG_CONFIG.navigator_output_path.format(model.model_id, step, "test") 96 | model.predict_paths_test(samples, evidence_dict, encoder, output_path) 97 | 98 | 99 | ################################### 100 | # Main 101 | # 102 | 103 | def main(): 104 | args = parse_args() 105 | if not valid_args(args): 106 | exit() 107 | set_configuration(args) 108 | word_embeddings, encoder, evidence_dict, train_samples, dev_samples = load_data() 109 | char_emb_len = encoder.get_char_emb_len() 110 | 111 | if args.train: 112 | print("\nCreating a model...") 113 | timestamp = str(time()) 114 | model, log_trn_perf_path, log_trn_stats_path = create_model(word_embeddings, char_emb_len, timestamp, train=True) 115 | store_execution_config(model, LOG_CONFIG) 116 | print_config(model, LOG_CONFIG) 117 | print("Model ID: {}\ntotal params: {}".format(timestamp, model.q_estimator.get_num_model_params())) 118 | print("\nTraining...") 119 | model.start_sess(args.num_threads, args.tfevents) 120 | 121 | with open(log_trn_perf_path, 'w', LOG_FILE_BUFF_SIZE) as flogperf, \ 122 | open(log_trn_stats_path, 'w', LOG_FILE_BUFF_SIZE) as flogstats: 123 | write_flog("dataset\tstep\t" + "\t".join(key for key in MetricsAgg._fields) + "\n", flogperf) 124 | write_flog("step\tloss\tgrads\tpath_len\tavg.reward\tmin.reward\tmax.reward\n", flogstats) 125 | model.train(train_samples, dev_samples, evidence_dict, encoder, flogstats, flogperf) 126 | print("\nFinished training.") 127 | 128 | evaluate_model(model, model.step, dev_samples, evidence_dict, encoder) 129 | model.close_sess() 130 | 131 | if args.resume: 132 | model, log_trn_perf_path, log_trn_stats_path = create_model(word_embeddings, char_emb_len, args.model_id, train=True, step=args.model_step) 133 | store_execution_config(model, LOG_CONFIG) 134 | print_config(model, LOG_CONFIG) 135 | print("Loading model to resume training: {} step {}".format(args.model_id, args.model_step)) 136 | model.start_sess(args.num_threads) 137 | model.load(args.model_step, LOG_CONFIG) 138 | 139 | with open(log_trn_perf_path, 'w', LOG_FILE_BUFF_SIZE) as flogperf, \ 140 | open(log_trn_stats_path, 'w', LOG_FILE_BUFF_SIZE) as flogstats: 141 | write_flog("dataset\tstep\t" + "\t".join(key for key in MetricsAgg._fields) + "\n", flogperf) 142 | write_flog("step\tloss\tgrads\tpath_len\tavg.reward\tmin.reward\tmax.reward\n", flogstats) 143 | model.train(train_samples, dev_samples, evidence_dict, encoder, flogstats, flogperf) 144 | print("\nFinished training.") 145 | 146 | evaluate_model(model, model.step, dev_samples, evidence_dict, encoder) 147 | model.close_sess() 148 | 149 | if args.evaluate: 150 | model, log_trn_perf_path, log_trn_stats_path = create_model(word_embeddings, char_emb_len, args.model_id, train=False) 151 | 152 | step = "best" if args.model_best else args.model_step 153 | print("Loading model for evaluation: {} step {}".format(args.model_id, step)) 154 | model.start_sess(args.num_threads) 155 | if args.model_best: 156 | model.load_best(LOG_CONFIG) 157 | else: 158 | model.load(args.model_step, LOG_CONFIG) 159 | 160 | evaluate_model(model, step, dev_samples, evidence_dict, encoder) 161 | 162 | model.close_sess() 163 | 164 | if args.test: 165 | model, log_trn_perf_path, log_trn_stats_path = create_model(word_embeddings, char_emb_len, args.model_id, train=False) 166 | step = "best" if args.model_best else args.model_step 167 | print("Loading model for test: {} step {}".format(args.model_id, step)) 168 | model.start_sess(args.num_threads) 169 | if args.model_best: 170 | model.load_best(LOG_CONFIG) 171 | else: 172 | model.load(args.model_step, LOG_CONFIG) 173 | 174 | test_model(model, step, evidence_dict, encoder) 175 | 176 | 177 | if __name__ == '__main__': 178 | main() 179 | -------------------------------------------------------------------------------- /code/run_model_c.py: -------------------------------------------------------------------------------- 1 | ###################################################################### 2 | # Run Model - Coupled 3 | # 4 | # Main routine for execution of model training and evaluation 5 | # 6 | ###################################################################### 7 | 8 | 9 | ################################### 10 | # Imports 11 | # 12 | 13 | from models.estimator_c import * 14 | from utils.rpc_client import EvidenceRpcClient 15 | from utils.data_processing import WORD_EMBEDDING_DIM, DATA_CONFIG_NOP 16 | from time import time 17 | import pickle 18 | 19 | 20 | ################################### 21 | # Globals 22 | # 23 | 24 | LOG_CONFIG, MODEL_CONFIG, TRAIN_CONFIG, SEED = None, None, None, None 25 | DATA_CONFIG_ = None 26 | 27 | 28 | ################################### 29 | # Helper functions 30 | # 31 | 32 | def load_data(use_rpc): 33 | # load word embeddings 34 | print("Loading word embeddings...") 35 | with open(DATA_CONFIG_.glove_embeddings_path, 'rb') as fd: 36 | word_embeddings = pickle.load(fd) 37 | 38 | print("Loading vocabulary...") 39 | with open(DATA_CONFIG_.glove_vocab_path, 'rb') as fd: 40 | vocabulary = pickle.load(fd) 41 | encoder = Encoder(vocabulary) 42 | 43 | if use_rpc: 44 | # get evidence RPC client 45 | print("Connect to evidence RPC server...") 46 | evidence_dict = EvidenceRpcClient(coupled=True) 47 | else: 48 | # load evidence dict 49 | print("Loading evidence dict...", end='') 50 | with open(DATA_CONFIG_.evidence_dict_path, 'rb') as fd: 51 | evidence_dict = pickle.load(fd) 52 | print("{} evidences".format(len(evidence_dict))) 53 | 54 | # load preprocessed datasets 55 | print("Loading datasets...") 56 | with open(DATA_CONFIG_.train_dataset.replace('.json', '.exp.pkl'), 'rb') as fd: 57 | train_samples = pickle.load(fd) 58 | with open(DATA_CONFIG_.dev_dataset.replace('.json', '.exp.pkl'), 'rb') as fd: 59 | dev_samples = pickle.load(fd) 60 | print("{} train samples, {} dev samples".format(len(train_samples), len(dev_samples))) 61 | 62 | return word_embeddings, encoder, evidence_dict, train_samples, dev_samples 63 | 64 | 65 | def set_configuration(args): 66 | global LOG_CONFIG, MODEL_CONFIG, TRAIN_CONFIG, SEED, DATA_CONFIG_ 67 | LOG_CONFIG, MODEL_CONFIG, TRAIN_CONFIG, SEED = get_configuration(WORD_EMBEDDING_DIM, PROPS_DIM, args) 68 | DATA_CONFIG_ = DATA_CONFIG_NOP 69 | 70 | 71 | def create_model(word_embeddings, char_emb_len, model_id, train, step=0): 72 | model = ModelEstimator(word_embeddings, char_emb_len, model_id, SEED, MODEL_CONFIG, TRAIN_CONFIG, LOG_CONFIG) 73 | model.step = step 74 | log_trn_perf_navigator = LOG_CONFIG.log_trn_perf_navigator.format(model_id, step) if train else None 75 | log_trn_stats_path = LOG_CONFIG.log_trn_stats_navigator.format(model_id, step) if train else None 76 | 77 | return model, log_trn_perf_navigator, log_trn_stats_path 78 | 79 | 80 | def evaluate_model(model, step, dev_samples, evidence_dict, encoder): 81 | log_dev_perf_dbg_path = LOG_CONFIG.dbg_log_perf_navigator.format(model.model_id, step, "dev") 82 | 83 | # zero exploration for evaluation 84 | model.step = model.tc.epsilon_a_decay_steps 85 | model.epsilon_a_schedule[-1] = 0.0 86 | 87 | print("Evaluating {} (step {}) on dev set...".format(model.model_id, step)) 88 | output_path = LOG_CONFIG.navigator_output_path.format(model.model_id, step, "dev") 89 | model.evaluate(dev_samples, evidence_dict, encoder, log_dev_perf_dbg_path, output_path) 90 | 91 | 92 | def test_model(model, step, evidence_dict, encoder): 93 | with open(DATA_CONFIG_.test_dataset.replace('.json', '.exp.pkl'), 'rb') as fd: 94 | samples = pickle.load(fd) 95 | 96 | # zero exploration for evaluation 97 | model.step = model.tc.epsilon_a_decay_steps 98 | model.epsilon_a_schedule[-1] = 0.0 99 | 100 | print("Test model {} (step {}) on test set...".format(model.model_id, step)) 101 | output_path = LOG_CONFIG.navigator_output_path.format(model.model_id, step, "test") 102 | model.predict_paths_test(samples, evidence_dict, encoder, output_path) 103 | 104 | 105 | ################################### 106 | # Main 107 | # 108 | 109 | def main(): 110 | args = parse_args() 111 | if not valid_args(args): 112 | exit() 113 | set_configuration(args) 114 | word_embeddings, encoder, evidence_dict, train_samples, dev_samples = load_data(args.use_rpc) 115 | char_emb_len = encoder.get_char_emb_len() 116 | 117 | if args.train: 118 | print("\nCreating a model...") 119 | timestamp = str(time()) 120 | model, log_trn_perf_path, log_trn_stats_path = create_model(word_embeddings, char_emb_len, timestamp, train=True) 121 | store_execution_config(model, LOG_CONFIG) 122 | print_config(model, LOG_CONFIG) 123 | print("Model ID: {}\ntotal params: {}".format(timestamp, model.q_estimator.get_num_model_params())) 124 | print("\nTraining...") 125 | model.start_sess(args.num_threads, args.tfevents) 126 | 127 | with open(log_trn_perf_path, 'w', LOG_FILE_BUFF_SIZE) as flogperf, \ 128 | open(log_trn_stats_path, 'w', LOG_FILE_BUFF_SIZE) as flogstats: 129 | write_flog("dataset\tstep\t" + "\t".join(key for key in MetricsAgg._fields) + "\n", flogperf) 130 | write_flog("step\tloss\tgrads\tpath_len\tavg.reward\tmin.reward\tmax.reward\n", flogstats) 131 | model.train(train_samples, dev_samples, evidence_dict, encoder, flogstats, flogperf) 132 | print("\nFinished training.") 133 | 134 | evaluate_model(model, model.step, dev_samples, evidence_dict, encoder) 135 | model.close_sess() 136 | 137 | if args.resume: 138 | model, log_trn_perf_path, log_trn_stats_path = create_model(word_embeddings, char_emb_len, args.model_id, train=True, step=args.model_step) 139 | store_execution_config(model, LOG_CONFIG) 140 | print_config(model, LOG_CONFIG) 141 | print("Loading model to resume training: {} step {}".format(args.model_id, args.model_step)) 142 | model.start_sess(args.num_threads) 143 | model.load(args.model_step, LOG_CONFIG) 144 | 145 | with open(log_trn_perf_path, 'w', LOG_FILE_BUFF_SIZE) as flogperf, \ 146 | open(log_trn_stats_path, 'w', LOG_FILE_BUFF_SIZE) as flogstats: 147 | write_flog("dataset\tstep\t" + "\t".join(key for key in MetricsAgg._fields) + "\n", flogperf) 148 | write_flog("step\tloss\tgrads\tpath_len\tavg.reward\tmin.reward\tmax.reward\n", flogstats) 149 | model.train(train_samples, dev_samples, evidence_dict, encoder, flogstats, flogperf) 150 | print("\nFinished training.") 151 | 152 | evaluate_model(model, model.step, dev_samples, evidence_dict, encoder) 153 | model.close_sess() 154 | 155 | if args.evaluate: 156 | model, log_trn_perf_path, log_trn_stats_path = create_model(word_embeddings, char_emb_len, args.model_id, train=False) 157 | 158 | step = "best" if args.model_best else args.model_step 159 | print("Loading model for evaluation: {} step {}".format(args.model_id, step)) 160 | model.start_sess(args.num_threads) 161 | if args.model_best: 162 | model.load_best(LOG_CONFIG) 163 | else: 164 | model.load(args.model_step, LOG_CONFIG) 165 | 166 | evaluate_model(model, step, dev_samples, evidence_dict, encoder) 167 | 168 | model.close_sess() 169 | 170 | if args.test: 171 | model, log_trn_perf_path, log_trn_stats_path = create_model(word_embeddings, char_emb_len, args.model_id, train=False) 172 | step = "best" if args.model_best else args.model_step 173 | print("Loading model for test: {} step {}".format(args.model_id, step)) 174 | model.start_sess(args.num_threads) 175 | if args.model_best: 176 | model.load_best(LOG_CONFIG) 177 | else: 178 | model.load(args.model_step, LOG_CONFIG) 179 | 180 | test_model(model, step, evidence_dict, encoder) 181 | 182 | 183 | if __name__ == '__main__': 184 | main() 185 | -------------------------------------------------------------------------------- /code/run_rpc_server.py: -------------------------------------------------------------------------------- 1 | ###################################################################### 2 | # Evidence RPC Server 3 | # 4 | # Running on top of RabbitMQ 5 | # 6 | ###################################################################### 7 | 8 | ################################### 9 | # Imports 10 | # 11 | 12 | import pika 13 | import pickle 14 | from tqdm import tqdm 15 | import simplejson as json 16 | from anytree.exporter import JsonExporter 17 | 18 | 19 | ################################### 20 | # Globals 21 | # 22 | 23 | evidence_dict_path = '../data/evidence_dict_cased.nop.pkl' 24 | predictions_path = '../data/all_qa_para_preds.nop.json' 25 | 26 | server_evidence_dict = None 27 | server_prediction_dict = None 28 | 29 | 30 | ################################### 31 | # Functions 32 | # 33 | 34 | def load_serialize_evidence_dict(): 35 | global server_evidence_dict 36 | print("Loading evidence dict...", end='') 37 | with open(evidence_dict_path, 'rb') as fd: 38 | evidence_dict = pickle.load(fd) 39 | print("{} evidences".format(len(evidence_dict))) 40 | 41 | print("Serializing evidence dict...") 42 | exporter = JsonExporter() 43 | for key in tqdm(evidence_dict): 44 | evidence = evidence_dict[key] 45 | evidence['tree'] = exporter.export(evidence['tree']) 46 | evidence_dict[key] = json.dumps(evidence) 47 | 48 | server_evidence_dict = evidence_dict 49 | 50 | 51 | def load_prediction_dict(): 52 | global server_prediction_dict 53 | 54 | print("Loading prediction dict...", end='') 55 | with open(predictions_path, 'r') as fd: 56 | prediction_dict = json.load(fd) 57 | print("{} qid-eidx predictions".format(len(prediction_dict))) 58 | 59 | server_prediction_dict = prediction_dict 60 | 61 | 62 | def get_evidence(evidence_title): 63 | return server_evidence_dict[evidence_title] 64 | 65 | 66 | def get_response(request): 67 | qid, evidence_title = request.split('--', 1) 68 | evidence = get_evidence(evidence_title) 69 | if qid == '': 70 | return evidence 71 | predictions = server_prediction_dict[request] 72 | 73 | return json.dumps([evidence, predictions]) 74 | 75 | 76 | def on_request(ch, method, props, body): 77 | body_str = body.decode("utf-8") 78 | response = get_response(body_str) 79 | 80 | ch.basic_publish(exchange='', 81 | routing_key=props.reply_to, 82 | properties=pika.BasicProperties(correlation_id=props.correlation_id), 83 | body=response) 84 | ch.basic_ack(delivery_tag=method.delivery_tag) 85 | 86 | 87 | ################################### 88 | # Main 89 | # 90 | 91 | def main(): 92 | print("\nInitializing evidence RPC server...") 93 | load_serialize_evidence_dict() 94 | load_prediction_dict() 95 | 96 | connection = pika.BlockingConnection(pika.ConnectionParameters(host='localhost')) 97 | 98 | channel = connection.channel() 99 | channel.queue_declare(queue='rpc_queue') 100 | channel.basic_qos(prefetch_count=1) 101 | channel.basic_consume(on_request, queue='rpc_queue') 102 | 103 | print("\nReady for RPC requests...") 104 | channel.start_consuming() 105 | 106 | 107 | if __name__ == "__main__": 108 | main() 109 | 110 | -------------------------------------------------------------------------------- /code/setup.py: -------------------------------------------------------------------------------- 1 | ###################################################################### 2 | # Setup 3 | # 4 | # Download and extract TriviaQA-NoP raw and preprocessed data 5 | # 6 | ###################################################################### 7 | 8 | 9 | ################################### 10 | # Imports 11 | # 12 | 13 | import os 14 | import sys 15 | 16 | ################################### 17 | # Globals 18 | # 19 | 20 | DATA_BASE_URL = 'https://www.cs.tau.ac.il/~taunlp/triviaqa-nop/{}' 21 | DATA_BASE_DIR = '../data/{}' 22 | FNAMES = ['triviaqa-nop.gz', 'triviaqa-nop-preprocessed.gz'] 23 | FDESCS = ['TriviaQA-NoP dataset', 'TriviaQA-NoP preprocessed data'] 24 | PREDS = 'all_qa_para_preds.gz' 25 | 26 | 27 | ################################### 28 | # Functions 29 | # 30 | 31 | def download_data(): 32 | for i, fname in enumerate(FNAMES): 33 | if os.path.isfile(DATA_BASE_DIR.format(fname)): 34 | continue 35 | print('Downloading {}'.format(FDESCS[i])) 36 | wget_cmd = 'wget {} -O {}'.format(DATA_BASE_URL.format(fname), DATA_BASE_DIR.format(fname)) 37 | if os.system(wget_cmd) != 0: 38 | print('Failure executing "{}"'.format(wget_cmd)) 39 | sys.exit(1) 40 | 41 | 42 | def extract_data(): 43 | for i, fname in enumerate(FNAMES): 44 | print('Extracting {}'.format(FDESCS[i])) 45 | tar_cmd = 'tar -xzf {} -C {}'.format(DATA_BASE_DIR.format(FNAMES[i]), DATA_BASE_DIR.format('')) 46 | if os.system(tar_cmd) != 0: 47 | print('Failure executing "{}"'.format(tar_cmd)) 48 | sys.exit(1) 49 | 50 | assert os.path.exists(DATA_BASE_DIR.format(PREDS)) 51 | print('Extracting RaSoR predictions') 52 | tar_cmd = 'tar -xzf {} -C {}'.format(DATA_BASE_DIR.format(PREDS), DATA_BASE_DIR.format('')) 53 | if os.system(tar_cmd) != 0: 54 | print('Failure executing "{}"'.format(tar_cmd)) 55 | sys.exit(1) 56 | 57 | 58 | def delete_gz_files(): 59 | for i, fname in enumerate(FNAMES + [PREDS]): 60 | if not os.path.isfile(DATA_BASE_DIR.format(fname)): 61 | continue 62 | print('Deleting {}'.format(DATA_BASE_DIR.format(fname))) 63 | del_cmd = 'rm {}'.format(DATA_BASE_DIR.format(fname)) 64 | if os.system(del_cmd) != 0: 65 | print('Failure executing "{}"'.format(del_cmd)) 66 | continue 67 | 68 | 69 | ################################### 70 | # Main 71 | # 72 | 73 | def main(): 74 | if not os.path.exists(DATA_BASE_DIR.format('')): 75 | os.makedirs(DATA_BASE_DIR.format('')) 76 | download_data() 77 | extract_data() 78 | delete_gz_files() 79 | 80 | 81 | if __name__ == '__main__': 82 | main() 83 | -------------------------------------------------------------------------------- /code/utils/analytics.py: -------------------------------------------------------------------------------- 1 | ###################################################################### 2 | # Analytics utils 3 | # 4 | # Helper functions to analyze navigation performance 5 | # 6 | ###################################################################### 7 | 8 | 9 | ################################### 10 | # Imports 11 | # 12 | 13 | from collections import namedtuple 14 | import numpy as np 15 | import pandas as pd 16 | from utils.tree_navigation import is_sentence, is_ans_in_subtree, get_node_idx_path, \ 17 | get_closest_idx_diff, find_closest_ans_node, navigation_dist_idx_path, get_evidence_length 18 | 19 | 20 | ################################### 21 | # Globals 22 | # 23 | 24 | MetricsAgg = namedtuple('MetricsAgg', [ 25 | 'avg_acc', 26 | 'avg_reward', 27 | 'max_reward', 28 | 'min_reward', 29 | 'avg_q_diff', 30 | 'avg_q_diff_win', 31 | 'avg_path_len', 32 | 'avg_illegal' 33 | ]) 34 | 35 | 36 | ################################### 37 | # Functions 38 | # 39 | 40 | def get_node_metrics(node, ans_line_idx): 41 | closest_ans_node, section_min_dist = find_closest_ans_node(node, ans_line_idx) 42 | if is_sentence(node): 43 | node_idx_path = get_node_idx_path(node.parent) 44 | else: 45 | node_idx_path = get_node_idx_path(node) 46 | ans_idx_path = get_node_idx_path(closest_ans_node) 47 | line_min_dist = navigation_dist_idx_path(node_idx_path, ans_idx_path) 48 | 49 | closest_line_idx, closest_line_diff = get_closest_idx_diff(node, [closest_ans_node.line]) 50 | last_line = node.line - int(node.is_root) 51 | fsubtree = is_ans_in_subtree(node, ans_line_idx) 52 | 53 | node_metrics = [node_idx_path, ans_idx_path, section_min_dist, line_min_dist, 54 | closest_line_diff, int(fsubtree), node.height, last_line] 55 | return node_metrics 56 | 57 | 58 | def get_sample_metrics(node, ans_line_idx, evidence, t, path_num_illegal_moves, q_values, reward): 59 | last_node_metrics = get_node_metrics(node, ans_line_idx) 60 | evidence_metrics = [get_evidence_length(evidence), len(ans_line_idx), ans_line_idx[0]] 61 | navigation_metrics = [t, path_num_illegal_moves, max(q_values) - reward] 62 | 63 | return last_node_metrics + evidence_metrics + navigation_metrics 64 | 65 | 66 | def get_metrics(metrics, reward_sums): 67 | df = pd.DataFrame(metrics, columns=('node_idx_path', 'ans_idx_path', 'section_min_dist', 'line_min_dist', 68 | 'closest_line_diff', 'fsubtree', 'last_height', 'last_line', 69 | 'evidence_len', 'num_line_idx', 'first_line_idx', 'num_steps', 'num_illegal', 'q_diff')) 70 | # accuracy stats 71 | acc_all = len(df[df.closest_line_diff == 0]) / len(df) 72 | 73 | # reward and q-value stats 74 | avg_reward = np.mean(reward_sums) 75 | max_reward = np.max(reward_sums) 76 | min_reward = np.min(reward_sums) 77 | avg_q_diff = df.q_diff.mean() 78 | avg_q_diff_win = df[df.closest_line_diff == 0].q_diff.mean() 79 | 80 | # navigation general stats 81 | df['illegal_frac'] = df.num_illegal / df.num_steps 82 | avg_illegal = df.illegal_frac.mean() 83 | avg_path_len = df.num_steps.mean() 84 | 85 | metrics_agg = MetricsAgg( 86 | avg_acc=acc_all, 87 | avg_reward=avg_reward, 88 | max_reward=max_reward, 89 | min_reward=min_reward, 90 | avg_q_diff=avg_q_diff, 91 | avg_q_diff_win=avg_q_diff_win, 92 | avg_path_len=avg_path_len, 93 | avg_illegal=avg_illegal) 94 | 95 | return metrics_agg, df 96 | 97 | 98 | def metrics_agg_to_str(metrics_agg, fields=True): 99 | if fields: 100 | return "\t".join(["{} {:.4f}".format(key, getattr(metrics_agg, key)) for key in metrics_agg._fields[:-2]]) 101 | else: 102 | return "\t".join(["{:.4f}".format(getattr(metrics_agg, key)) for key in metrics_agg._fields[:-2]]) 103 | 104 | -------------------------------------------------------------------------------- /code/utils/data_processing.py: -------------------------------------------------------------------------------- 1 | ###################################################################### 2 | # Data processing 3 | # 4 | # Data structures and constants related to the preprocessed data 5 | # 6 | ###################################################################### 7 | 8 | 9 | ################################### 10 | # Imports 11 | # 12 | 13 | from collections import namedtuple 14 | 15 | 16 | ################################### 17 | # Globals 18 | # 19 | 20 | MAX_PROCESSES = 6 # maximum processes in pool at given time 21 | 22 | WORD_VOCAB_SIZE = 100000 23 | CHAR_THR = 49 24 | NUM_UNK = 256 25 | PADDING = "" 26 | PADDING_IDX = 0 27 | UNK = ["".format(i) for i in range(NUM_UNK)] 28 | 29 | WORD_EMBEDDING_DIM = 300 30 | 31 | MAX_ANS_OCC = 500 32 | 33 | WordEmbeddings = namedtuple('WordEmbeddings', [ 34 | 'known', # known words, including PADDING 35 | 'unknown' # unknown words 36 | ]) 37 | 38 | Vocabulary = namedtuple('Vocabulary', [ 39 | 'char_indices', 40 | 'word_indices', 41 | 'index_words' 42 | ]) 43 | 44 | DataConfig = namedtuple('DataConfig', [ 45 | 'glove_embeddings_path', 46 | 'train_dataset', 47 | 'dev_dataset', 48 | 'test_dataset', 49 | 'evidence_dir', 50 | 'glove_vocab_path', 51 | 'evidence_dict_path' 52 | ]) 53 | 54 | DATA_CONFIG_NOP = DataConfig(glove_embeddings_path="../data/glove.840B.300d.nop.pkl", 55 | train_dataset="../data/qa/wikipedia-train.nop.json", 56 | dev_dataset="../data/qa/wikipedia-dev.nop.json", 57 | test_dataset="../data/qa/wikipedia-test-without-answers.nop.json", 58 | evidence_dir="../data/evidence/wikipedia", 59 | glove_vocab_path="../data/vocabulary_cased_glove.nop.pkl", 60 | evidence_dict_path="../data/evidence_dict_cased.nop.pkl" 61 | ) 62 | 63 | 64 | ################################### 65 | # Helper functions 66 | # 67 | 68 | def hash_token(token): 69 | return UNK[hash(token) % NUM_UNK] 70 | -------------------------------------------------------------------------------- /code/utils/io_utils.py: -------------------------------------------------------------------------------- 1 | ###################################################################### 2 | # IO Utils 3 | # 4 | # Handles training configuration and logging 5 | # 6 | ###################################################################### 7 | 8 | 9 | ################################### 10 | # Imports 11 | # 12 | 13 | from collections import namedtuple 14 | import os 15 | import argparse 16 | import numpy as np 17 | import pickle 18 | import json 19 | import random 20 | 21 | from utils.tree_navigation import ACTIONS, get_node_dist, get_evidence_title 22 | 23 | ################################### 24 | # Globals 25 | # 26 | 27 | LogConfig = namedtuple('LogConfig', [ 28 | 'log_trn_perf_navigator', 29 | 'log_trn_stats_navigator', 30 | 'dbg_log_perf_navigator', 31 | 'navigator_output_path', 32 | 'conf_path', 33 | 'model_temp', 34 | 'log_dir' 35 | ]) 36 | 37 | ModelConfig = namedtuple('ModelConfig', [ 38 | 'word_embedding_dim', 39 | 'char_embedding_dim', 40 | 'hidden_dim_q', 41 | 'hidden_dim_x', 42 | 'hidden_dim_a', 43 | 'props_dim', 44 | 'ans_props_dim', 45 | 'output_dim', 46 | 'dropout_rate', 47 | 'learning_rate', 48 | 'seq_length', 49 | 'observ_length', 50 | 'token_length' 51 | ]) 52 | 53 | TrainConfig = namedtuple('TrainConfig', [ 54 | 'batch_size', 55 | 'max_steps', 56 | 'replay_memory_init_size', 57 | 'replay_memory_size', 58 | 'per_alpha', 59 | 'per_beta_start', 60 | 'per_beta_end', 61 | 'per_beta_growth_steps', 62 | 'per_eps', 63 | 'update_estimator_freq', 64 | 'check_freq', 65 | 'max_episode_steps', 66 | 'epsilon_a_start', 67 | 'epsilon_a_end', 68 | 'epsilon_a_decay_steps', 69 | 'epsilon_s_start', 70 | 'epsilon_s_end', 71 | 'epsilon_s_decay_steps', 72 | 'gamma', 73 | 'scores', 74 | 'policy_type', 75 | 'train_protocol', 76 | 'combined_random_samples', 77 | 'ans_radius', 78 | 'ans_dist_prob' 79 | ]) 80 | 81 | RewardScores = namedtuple('RewardScores', [ 82 | 'r_delta', 83 | 'r_win', 84 | 'r_lose', 85 | 'r_illegal' 86 | ]) 87 | 88 | CONF_PATH = '../logs/{}_rl.conf' 89 | MODEL_TEMP = '../models/{}' 90 | EVIDENCE_TEMP = '../data/evidence/wikipedia/{}' 91 | VDEV_JSON = '../data/qa/verified-wikipedia-dev.json' 92 | LOG_FILE_BUFF_SIZE = 1 93 | TEST_SEED = 1618033988 94 | 95 | 96 | ################################### 97 | # Functions 98 | # 99 | 100 | def valid_args(args): 101 | if args.resume or args.evaluate or args.test: 102 | if args.model_id is None or (args.model_step is None and not args.model_best): 103 | print("Both model id and step must be specified for resuming / evaluating a model") 104 | return False 105 | else: 106 | if args.model_best: 107 | model_step_path = os.path.join(MODEL_TEMP.format(args.model_id), args.model_id + '-best' + '.meta') 108 | else: 109 | model_step_path = os.path.join(MODEL_TEMP.format(args.model_id), args.model_id + '-' + str(args.model_step) + '.meta') 110 | 111 | if not os.path.isfile(model_step_path): 112 | print("There is no model with the given id: {}, step {}".format(args.model_id, args.model_step)) 113 | return False 114 | 115 | return True 116 | 117 | 118 | def get_configuration(word_embedding_dim, props_dim, args): 119 | if args.resume: 120 | model_conf, train_conf, log_conf, seed = load_execution_config(CONF_PATH.format(args.model_id)) 121 | train_conf = TrainConfig( 122 | batch_size=train_conf.batch_size, 123 | max_steps=train_conf.max_steps + args.max_steps, 124 | replay_memory_init_size=train_conf.replay_memory_init_size, 125 | replay_memory_size=train_conf.replay_memory_size, 126 | per_alpha=train_conf.per_alpha, 127 | per_beta_start=train_conf.per_beta_start, 128 | per_beta_end=train_conf.per_beta_end, 129 | per_beta_growth_steps=train_conf.per_beta_growth_steps, 130 | per_eps=train_conf.per_eps, 131 | update_estimator_freq=train_conf.update_estimator_freq, 132 | check_freq=train_conf.check_freq, 133 | max_episode_steps=train_conf.max_episode_steps, 134 | epsilon_a_start=train_conf.epsilon_a_start, 135 | epsilon_a_end=train_conf.epsilon_a_end, 136 | epsilon_a_decay_steps=train_conf.epsilon_a_decay_steps, 137 | epsilon_s_start=train_conf.epsilon_s_start, 138 | epsilon_s_end=train_conf.epsilon_s_end, 139 | epsilon_s_decay_steps=train_conf.epsilon_s_decay_steps, 140 | gamma=train_conf.gamma, 141 | scores=train_conf.scores, 142 | policy_type=train_conf.policy_type, 143 | train_protocol=train_conf.train_protocol, 144 | combined_random_samples=train_conf.combined_random_samples, 145 | ans_radius=train_conf.ans_radius, 146 | ans_dist_prob=train_conf.ans_dist_prob 147 | ) 148 | return log_conf, model_conf, train_conf, seed 149 | 150 | log_conf = LogConfig( 151 | log_trn_perf_navigator='../logs/{}_{}_trn_perf.log', # model_id, init step (non zero when resuming training) 152 | log_trn_stats_navigator='../logs/{}_{}_trn_stats.log', # model_id, init step (non zero when resuming training) 153 | dbg_log_perf_navigator='../logs/{}_{}_{}.dbg.log', # model_id, evaluation step 154 | navigator_output_path='../logs/{}_{}_{}_output.json', # model_id, init step (non zero when resuming training), name 155 | conf_path=CONF_PATH, 156 | model_temp=MODEL_TEMP, 157 | log_dir='../logs/' 158 | ) 159 | 160 | model_conf = ModelConfig( 161 | word_embedding_dim=word_embedding_dim, 162 | char_embedding_dim=20, 163 | hidden_dim_q=300, 164 | hidden_dim_x=300, 165 | hidden_dim_a=300, 166 | props_dim=props_dim, 167 | ans_props_dim=3, 168 | output_dim=7, 169 | dropout_rate=1.0 if args.evaluate or args.test else args.keep_rate, 170 | learning_rate=args.learning_rate, 171 | seq_length=args.max_seq_len, 172 | observ_length=args.max_seq_len * 6, # 6 levels 173 | token_length=args.max_token_len 174 | ) 175 | 176 | rscores = RewardScores( 177 | r_delta=-0.02, 178 | r_win=1.0, 179 | r_lose=-1.0, 180 | r_illegal=-0.02 181 | ) 182 | 183 | per_beta_growth_steps = args.per_beta_growth_steps if args.per_beta_growth_steps is not None else args.max_steps 184 | epsilon_a_decay_steps = args.epsilon_a_decay_steps if args.epsilon_a_decay_steps is not None else args.max_steps 185 | epsilon_s_decay_steps = args.epsilon_s_decay_steps if args.epsilon_s_decay_steps is not None else args.max_steps 186 | 187 | max_episode_steps = 100 if args.evaluate or args.test else 30 188 | train_conf = TrainConfig( 189 | batch_size=args.batch_size, 190 | max_steps=args.max_steps, 191 | replay_memory_init_size=args.rm_init_size, 192 | replay_memory_size=args.rm_size, 193 | per_alpha=0.6, 194 | per_beta_start=0.4, 195 | per_beta_end=1.0, 196 | per_beta_growth_steps=per_beta_growth_steps, 197 | per_eps=1e-6, 198 | update_estimator_freq=args.estimator_freq, 199 | check_freq=args.check_freq, 200 | max_episode_steps=max_episode_steps, 201 | epsilon_a_start=1.0, 202 | epsilon_a_end=0.1, 203 | epsilon_a_decay_steps=epsilon_a_decay_steps, 204 | epsilon_s_start=args.epsilon_s_start, 205 | epsilon_s_end=args.epsilon_s_end, 206 | epsilon_s_decay_steps=epsilon_s_decay_steps, 207 | gamma=0.996, 208 | scores=rscores, 209 | policy_type=args.policy_type, 210 | train_protocol=args.train_protocol, 211 | combined_random_samples=args.combined_random_samples, 212 | ans_radius=args.ans_radius, 213 | ans_dist_prob=args.ans_dist_prob 214 | ) 215 | 216 | seed = args.seed if args.seed is not None else random.randrange(2 ** 32) 217 | if args.evaluate or args.test: 218 | seed = TEST_SEED 219 | 220 | return log_conf, model_conf, train_conf, seed 221 | 222 | 223 | def parse_args(): 224 | parser = argparse.ArgumentParser() 225 | parser.add_argument('--train', action='store_true', default=False, help='Train a new model') 226 | parser.add_argument('--resume', action='store_true', default=False, help='Resume training of an existing model') 227 | parser.add_argument('--test', action='store_true', default=False, help='Test a trained model') 228 | parser.add_argument('--evaluate', action='store_true', default=False, help='Evaluate a trained model') 229 | parser.add_argument('--model_id', default=None, help='Model id to resume training / evaluate', type=str) 230 | parser.add_argument('--model_step', default=None, help='Step of the model to resume training from / evaluate', type=int) 231 | parser.add_argument('--model_best', action='store_true', default=False, help='Evaluate the best trained model') 232 | parser.add_argument("--num_threads", type=int, default=-1, help="Number of CPU cores to use (maximum that is available if not set)") 233 | parser.add_argument("--seed", type=int, help="Random seed, default is random") 234 | parser.add_argument("--tfevents", action='store_true', default=False, help="Generate TF events for tensorboard") 235 | parser.add_argument("--rm_init_size", type=int, default=50000, help="Replay memory initial size, default is 50K") 236 | parser.add_argument("--rm_size", type=int, default=300000, help="Replay memory size, default is 300K") 237 | parser.add_argument("--per_beta_growth_steps", type=int, default=4000000, 238 | help="Number of steps to increase beta, default is 4M") 239 | parser.add_argument("--epsilon_a_decay_steps", type=int, default=1000000, 240 | help="Number of steps to decay epsilon_a, default is 1M") 241 | parser.add_argument("--epsilon_s_decay_steps", type=int, default=2000000, 242 | help="Number of steps to decay epsilon_s, default is 2M") 243 | parser.add_argument("--epsilon_s_start", type=float, default=1.0, help="epsilon_s initial value, default is 1.0") 244 | parser.add_argument("--epsilon_s_end", type=float, default=0.5, help="epsilon_s final value, default is 0.5") 245 | parser.add_argument("--batch_size", type=int, default=64, help="Batch size, default is 64") 246 | parser.add_argument("--max_steps", type=int, default=4000000, help="Number of steps for training, default is 4M. \ 247 | When resuming training, it specifies how many steps will be added to the original number of steps.") 248 | parser.add_argument("--keep_rate", type=float, default=0.8, help="Keep rate for dropout, default is 0.8") 249 | parser.add_argument("--learning_rate", type=float, default=0.0001, 250 | help="Initial learning rate, default is 0.0001") 251 | parser.add_argument("--max_seq_len", type=int, default=20, help="Maximum number of tokens in node observation") 252 | parser.add_argument("--max_token_len", type=int, default=50, help="Maximum number of characters in token") 253 | parser.add_argument("--estimator_freq", type=int, default=10000, 254 | help="Frequency to update estimator parameters, default is 10K") 255 | parser.add_argument("--check_freq", type=int, default=20000, 256 | help="Frequency to evaluate model performance, default is 20K") 257 | parser.add_argument('--policy_type', dest='policy_type', default='eglp', choices=['egp', 'eglp'], help='Default is eglp', type=str) 258 | parser.add_argument('--train_protocol', dest='train_protocol', default='sequential', 259 | choices=['sequential', 'random_balanced', 'combined', 'combined_ans_radius'], help='Default is sequential', type=str) 260 | parser.add_argument("--combined_random_samples", type=int, default=5, 261 | help="Number of random samples per iteration for combined training protocols, default is 5") 262 | parser.add_argument("--ans_radius", type=int, default=3, 263 | help="Answer radius for training protocol combined_ans_radius, default is 3") 264 | parser.add_argument("--ans_dist_prob", type=float, default=0.5, 265 | help="Random sampling probability for the training protocol combined_ans_radius, default is 0.5") 266 | return parser.parse_args() 267 | 268 | 269 | def write_flog(text, flog): 270 | if flog is not None: 271 | flog.write(text) 272 | 273 | 274 | def print_config(model, log_config): 275 | print("\nexecution config:\n----------------------------") 276 | print("model_id:\t{}\nseed:\t{}\n".format(model.model_id, model.seed)) 277 | for config in [model.mc, model.tc, log_config]: 278 | for key in config._fields: 279 | print("{}:\t{}".format(key, getattr(config, key))) 280 | print("\n") 281 | print("----------------------------\n") 282 | 283 | 284 | def store_execution_config(model, log_config): 285 | confs = [model.mc, model.tc, log_config] 286 | 287 | with open(log_config.conf_path.format(model.model_id) + '.txt', 'w') as fout: 288 | fout.write("model_id:\t{}\nseed:\t{}\n\n".format(model.model_id, model.seed)) 289 | for config in confs: 290 | for key in config._fields: 291 | fout.write("{}:\t{}\n".format(key, getattr(config, key))) 292 | fout.write("\n") 293 | 294 | with open(log_config.conf_path.format(model.model_id), 'wb') as fout: 295 | pickle.dump(confs + [model.seed], fout) 296 | 297 | 298 | def load_execution_config(conf_path): 299 | with open(conf_path, 'rb') as fconf: 300 | confs = pickle.load(fconf) 301 | return confs 302 | 303 | 304 | def write_train_stats(step, reward_sums, path_avg_loss, path_avg_grads, path_lengths, flogstats): 305 | mean_rwrd, min_rwrd, max_rwrd = np.mean(reward_sums), np.min(reward_sums), np.max(reward_sums) 306 | avg_loss, avg_grd, avg_pl = np.mean(path_avg_loss), np.mean(path_avg_grads), np.mean(path_lengths) 307 | write_flog('{}\t{:.6f}\t{:.4e}\t{}\t{:.4f}\t{:.4f}\t{:.4f}\n'.format( 308 | step, avg_loss, avg_grd, avg_pl, mean_rwrd, min_rwrd, max_rwrd), flogstats) 309 | print("step {}\tloss {:.6f}\tgrads {:.4e}\tavg_path_len {:.4f}\tavg.reward {:.4f}\tmin.reward {:.4f}\tmax.reward {:.4f}".format( 310 | step, avg_loss, avg_grd, avg_pl, mean_rwrd, min_rwrd, max_rwrd)) 311 | 312 | 313 | def write_predict_paths_header(flog): 314 | write_flog('evidence_title\tevidence_idx\tstep\tquestion_txt\tanswer\tobserv_line\tobserv_height\tobserv_depth\tobserv_level_idx\tobserv' 315 | + '\tobserv_wts\tquestion\tq_wts\tq_values\taction\treward\tans_line_idx\tdescription\n', flog) 316 | 317 | 318 | def write_step_start_msg(flog, info, t, node, observ_tokens, x_weights, q_tokens, q_weights, q_values, action, reward): 319 | if flog is None: 320 | return 321 | 322 | (question_w, question_c, question_txt, answer_txt, eidx, ans_line_idx, evidence, predictions) = info 323 | q_values_str = ["{:.4f}".format(x) for x in q_values] 324 | x_weights_str = ["{:.4f}".format(x_weights[i]) for i in range(len(observ_tokens))] # ignoring padding weights 325 | q_weights_str = ["{:.4f}".format(x) for x in q_weights.tolist()] 326 | evidence_title = "_".join([x[0] for x in evidence["title_tokens"]]) 327 | msg = "{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}".format( 328 | evidence_title, eidx, t - 1, question_txt, answer_txt, node.line - node.is_root, node.height, node.depth, get_node_dist(node)[0], 329 | observ_tokens, x_weights_str, q_tokens, q_weights_str, q_values_str, action, reward, ans_line_idx) 330 | write_flog(msg, flog) 331 | 332 | 333 | def write_step_end_msg(flog, node, next_node, action, prediction_text): 334 | if flog is None: 335 | return 336 | 337 | if (node == next_node and action != ACTIONS["ANS"]) or \ 338 | (node.is_root and action == ACTIONS["ANS"]): 339 | write_flog("\tIllegal move\n", flog) 340 | 341 | elif not node.is_root and action == ACTIONS["ANS"]: 342 | assert prediction_text is not None 343 | write_flog("\tRaSoR output: {}\n".format(prediction_text), flog) 344 | 345 | else: 346 | write_flog("\t\n", flog) 347 | 348 | 349 | def write_path_end_msg(flog, node, done, action, line_diff): 350 | if flog is None: 351 | return 352 | 353 | if done: 354 | if action == ACTIONS['STOP']: 355 | if node.is_root: 356 | write_flog("\tStopped at the title\n", flog) 357 | elif line_diff == 0: 358 | write_flog("\tWin!\n", flog) 359 | else: 360 | write_flog("\tStopped at wrong place\n", flog) 361 | else: 362 | write_flog("\tReached maximum steps\n", flog) 363 | 364 | 365 | def create_json_record(sample, line_idx, from_raw=False, line_txt=None): 366 | if from_raw: 367 | context = line_txt 368 | else: 369 | context = sample['OrigEvidenceFilename'] + '_' + str(line_idx) 370 | 371 | record = {"title": get_evidence_title(sample) + '_' + str(line_idx), 372 | "paragraphs": [ 373 | { 374 | "context": context, 375 | "qas": [ 376 | { 377 | "id": sample['QuestionId'] + '--' + sample['OrigEvidenceFilename'] + '-->' + str(line_idx), 378 | "question": sample['Question'] 379 | } 380 | ] 381 | } 382 | ] 383 | } 384 | 385 | return record 386 | 387 | 388 | def get_evidence_line_idx_context(record): 389 | filename, line_idx = record["paragraphs"][0]["context"].rsplit('_', 1) 390 | line_idx = int(line_idx) 391 | 392 | if line_idx < 0: 393 | context = filename.rsplit('.', 1)[0].replace('_', ' ') 394 | 395 | else: 396 | evidence_filename = filename.replace('.txt', '.nop.txt') 397 | raw_evidence_path = EVIDENCE_TEMP.format(evidence_filename) 398 | 399 | with open(raw_evidence_path, 'r', encoding='utf-8') as fd: 400 | lines = fd.readlines() 401 | context = lines[line_idx].strip().strip('\n') 402 | 403 | return context 404 | 405 | 406 | def write_predictions_json(records, output_path, from_raw=False): 407 | if output_path is None: 408 | return 409 | 410 | if not from_raw: 411 | for record in records: 412 | context = get_evidence_line_idx_context(record) 413 | record["paragraphs"][0]["context"] = context 414 | 415 | with open(output_path, 'w') as fd: 416 | json.dump({"data": records}, fd) 417 | 418 | -------------------------------------------------------------------------------- /code/utils/io_utils_c.py: -------------------------------------------------------------------------------- 1 | ###################################################################### 2 | # IO Utils 3 | # 4 | # Handles training configuration and logging 5 | # 6 | ###################################################################### 7 | 8 | 9 | ################################### 10 | # Imports 11 | # 12 | 13 | from collections import namedtuple 14 | import os 15 | import argparse 16 | import numpy as np 17 | import pickle 18 | import json 19 | import random 20 | 21 | from utils.tree_navigation_c import ACTIONS, get_node_dist, get_evidence_title 22 | 23 | ################################### 24 | # Globals 25 | # 26 | 27 | LogConfig = namedtuple('LogConfig', [ 28 | 'log_trn_perf_navigator', 29 | 'log_trn_stats_navigator', 30 | 'dbg_log_perf_navigator', 31 | 'navigator_output_path', 32 | 'conf_path', 33 | 'model_temp', 34 | 'log_dir' 35 | ]) 36 | 37 | ModelConfig = namedtuple('ModelConfig', [ 38 | 'word_embedding_dim', 39 | 'char_embedding_dim', 40 | 'hidden_dim_q', 41 | 'hidden_dim_x', 42 | 'props_dim', 43 | 'output_dim', 44 | 'dropout_rate', 45 | 'learning_rate', 46 | 'seq_length', 47 | 'observ_length', 48 | 'token_length' 49 | ]) 50 | 51 | TrainConfig = namedtuple('TrainConfig', [ 52 | 'batch_size', 53 | 'max_steps', 54 | 'replay_memory_init_size', 55 | 'replay_memory_size', 56 | 'per_alpha', 57 | 'per_beta_start', 58 | 'per_beta_end', 59 | 'per_beta_growth_steps', 60 | 'per_eps', 61 | 'update_estimator_freq', 62 | 'check_freq', 63 | 'max_episode_steps', 64 | 'epsilon_a_start', 65 | 'epsilon_a_end', 66 | 'epsilon_a_decay_steps', 67 | 'epsilon_s_start', 68 | 'epsilon_s_end', 69 | 'epsilon_s_decay_steps', 70 | 'gamma', 71 | 'scores', 72 | 'policy_type', 73 | 'train_protocol', 74 | 'combined_random_samples', 75 | 'ans_radius', 76 | 'ans_dist_prob' 77 | ]) 78 | 79 | RewardScores = namedtuple('RewardScores', [ 80 | 'r_delta', 81 | 'r_win', 82 | 'r_lose', 83 | 'r_illegal' 84 | ]) 85 | 86 | CONF_PATH = '../logs/{}_rl.conf' 87 | MODEL_TEMP = '../models/{}' 88 | EVIDENCE_TEMP = '../data/evidence/wikipedia/{}' 89 | VDEV_JSON = '../data/qa/verified-wikipedia-dev.json' 90 | LOG_FILE_BUFF_SIZE = 1 91 | TEST_SEED = 1618033988 92 | 93 | 94 | ################################### 95 | # Functions 96 | # 97 | 98 | def valid_args(args): 99 | if args.resume or args.evaluate or args.test: 100 | if args.model_id is None or (args.model_step is None and not args.model_best): 101 | print("Both model id and step must be specified for resuming / evaluating a model") 102 | return False 103 | else: 104 | if args.model_best: 105 | model_step_path = os.path.join(MODEL_TEMP.format(args.model_id), args.model_id + '-best' + '.meta') 106 | else: 107 | model_step_path = os.path.join(MODEL_TEMP.format(args.model_id), args.model_id + '-' + str(args.model_step) + '.meta') 108 | 109 | if not os.path.isfile(model_step_path): 110 | print("There is no model with the given id: {}, step {}".format(args.model_id, args.model_step)) 111 | return False 112 | 113 | return True 114 | 115 | 116 | def get_configuration(word_embedding_dim, props_dim, args): 117 | if args.resume: 118 | model_conf, train_conf, log_conf, seed = load_execution_config(CONF_PATH.format(args.model_id)) 119 | train_conf = TrainConfig( 120 | batch_size=train_conf.batch_size, 121 | max_steps=train_conf.max_steps + args.max_steps, 122 | replay_memory_init_size=train_conf.replay_memory_init_size, 123 | replay_memory_size=train_conf.replay_memory_size, 124 | per_alpha=train_conf.per_alpha, 125 | per_beta_start=train_conf.per_beta_start, 126 | per_beta_end=train_conf.per_beta_end, 127 | per_beta_growth_steps=train_conf.per_beta_growth_steps, 128 | per_eps=train_conf.per_eps, 129 | update_estimator_freq=train_conf.update_estimator_freq, 130 | check_freq=train_conf.check_freq, 131 | max_episode_steps=train_conf.max_episode_steps, 132 | epsilon_a_start=train_conf.epsilon_a_start, 133 | epsilon_a_end=train_conf.epsilon_a_end, 134 | epsilon_a_decay_steps=train_conf.epsilon_a_decay_steps, 135 | epsilon_s_start=train_conf.epsilon_s_start, 136 | epsilon_s_end=train_conf.epsilon_s_end, 137 | epsilon_s_decay_steps=train_conf.epsilon_s_decay_steps, 138 | gamma=train_conf.gamma, 139 | scores=train_conf.scores, 140 | policy_type=train_conf.policy_type, 141 | train_protocol=train_conf.train_protocol, 142 | combined_random_samples=train_conf.combined_random_samples, 143 | ans_radius=train_conf.ans_radius, 144 | ans_dist_prob=train_conf.ans_dist_prob 145 | ) 146 | return log_conf, model_conf, train_conf, seed 147 | 148 | log_conf = LogConfig( 149 | log_trn_perf_navigator='../logs/{}_{}_trn_perf.log', # model_id, init step (non zero when resuming training) 150 | log_trn_stats_navigator='../logs/{}_{}_trn_stats.log', # model_id, init step (non zero when resuming training) 151 | dbg_log_perf_navigator='../logs/{}_{}_{}.dbg.log', # model_id, evaluation step 152 | navigator_output_path='../logs/{}_{}_{}_output.json', # model_id, init step (non zero when resuming training), name 153 | conf_path=CONF_PATH, 154 | model_temp=MODEL_TEMP, 155 | log_dir='../logs/' 156 | ) 157 | 158 | model_conf = ModelConfig( 159 | word_embedding_dim=word_embedding_dim, 160 | char_embedding_dim=20, 161 | hidden_dim_q=300, 162 | hidden_dim_x=300, 163 | props_dim=props_dim, 164 | output_dim=6, 165 | dropout_rate=1.0 if args.evaluate or args.test else args.keep_rate, 166 | learning_rate=args.learning_rate, 167 | seq_length=args.max_seq_len, 168 | observ_length=args.max_seq_len * 6, # 6 levels 169 | token_length=args.max_token_len 170 | ) 171 | 172 | rscores = RewardScores( 173 | r_delta=-0.02, 174 | r_win=1.0, 175 | r_lose=-1.0, 176 | r_illegal=-0.02 177 | ) 178 | 179 | per_beta_growth_steps = args.per_beta_growth_steps if args.per_beta_growth_steps is not None else args.max_steps 180 | epsilon_a_decay_steps = args.epsilon_a_decay_steps if args.epsilon_a_decay_steps is not None else args.max_steps 181 | epsilon_s_decay_steps = args.epsilon_s_decay_steps if args.epsilon_s_decay_steps is not None else args.max_steps 182 | 183 | max_episode_steps = 100 if args.evaluate or args.test else 30 184 | train_conf = TrainConfig( 185 | batch_size=args.batch_size, 186 | max_steps=args.max_steps, 187 | replay_memory_init_size=args.rm_init_size, 188 | replay_memory_size=args.rm_size, 189 | per_alpha=0.6, 190 | per_beta_start=0.4, 191 | per_beta_end=1.0, 192 | per_beta_growth_steps=per_beta_growth_steps, 193 | per_eps=1e-6, 194 | update_estimator_freq=args.estimator_freq, 195 | check_freq=args.check_freq, 196 | max_episode_steps=max_episode_steps, 197 | epsilon_a_start=1.0, 198 | epsilon_a_end=0.1, 199 | epsilon_a_decay_steps=epsilon_a_decay_steps, 200 | epsilon_s_start=args.epsilon_s_start, 201 | epsilon_s_end=args.epsilon_s_end, 202 | epsilon_s_decay_steps=epsilon_s_decay_steps, 203 | gamma=0.996, 204 | scores=rscores, 205 | policy_type=args.policy_type, 206 | train_protocol=args.train_protocol, 207 | combined_random_samples=args.combined_random_samples, 208 | ans_radius=args.ans_radius, 209 | ans_dist_prob=args.ans_dist_prob 210 | ) 211 | 212 | seed = args.seed if args.seed is not None else random.randrange(2 ** 32) 213 | if args.evaluate or args.test: 214 | seed = TEST_SEED 215 | 216 | return log_conf, model_conf, train_conf, seed 217 | 218 | 219 | def parse_args(): 220 | parser = argparse.ArgumentParser() 221 | parser.add_argument('--train', action='store_true', default=False, help='Train a new model') 222 | parser.add_argument('--resume', action='store_true', default=False, help='Resume training of an existing model') 223 | parser.add_argument('--test', action='store_true', default=False, help='Test a trained model') 224 | parser.add_argument('--evaluate', action='store_true', default=False, help='Evaluate a trained model') 225 | parser.add_argument('--model_id', default=None, help='Model id to resume training / evaluate', type=str) 226 | parser.add_argument('--model_step', default=None, help='Step of the model to resume training from / evaluate', type=int) 227 | parser.add_argument('--model_best', action='store_true', default=False, help='Evaluate the best trained model') 228 | parser.add_argument("--num_threads", type=int, default=-1, help="Number of CPU cores to use (maximum that is available if not set)") 229 | parser.add_argument("--seed", type=int, help="Random seed, default is random") 230 | parser.add_argument("--tfevents", action='store_true', default=False, help="Generate TF events for tensorboard") 231 | parser.add_argument('--use_rpc', action='store_true', default=False, help='Use evidence RPC server') 232 | parser.add_argument("--rm_init_size", type=int, default=50000, help="Replay memory initial size, default is 50K") 233 | parser.add_argument("--rm_size", type=int, default=300000, help="Replay memory size, default is 300K") 234 | parser.add_argument("--per_beta_growth_steps", type=int, default=4000000, 235 | help="Number of steps to increase beta, default is 4M") 236 | parser.add_argument("--epsilon_a_decay_steps", type=int, default=1000000, 237 | help="Number of steps to decay epsilon_a, default is 1M") 238 | parser.add_argument("--epsilon_s_decay_steps", type=int, default=2000000, 239 | help="Number of steps to decay epsilon_s, default is 2M") 240 | parser.add_argument("--epsilon_s_start", type=float, default=1.0, help="epsilon_s initial value, default is 1.0") 241 | parser.add_argument("--epsilon_s_end", type=float, default=0.5, help="epsilon_s final value, default is 0.5") 242 | parser.add_argument("--batch_size", type=int, default=64, help="Batch size, default is 64") 243 | parser.add_argument("--max_steps", type=int, default=4000000, help="Number of steps for training, default is 4M. \ 244 | When resuming training, it specifies how many steps will be added to the original number of steps.") 245 | parser.add_argument("--keep_rate", type=float, default=0.8, help="Keep rate for dropout, default is 0.8") 246 | parser.add_argument("--learning_rate", type=float, default=0.0001, 247 | help="Initial learning rate, default is 0.0001") 248 | parser.add_argument("--max_seq_len", type=int, default=20, help="Maximum number of tokens in node observation") 249 | parser.add_argument("--max_token_len", type=int, default=50, help="Maximum number of characters in token") 250 | parser.add_argument("--estimator_freq", type=int, default=10000, 251 | help="Frequency to update estimator parameters, default is 10K") 252 | parser.add_argument("--check_freq", type=int, default=20000, 253 | help="Frequency to evaluate model performance, default is 20K") 254 | parser.add_argument('--policy_type', dest='policy_type', default='eglp', choices=['egp', 'eglp'], help='Default is eglp', type=str) 255 | parser.add_argument('--train_protocol', dest='train_protocol', default='sequential', 256 | choices=['sequential', 'random_balanced', 'combined', 'combined_ans_radius'], help='Default is sequential', type=str) 257 | parser.add_argument("--combined_random_samples", type=int, default=5, 258 | help="Number of random samples per iteration for combined training protocols, default is 5") 259 | parser.add_argument("--ans_radius", type=int, default=3, 260 | help="Answer radius for training protocol combined_ans_radius, default is 3") 261 | parser.add_argument("--ans_dist_prob", type=float, default=0.5, 262 | help="Random sampling probability for the training protocol combined_ans_radius, default is 0.5") 263 | return parser.parse_args() 264 | 265 | 266 | def write_flog(text, flog): 267 | if flog is not None: 268 | flog.write(text) 269 | 270 | 271 | def print_config(model, log_config): 272 | print("\nexecution config:\n----------------------------") 273 | print("model_id:\t{}\nseed:\t{}\n".format(model.model_id, model.seed)) 274 | for config in [model.mc, model.tc, log_config]: 275 | for key in config._fields: 276 | print("{}:\t{}".format(key, getattr(config, key))) 277 | print("\n") 278 | print("----------------------------\n") 279 | 280 | 281 | def store_execution_config(model, log_config): 282 | confs = [model.mc, model.tc, log_config] 283 | 284 | with open(log_config.conf_path.format(model.model_id) + '.txt', 'w') as fout: 285 | fout.write("model_id:\t{}\nseed:\t{}\n\n".format(model.model_id, model.seed)) 286 | for config in confs: 287 | for key in config._fields: 288 | fout.write("{}:\t{}\n".format(key, getattr(config, key))) 289 | fout.write("\n") 290 | 291 | with open(log_config.conf_path.format(model.model_id), 'wb') as fout: 292 | pickle.dump(confs + [model.seed], fout) 293 | 294 | 295 | def load_execution_config(conf_path): 296 | with open(conf_path, 'rb') as fconf: 297 | confs = pickle.load(fconf) 298 | return confs 299 | 300 | 301 | def write_train_stats(step, reward_sums, path_avg_loss, path_avg_grads, path_lengths, flogstats): 302 | mean_rwrd, min_rwrd, max_rwrd = np.mean(reward_sums), np.min(reward_sums), np.max(reward_sums) 303 | avg_loss, avg_grd, avg_pl = np.mean(path_avg_loss), np.mean(path_avg_grads), np.mean(path_lengths) 304 | write_flog('{}\t{:.6f}\t{:.4e}\t{}\t{:.4f}\t{:.4f}\t{:.4f}\n'.format( 305 | step, avg_loss, avg_grd, avg_pl, mean_rwrd, min_rwrd, max_rwrd), flogstats) 306 | print("step {}\tloss {:.6f}\tgrads {:.4e}\tavg_path_len {:.4f}\tavg.reward {:.4f}\tmin.reward {:.4f}\tmax.reward {:.4f}".format( 307 | step, avg_loss, avg_grd, avg_pl, mean_rwrd, min_rwrd, max_rwrd)) 308 | 309 | 310 | def write_predict_paths_header(flog): 311 | write_flog('evidence_title\tevidence_idx\tstep\tquestion_txt\tanswer\tobserv_line\tobserv_height\tobserv_depth\tobserv_level_idx\tobserv' 312 | + '\tobserv_wts\tquestion\tq_wts\tq_values\taction\treward\tans_line_idx\tdescription\n', flog) 313 | 314 | 315 | def write_step_start_msg(flog, info, t, node, observ_tokens, x_weights, q_tokens, q_weights, q_values, action, reward): 316 | if flog is None: 317 | return 318 | 319 | (question_w, question_c, question_txt, answer_txt, eidx, ans_line_idx, evidence) = info 320 | q_values_str = ["{:.4f}".format(x) for x in q_values] 321 | x_weights_str = ["{:.4f}".format(x_weights[i]) for i in range(len(observ_tokens))] # ignoring padding weights 322 | q_weights_str = ["{:.4f}".format(x) for x in q_weights.tolist()] 323 | evidence_title = "_".join([x[0] for x in evidence["title_tokens"]]) 324 | msg = "{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}".format( 325 | evidence_title, eidx, t - 1, question_txt, answer_txt, node.line - node.is_root, node.height, node.depth, get_node_dist(node)[0], 326 | observ_tokens, x_weights_str, q_tokens, q_weights_str, q_values_str, action, reward, ans_line_idx) 327 | write_flog(msg, flog) 328 | 329 | 330 | def write_step_end_msg(flog, node, next_node): 331 | if flog is None: 332 | return 333 | 334 | if node == next_node: 335 | write_flog("\tIllegal move\n", flog) 336 | else: 337 | write_flog("\t\n", flog) 338 | 339 | 340 | def write_path_end_msg(flog, node, done, action, line_diff): 341 | if flog is None: 342 | return 343 | 344 | if done: 345 | if action == ACTIONS['STOP']: 346 | if node.is_root: 347 | write_flog("\tStopped at the title\n", flog) 348 | elif line_diff == 0: 349 | write_flog("\tWin!\n", flog) 350 | else: 351 | write_flog("\tStopped at wrong place\n", flog) 352 | else: 353 | write_flog("\tReached maximum steps\n", flog) 354 | 355 | 356 | def create_json_record(sample, line_idx, from_raw=False, line_txt=None): 357 | if from_raw: 358 | context = line_txt 359 | else: 360 | context = sample['OrigEvidenceFilename'] + '_' + str(line_idx) 361 | 362 | record = {"title": get_evidence_title(sample) + '_' + str(line_idx), 363 | "paragraphs": [ 364 | { 365 | "context": context, 366 | "qas": [ 367 | { 368 | "id": sample['QuestionId'] + '--' + sample['OrigEvidenceFilename'] + '-->' + str(line_idx), 369 | "question": sample['Question'] 370 | } 371 | ] 372 | } 373 | ] 374 | } 375 | 376 | return record 377 | 378 | 379 | def get_evidence_line_idx_context(record): 380 | filename, line_idx = record["paragraphs"][0]["context"].rsplit('_', 1) 381 | line_idx = int(line_idx) 382 | 383 | if line_idx < 0: 384 | context = filename.rsplit('.', 1)[0].replace('_', ' ') 385 | 386 | else: 387 | evidence_filename = filename.replace('.txt', '.nop.txt') 388 | raw_evidence_path = EVIDENCE_TEMP.format(evidence_filename) 389 | 390 | with open(raw_evidence_path, 'r', encoding='utf-8') as fd: 391 | lines = fd.readlines() 392 | context = lines[line_idx].strip().strip('\n') 393 | 394 | return context 395 | 396 | 397 | def write_predictions_json(records, output_path, from_raw=False): 398 | if output_path is None: 399 | return 400 | 401 | if not from_raw: 402 | for record in records: 403 | context = get_evidence_line_idx_context(record) 404 | record["paragraphs"][0]["context"] = context 405 | 406 | with open(output_path, 'w') as fd: 407 | json.dump({"data": records}, fd) 408 | 409 | -------------------------------------------------------------------------------- /code/utils/rpc_client.py: -------------------------------------------------------------------------------- 1 | ###################################################################### 2 | # Evidence RPC client 3 | # 4 | # Running on top of RabbitMQ 5 | # 6 | ###################################################################### 7 | 8 | ################################### 9 | # Imports 10 | # 11 | 12 | import pika 13 | import uuid 14 | import simplejson as json 15 | from anytree.importer import JsonImporter 16 | import time 17 | 18 | 19 | ################################### 20 | # Classes 21 | # 22 | 23 | class EvidenceRpcClient(object): 24 | def __init__(self, coupled=False): 25 | self.tree_importer = JsonImporter() 26 | self.connection = pika.BlockingConnection(pika.ConnectionParameters(host='localhost')) 27 | self.channel = self.connection.channel() 28 | self.coupled = coupled 29 | 30 | result = self.channel.queue_declare(exclusive=True) 31 | self.callback_queue = result.method.queue 32 | 33 | self.channel.basic_consume(self.on_response, no_ack=True, queue=self.callback_queue) 34 | 35 | def deserialize(self, response): 36 | res = json.loads(response) 37 | if self.coupled: 38 | res['tree'] = self.tree_importer.import_(res['tree']) 39 | else: 40 | res[0] = json.loads(res[0]) 41 | res[0]['tree'] = self.tree_importer.import_(res[0]['tree']) 42 | 43 | return res 44 | 45 | def on_response(self, ch, method, props, body): 46 | if self.corr_id == props.correlation_id: 47 | self.response = body 48 | 49 | def call(self, request): 50 | self.response = None 51 | self.corr_id = str(uuid.uuid4()) 52 | try: 53 | self.channel.basic_publish(exchange='', 54 | routing_key='rpc_queue', 55 | properties=pika.BasicProperties( 56 | reply_to=self.callback_queue, 57 | correlation_id=self.corr_id, 58 | ), 59 | body=str(request)) 60 | except: 61 | time.sleep(5) 62 | self.channel.basic_publish(exchange='', 63 | routing_key='rpc_queue', 64 | properties=pika.BasicProperties( 65 | reply_to=self.callback_queue, 66 | correlation_id=self.corr_id, 67 | ), 68 | body=str(request)) 69 | 70 | while self.response is None: 71 | self.connection.process_data_events() 72 | 73 | return self.deserialize(self.response) 74 | 75 | -------------------------------------------------------------------------------- /code/utils/tree_navigation.py: -------------------------------------------------------------------------------- 1 | ###################################################################### 2 | # Tree Navigation 3 | # 4 | # Helper functions to navigate through evidence trees 5 | # 6 | ###################################################################### 7 | 8 | 9 | ################################### 10 | # Imports 11 | # 12 | 13 | from utils.data_processing import PADDING_IDX 14 | from anytree import Resolver, Walker, PreOrderIter 15 | import numpy as np 16 | import random 17 | 18 | ################################### 19 | # Globals 20 | # 21 | 22 | ACTIONS = {'UPL': 0, 23 | 'UPR': 1, 24 | 'DOWN': 2, 25 | 'LEFT': 3, 26 | 'RIGHT': 4, 27 | 'STOP': 5, 28 | 'ANS': 6} 29 | 30 | MAX_HEIGHT = 8 31 | PROPS_DIM = (MAX_HEIGHT+1) * 2 + 5 32 | 33 | 34 | ################################### 35 | # Helper Functions 36 | # 37 | 38 | def is_paragraph(node): 39 | return node.name[0] == 'p' and node.name[1].isdigit() 40 | 41 | 42 | def is_sentence(node): 43 | return hasattr(node, 'sidx') 44 | 45 | 46 | def is_leftmost(node): 47 | if node.is_root: 48 | return True 49 | 50 | children = node.parent.children 51 | node_idx = children.index(node) 52 | if node_idx == 0: 53 | return True 54 | 55 | return False 56 | 57 | 58 | def is_rightmost(node): 59 | if node.is_root: 60 | return True 61 | 62 | children = node.parent.children 63 | node_idx = children.index(node) 64 | if node_idx == len(children)-1: 65 | return True 66 | 67 | return False 68 | 69 | 70 | def get_tokens_txt(tokens): 71 | return [token[0] for token in tokens] 72 | 73 | 74 | def to_add_prediction(node): 75 | return not node.is_root and np.random.rand() < 1.0 / len(ACTIONS) 76 | 77 | 78 | ################################### 79 | # Functions 80 | # 81 | 82 | def get_closest_idx_diff(node, line_idx): 83 | diffs = np.array([abs(node.line - idx) + int(node.is_root) for idx in line_idx], dtype=np.int32) 84 | return line_idx[np.argmin(diffs)], np.min(diffs) 85 | 86 | 87 | def find_closest_ans_node(node, ans_line_idx, section_level=True): 88 | unique_ans_line_idx = np.asarray(list(set(ans_line_idx)), dtype=np.int32) 89 | ans_nodes = [_locate_paragraph_node_by_line_idx(node.root, x) for x in unique_ans_line_idx] 90 | if section_level: 91 | navigation_dists = np.asarray([navigation_dist(node, x.parent) for x in ans_nodes], dtype=np.int32) 92 | else: 93 | navigation_dists = np.asarray([navigation_dist(node, x) for x in ans_nodes], dtype=np.int32) 94 | # taking the node at minimum distance with minimum line index 95 | # (unless its at 0 distance, then we take the node itself) 96 | min_dist = min(navigation_dists) 97 | if node.line in unique_ans_line_idx: 98 | closest_node = [x for x in ans_nodes if x.line == node.line][0] 99 | else: 100 | min_dist_line_idx = min(unique_ans_line_idx[navigation_dists == min_dist]) 101 | min_idx = np.where(unique_ans_line_idx == min_dist_line_idx)[0][0] 102 | closest_node = ans_nodes[min_idx] 103 | 104 | return closest_node, min_dist 105 | 106 | 107 | def get_node_line_span(node): 108 | # get start line 109 | span_s = node.line 110 | 111 | # get end line by walking to the bottom most right child 112 | tmp = node 113 | while tmp.height > 0: 114 | tmp = tmp.children[-1] 115 | span_e = tmp.line 116 | 117 | return span_s, span_e 118 | 119 | 120 | def is_ans_in_subtree(node, line_idx): 121 | span_s, span_e = get_node_line_span(node) 122 | for idx in line_idx: 123 | if span_s <= idx <= span_e: 124 | return True 125 | return False 126 | 127 | 128 | def get_evidence_title(sample): 129 | return sample['OrigEvidenceFilename'].rsplit('.', 1)[0] 130 | 131 | 132 | def get_evidence_idx_title(sample, evidence_idx): 133 | return sample['EntityPages'][evidence_idx]['Filename'].rsplit('.', 1)[0] 134 | 135 | 136 | def get_evidence_tree_height(evidence): 137 | return evidence['tree'].height 138 | 139 | 140 | def get_evidence_length(evidence): 141 | # + 1 for evidence title 142 | return len(evidence["tokens"]) + 1 143 | 144 | 145 | def get_evidence(evidence_dict, sample): 146 | evidence_title = get_evidence_title(sample) 147 | key = sample["QuestionId"] + '--' + evidence_title 148 | 149 | return evidence_dict.call(key) 150 | 151 | 152 | def init_step(evidence, encoder, seq_len, observ_len, token_len, step): 153 | root = evidence["tree"] 154 | info = get_node_observ_props(root, step, evidence, encoder, seq_len, observ_len, token_len) 155 | (observ_w, observ_c, props, ans_w, ans_c, ans_p) = info 156 | 157 | return root, observ_w, observ_c, props 158 | 159 | 160 | def init_step_random_balanced(evidence, encoder, predictions, seq_len, observ_len, token_len, max_episode_steps): 161 | if np.random.rand() < 0.2: 162 | nodes = [x for x in PreOrderIter(evidence['tree']) if not x.dummy] 163 | else: 164 | nodes = [x for x in PreOrderIter(evidence['tree']) if not x.dummy and not hasattr(x, 'sidx')] 165 | node = random.choice(nodes) 166 | 167 | if node.is_root: 168 | step = 0 169 | else: 170 | min_step = node.depth + get_node_dist(node.parent)[0] 171 | max_step = max_episode_steps-1 # minus 1 because counting from 0 172 | if min_step >= max_step: 173 | step = max_step 174 | else: 175 | step = np.random.randint(min_step, max_step) 176 | 177 | preds = predictions if to_add_prediction(node) else None 178 | info = get_node_observ_props(node, step, evidence, encoder, seq_len, observ_len, token_len, preds) 179 | (observ_w, observ_c, props, ans_w, ans_c, ans_p) = info 180 | 181 | return node, observ_w, observ_c, props, ans_w, ans_c, ans_p, step 182 | 183 | 184 | def init_step_random_answer_radius(evidence, encoder, predictions, ans_line_idx, seq_len, observ_len, token_len, max_episode_steps, 185 | ans_radius, ans_dist_prob): 186 | if np.random.rand() > ans_dist_prob: 187 | return init_step_random_balanced(evidence, encoder, predictions, seq_len, observ_len, token_len, max_episode_steps) 188 | 189 | unique_ans_line_idx = list(set(ans_line_idx)) 190 | random_ans_line = random.choice(unique_ans_line_idx) 191 | ans_node = _locate_paragraph_node_by_line_idx(evidence['tree'], random_ans_line) 192 | node = get_random_nearby_node(ans_node, max_actions=ans_radius, evidence=evidence) 193 | 194 | if node.is_root: 195 | step = 0 196 | else: 197 | min_step = node.depth + get_node_dist(node.parent)[0] 198 | max_step = max_episode_steps-1 # minus 1 because counting from 0 199 | if min_step >= max_step: 200 | step = max_step 201 | else: 202 | step = np.random.randint(min_step, max_step) 203 | 204 | preds = predictions if to_add_prediction(node) else None 205 | info = get_node_observ_props(node, step, evidence, encoder, seq_len, observ_len, token_len, preds) 206 | (observ_w, observ_c, props, ans_w, ans_c, ans_p) = info 207 | 208 | return node, observ_w, observ_c, props, ans_w, ans_c, ans_p, step 209 | 210 | 211 | def get_random_nearby_node(start_node, max_actions, evidence): 212 | num_actions = np.random.randint(max_actions) 213 | 214 | node = start_node 215 | for _ in range(num_actions): 216 | legal_actions = get_legal_actions(node) 217 | legal_actions.remove(ACTIONS["STOP"]) 218 | if ACTIONS["ANS"] in legal_actions: 219 | legal_actions.remove(ACTIONS["ANS"]) 220 | if len(legal_actions) == 0: 221 | print("unexpected behavior - only STOP and or ANS are legal actions:\n\nstart_node: {}\t\nnode: {}".format( 222 | start_node, node)) 223 | return node 224 | action = random.choice(legal_actions) 225 | node = _make_step(evidence, node, action) 226 | 227 | return node 228 | 229 | 230 | def get_non_dummy_parent(node, evidence): 231 | if not node.dummy: 232 | return node 233 | 234 | else: 235 | if node.parent is None: 236 | print("Unexpected error, root is a dummy node: {}".format(evidence["title_tokens"])) 237 | return None 238 | 239 | elif not node.parent.dummy: 240 | return node.parent 241 | 242 | else: 243 | if node.parent.parent is None: 244 | print("Unexpected error, root is a dummy node: {}".format(evidence["title_tokens"])) 245 | return None 246 | 247 | elif not node.parent.parent.dummy: 248 | return node.parent.parent 249 | 250 | else: 251 | if node.parent.parent.parent is None: 252 | print("Unexpected error, root is a dummy node: {}".format(evidence["title_tokens"])) 253 | return None 254 | 255 | elif not node.parent.parent.parent.dummy: 256 | return node.parent.parent.parent 257 | 258 | else: 259 | print("Unexpected error, 4 dummy nodes: {}".format(evidence["title_tokens"])) 260 | return None 261 | 262 | 263 | def get_non_dummy_child(node, evidence): 264 | if not node.dummy: 265 | return node 266 | 267 | else: 268 | if node.children == (): 269 | print("Unexpected error, dummy node without children: {}".format(evidence["title_tokens"])) 270 | return None 271 | 272 | elif not node.children[0].dummy: 273 | return node.children[0] 274 | 275 | else: 276 | if node.children[0].children == (): 277 | print("Unexpected error, dummy node without children: {}".format(evidence["title_tokens"])) 278 | return None 279 | 280 | elif not node.children[0].children[0].dummy: 281 | return node.children[0].children[0] 282 | 283 | else: 284 | if node.children[0].children[0].children == (): 285 | print("Unexpected error, dummy node without children: {}".format(evidence["title_tokens"])) 286 | return None 287 | 288 | elif not node.children[0].children[0].children[0].dummy: 289 | return node.children[0].children[0].children[0] 290 | 291 | else: 292 | print("Unexpected error, 4 dummy nodes: {}".format(evidence["title_tokens"])) 293 | return None 294 | 295 | 296 | def make_step_upl(evidence, node): 297 | new_node = node 298 | 299 | if not node.is_root: 300 | new_node = make_step_left(evidence, node.parent) 301 | if new_node == node.parent: 302 | new_node = get_non_dummy_parent(new_node, evidence) 303 | 304 | return new_node 305 | 306 | 307 | def make_step_upr(evidence, node): 308 | new_node = node 309 | 310 | if not node.is_root: 311 | new_node = make_step_right(evidence, node.parent) 312 | if new_node == node.parent: 313 | new_node = get_non_dummy_parent(new_node, evidence) 314 | 315 | return new_node 316 | 317 | 318 | def make_step_down(evidence, node): 319 | new_node = node 320 | 321 | if node.children != (): 322 | new_node = get_non_dummy_child(node.children[0], evidence) 323 | if new_node is None: 324 | new_node = node 325 | 326 | return new_node 327 | 328 | 329 | def make_step_left(evidence, node): 330 | new_node = node 331 | 332 | if not node.is_root: 333 | children = node.parent.children 334 | node_idx = children.index(node) 335 | if node_idx > 0: 336 | new_node = get_non_dummy_child(children[node_idx-1], evidence) 337 | if new_node is None: 338 | new_node = node 339 | 340 | return new_node 341 | 342 | 343 | def make_step_right(evidence, node): 344 | new_node = node 345 | 346 | if not node.is_root: 347 | children = node.parent.children 348 | node_idx = children.index(node) 349 | if node_idx < len(children)-1: 350 | new_node = get_non_dummy_child(children[node_idx+1], evidence) 351 | if new_node is None: 352 | new_node = node 353 | 354 | return new_node 355 | 356 | 357 | def make_step(evidence, encoder, predictions, node, action, seq_len, observ_len, token_len, step): 358 | if action == ACTIONS['UPL']: 359 | new_node = make_step_upl(evidence, node) 360 | 361 | elif action == ACTIONS['UPR']: 362 | new_node = make_step_upr(evidence, node) 363 | 364 | elif action == ACTIONS['DOWN']: 365 | new_node = make_step_down(evidence, node) 366 | 367 | elif action == ACTIONS['LEFT']: 368 | new_node = make_step_left(evidence, node) 369 | 370 | elif action == ACTIONS['RIGHT']: 371 | new_node = make_step_right(evidence, node) 372 | 373 | elif action == ACTIONS['ANS']: 374 | # can answer only from evidence content (not from the title) 375 | if not node.is_root: 376 | info = get_node_observ_props(node, step, evidence, encoder, seq_len, observ_len, token_len, predictions) 377 | (observ_w, observ_c, props, ans_w, ans_c, ans_p) = info 378 | done = False 379 | return node, observ_w, observ_c, props, ans_w, ans_c, ans_p, done 380 | new_node = node 381 | 382 | # STOP 383 | else: 384 | info = get_node_observ_props(node, step, evidence, encoder, seq_len, observ_len, token_len, predictions) 385 | (observ_w, observ_c, props, ans_w, ans_c, ans_p) = info 386 | done = True 387 | return node, observ_w, observ_c, props, ans_w, ans_c, ans_p, done 388 | 389 | info = get_node_observ_props(new_node, step, evidence, encoder, seq_len, observ_len, token_len) 390 | (observ_w, observ_c, props, ans_w, ans_c, ans_p) = info 391 | done = False 392 | 393 | return new_node, observ_w, observ_c, props, ans_w, ans_c, ans_p, done 394 | 395 | 396 | def _make_step(evidence, node, action): 397 | if action == ACTIONS['UPL']: 398 | new_node = make_step_upl(evidence, node) 399 | 400 | elif action == ACTIONS['UPR']: 401 | new_node = make_step_upr(evidence, node) 402 | 403 | elif action == ACTIONS['DOWN']: 404 | new_node = make_step_down(evidence, node) 405 | 406 | elif action == ACTIONS['LEFT']: 407 | new_node = make_step_left(evidence, node) 408 | 409 | elif action == ACTIONS['RIGHT']: 410 | new_node = make_step_right(evidence, node) 411 | 412 | # STOP / ANS 413 | else: 414 | return node 415 | 416 | return new_node 417 | 418 | 419 | def get_node_observ_props(node, step, evidence, encoder, seq_len, observ_len, token_len, predictions=None): 420 | observ_w, observ_c = get_node_observ(node, evidence, encoder, seq_len, observ_len, token_len) 421 | observ_w = np.reshape(observ_w, (1, observ_len)) 422 | observ_c = np.reshape(observ_c, (1, observ_len, token_len)) 423 | 424 | props = get_node_props(node, step) 425 | props = np.reshape(props, (1, len(props))) 426 | 427 | ans_w, ans_c, ans_p = None, None, None 428 | # getting predictions only for paragraph and sentence levels 429 | if predictions is not None and node.height <= 1: 430 | ans_w, ans_c, ans_p = get_node_prediction(node, evidence, encoder, predictions, token_len) 431 | ans_c = np.reshape(ans_c, (1, len(ans_c), token_len)) 432 | ans_p = np.reshape(ans_p, (1, len(ans_p))) 433 | 434 | return observ_w, observ_c, props, ans_w, ans_c, ans_p 435 | 436 | 437 | def get_node_observ(node, evidence, encoder, seq_len, observ_len, token_len): 438 | try: 439 | # add root observation 440 | if node.is_root: 441 | res_w, res_c = encoder.encode_seq(evidence['title_tokens'][:seq_len]) 442 | res_c = encoder.pad_idx_seq_2dim(res_c, token_len, PADDING_IDX) 443 | 444 | # add paragraph observation 445 | elif is_paragraph(node): 446 | res_w_anc, res_c_anc = add_ancestors_observs(node, evidence, encoder, seq_len, token_len) 447 | res_w, res_c = encoder.encode_seq(evidence['tokens'][node.line][0][:seq_len]) 448 | res_w = np.concatenate([res_w_anc, res_w]) 449 | res_c = np.concatenate([res_c_anc, encoder.pad_idx_seq_2dim(res_c, token_len, PADDING_IDX)]) 450 | 451 | # add sentence observation 452 | elif is_sentence(node): 453 | res_w_anc, res_c_anc = add_ancestors_observs(node.parent, evidence, encoder, seq_len, token_len) 454 | res_w, res_c = encoder.encode_seq(evidence['tokens'][node.line][node.sidx][:seq_len]) 455 | res_w = np.concatenate([res_w_anc, res_w]) 456 | res_c = np.concatenate([res_c_anc, encoder.pad_idx_seq_2dim(res_c, token_len, PADDING_IDX)]) 457 | 458 | # add section observations 459 | else: 460 | # TODO: handle multi-sentence sections 461 | res_w_anc, res_c_anc = add_ancestors_observs(node, evidence, encoder, seq_len, token_len) 462 | res_w, res_c = encoder.encode_seq(evidence['tokens'][node.line][0][:seq_len]) 463 | res_w = np.concatenate([res_w_anc, res_w]) 464 | res_c = np.concatenate([res_c_anc, encoder.pad_idx_seq_2dim(res_c, token_len, PADDING_IDX)]) 465 | 466 | observ_w = encoder.pad_idx_seq_1dim(res_w, observ_len, PADDING_IDX) 467 | observ_c = encoder.concate_pad_seq(res_c, observ_len, token_len, PADDING_IDX) 468 | 469 | return observ_w, observ_c 470 | 471 | except Exception as e: 472 | msg = "Error during 'get_node_observ': {}\nevidence: {}\nnode: {}\nnode_line: {}\nevidence_tokens_length: {}".format( 473 | e, evidence["tree"].name, node, node.line, len(evidence["tokens"])) 474 | print(msg) 475 | exit() 476 | 477 | 478 | def add_ancestors_observs(node, evidence, encoder, seq_len, token_len): 479 | # assumes node is a title/headline node, namely not a paragraph nor a sentence 480 | observ_w, observ_c = [], [] 481 | for anc_node in node.ancestors: 482 | if anc_node.dummy or is_paragraph(anc_node): 483 | continue 484 | if anc_node.is_root: 485 | res_w, res_c = encoder.encode_seq(evidence['title_tokens'][:seq_len]) 486 | res_c = encoder.pad_idx_seq_2dim(res_c, token_len, PADDING_IDX) 487 | else: 488 | res_w, res_c = encoder.encode_seq(evidence['tokens'][anc_node.line][0][:seq_len]) 489 | res_c = encoder.pad_idx_seq_2dim(res_c, token_len, PADDING_IDX) 490 | 491 | observ_w.extend(res_w) 492 | observ_c.extend(res_c) 493 | 494 | return observ_w, observ_c 495 | 496 | 497 | def get_node_dist(node): 498 | dist_start, dist_end = 0, 0 499 | 500 | if not node.is_root: 501 | dist_start = node.parent.children.index(node) 502 | dist_end = len(node.parent.children) - 1 - node.parent.children.index(node) 503 | 504 | return dist_start, dist_end 505 | 506 | 507 | def get_node_props(node, step): 508 | dist_start, dist_end = get_node_dist(node) 509 | if not node.is_root: 510 | dist_up_start, dist_up_end = get_node_dist(node.parent) 511 | else: 512 | dist_up_start, dist_up_end = 0, 0 513 | 514 | height = [0] * (MAX_HEIGHT+1) 515 | height[node.height] = 1 516 | depth = [0] * (MAX_HEIGHT+1) 517 | depth[node.depth] = 1 518 | 519 | return [step] + depth + height + [dist_start, dist_end, dist_up_start, dist_up_end] 520 | 521 | 522 | def get_node_prediction(node, evidence, encoder, predictions, token_len): 523 | prediction = predictions[str(node.line)] 524 | ans_w, ans_c = encoder.encode_seq(prediction['tokens']) 525 | ans_c = encoder.pad_idx_seq_2dim(ans_c, token_len, PADDING_IDX) 526 | 527 | num_line_tokens = sum([len(sent) for sent in evidence['tokens'][node.line]]) 528 | ans_p = [prediction['ent'], prediction['logits'], num_line_tokens] 529 | 530 | return ans_w, ans_c, ans_p 531 | 532 | 533 | def get_node_prediction_text(node, predictions): 534 | if node.is_root or str(node.line) not in predictions: 535 | return None 536 | else: 537 | return predictions[str(node.line)]['texts'] 538 | 539 | 540 | def _locate_paragraph_node_by_line_idx(root, idx): 541 | nodes = [x for x in PreOrderIter(root) if not hasattr(x, 'sidx')] 542 | node_lines = [x.line if hasattr(x, 'line') else -1 for x in nodes] 543 | 544 | node_idx = len(node_lines) - 1 - node_lines[::-1].index(idx) 545 | ans_node_idx = node_idx if is_paragraph(nodes[node_idx]) else nodes.index(nodes[node_idx].parent) 546 | 547 | return nodes[ans_node_idx] 548 | 549 | 550 | def get_node_section_idx(node): 551 | if node.is_root: 552 | return -1 553 | 554 | if node.parent.is_root: 555 | return node.parent.children.index(node) 556 | 557 | return node.anchestors[0].children.index(node.anchestors[1]) 558 | 559 | 560 | def is_illegal_move(node, action): 561 | if node.is_root: 562 | if action not in [ACTIONS['DOWN'], ACTIONS['STOP']]: 563 | return True 564 | return False 565 | 566 | if is_sentence(node) and action == ACTIONS['DOWN']: 567 | return True 568 | 569 | if is_leftmost(node) and action == ACTIONS['LEFT']: 570 | return True 571 | 572 | if is_rightmost(node) and action == ACTIONS['RIGHT']: 573 | return True 574 | 575 | return False 576 | 577 | 578 | def get_legal_actions(node): 579 | return [ACTIONS[x] for x in ACTIONS if not is_illegal_move(node, ACTIONS[x])] 580 | 581 | 582 | def navigation_dist(node1, node2): 583 | # navigation distance in steps from node1 to node2, including dummy nodes 584 | # it is valid to include them here because we look at the paths and not navigating 585 | # assuming the two node object are from the same tree instance 586 | path1 = np.asarray(get_node_idx_path(node1), dtype=np.int32) 587 | path2 = np.asarray(get_node_idx_path(node2), dtype=np.int32) 588 | return navigation_dist_idx_path(path1, path2) 589 | 590 | 591 | def navigation_dist_idx_path(path1, path2): 592 | # navigation distance in steps from node at path1 to node at path2, including dummy nodes 593 | path1 = np.asarray(path1, dtype=np.int32) 594 | path2 = np.asarray(path2, dtype=np.int32) 595 | distance = 0 596 | n1 = len(path1) 597 | n2 = len(path2) 598 | 599 | if n1 <= n2: 600 | common_diff = np.abs(path1 - path2[:n1]) 601 | else: 602 | common_diff = np.abs(path1[:n2] - path2) 603 | 604 | diff = np.argwhere(common_diff > 0) 605 | if len(diff) == 0: 606 | first_uncommon_idx = len(common_diff) 607 | else: 608 | first_diff_idx = diff[0][0] 609 | distance += common_diff[first_diff_idx] 610 | first_uncommon_idx = first_diff_idx + 1 611 | 612 | uncommon_path1 = path1[first_uncommon_idx:] 613 | uncommon_path2 = path2[first_uncommon_idx:] 614 | 615 | if len(uncommon_path1) > 0: 616 | distance += len(uncommon_path1) 617 | if len(uncommon_path2) > 0: 618 | distance += sum(uncommon_path2) + len(uncommon_path2) 619 | 620 | return distance 621 | 622 | 623 | def get_node_idx_path(node): 624 | if node.is_root: 625 | return [0] 626 | 627 | return [get_node_dist(x)[0] for x in node.anchestors] + [get_node_dist(node)[0]] 628 | -------------------------------------------------------------------------------- /code/utils/tree_navigation_c.py: -------------------------------------------------------------------------------- 1 | ###################################################################### 2 | # Tree Navigation - Coupled 3 | # 4 | # Helper functions to navigate through evidence trees 5 | # 6 | ###################################################################### 7 | 8 | 9 | ################################### 10 | # Imports 11 | # 12 | 13 | from utils.data_processing import PADDING_IDX 14 | from utils.rpc_client import EvidenceRpcClient 15 | from anytree import Resolver, Walker, PreOrderIter 16 | import numpy as np 17 | import random 18 | 19 | ################################### 20 | # Globals 21 | # 22 | 23 | ACTIONS = {'UPL': 0, 24 | 'UPR': 1, 25 | 'DOWN': 2, 26 | 'LEFT': 3, 27 | 'RIGHT': 4, 28 | 'STOP': 5} 29 | 30 | MAX_HEIGHT = 8 31 | PROPS_DIM = (MAX_HEIGHT+1) * 2 + 5 32 | 33 | 34 | ################################### 35 | # Helper Functions 36 | # 37 | 38 | def is_paragraph(node): 39 | return node.name[0] == 'p' and node.name[1].isdigit() 40 | 41 | 42 | def is_sentence(node): 43 | return hasattr(node, 'sidx') 44 | 45 | 46 | def is_leftmost(node): 47 | if node.is_root: 48 | return True 49 | 50 | children = node.parent.children 51 | node_idx = children.index(node) 52 | if node_idx == 0: 53 | return True 54 | 55 | return False 56 | 57 | 58 | def is_rightmost(node): 59 | if node.is_root: 60 | return True 61 | 62 | children = node.parent.children 63 | node_idx = children.index(node) 64 | if node_idx == len(children)-1: 65 | return True 66 | 67 | return False 68 | 69 | 70 | def get_tokens_txt(tokens): 71 | return [token[0] for token in tokens] 72 | 73 | 74 | ################################### 75 | # Functions 76 | # 77 | 78 | 79 | def find_closest_ans_node(node, ans_line_idx, section_level=True): 80 | unique_ans_line_idx = np.asarray(list(set(ans_line_idx)), dtype=np.int32) 81 | ans_nodes = [_locate_paragraph_node_by_line_idx(node.root, x) for x in unique_ans_line_idx] 82 | if section_level: 83 | navigation_dists = np.asarray([navigation_dist(node, x.parent) for x in ans_nodes], dtype=np.int32) 84 | else: 85 | navigation_dists = np.asarray([navigation_dist(node, x) for x in ans_nodes], dtype=np.int32) 86 | # taking the node at minimum distance with minimum line index 87 | # (unless its at 0 distance, then we take the node itself) 88 | min_dist = min(navigation_dists) 89 | if node.line in unique_ans_line_idx: 90 | closest_node = [x for x in ans_nodes if x.line == node.line][0] 91 | else: 92 | min_dist_line_idx = min(unique_ans_line_idx[navigation_dists == min_dist]) 93 | min_idx = np.where(unique_ans_line_idx == min_dist_line_idx)[0][0] 94 | closest_node = ans_nodes[min_idx] 95 | 96 | return closest_node, min_dist 97 | 98 | 99 | def get_node_line_span(node): 100 | # get start line 101 | span_s = node.line 102 | 103 | # get end line by walking to the bottom most right child 104 | tmp = node 105 | while tmp.height > 0: 106 | tmp = tmp.children[-1] 107 | span_e = tmp.line 108 | 109 | return span_s, span_e 110 | 111 | 112 | def is_ans_in_subtree(node, line_idx): 113 | span_s, span_e = get_node_line_span(node) 114 | for idx in line_idx: 115 | if span_s <= idx <= span_e: 116 | return True 117 | return False 118 | 119 | 120 | def get_evidence_title(sample): 121 | return sample['OrigEvidenceFilename'].rsplit('.', 1)[0] 122 | 123 | 124 | def get_evidence_idx_title(sample, evidence_idx): 125 | return sample['EntityPages'][evidence_idx]['Filename'].rsplit('.', 1)[0] 126 | 127 | 128 | def get_evidence_tree_height(evidence): 129 | return evidence['tree'].height 130 | 131 | 132 | def get_evidence_length(evidence): 133 | # + 1 for evidence title 134 | return len(evidence["tokens"]) + 1 135 | 136 | 137 | def get_evidence(evidence_dict, sample): 138 | evidence_title = get_evidence_title(sample) 139 | 140 | if isinstance(evidence_dict, EvidenceRpcClient): 141 | return evidence_dict.call('--' + evidence_title) 142 | else: 143 | return evidence_dict[evidence_title] 144 | 145 | 146 | def init_step(evidence, encoder, seq_len, observ_len, token_len, step): 147 | root = evidence["tree"] 148 | observ_w, observ_c, props = get_node_observ_props(root, step, evidence, encoder, seq_len, observ_len, token_len) 149 | 150 | return root, observ_w, observ_c, props 151 | 152 | 153 | def init_step_random_balanced(evidence, encoder, seq_len, observ_len, token_len, max_episode_steps): 154 | if np.random.rand() < 0.2: 155 | nodes = [x for x in PreOrderIter(evidence['tree']) if not x.dummy] 156 | else: 157 | nodes = [x for x in PreOrderIter(evidence['tree']) if not x.dummy and not hasattr(x, 'sidx')] 158 | node = random.choice(nodes) 159 | 160 | if node.is_root: 161 | step = 0 162 | else: 163 | min_step = node.depth + get_node_dist(node.parent)[0] 164 | max_step = max_episode_steps-1 # minus 1 because counting from 0 165 | if min_step >= max_step: 166 | step = max_step 167 | else: 168 | step = np.random.randint(min_step, max_step) 169 | 170 | observ_w, observ_c, props = get_node_observ_props(node, step, evidence, encoder, seq_len, observ_len, token_len) 171 | 172 | return node, observ_w, observ_c, props, step 173 | 174 | 175 | def init_step_random_answer_radius(evidence, encoder, ans_line_idx, seq_len, observ_len, token_len, max_episode_steps, 176 | ans_radius, ans_dist_prob): 177 | if np.random.rand() > ans_dist_prob: 178 | return init_step_random_balanced(evidence, encoder, seq_len, observ_len, token_len, max_episode_steps) 179 | 180 | unique_ans_line_idx = list(set(ans_line_idx)) 181 | random_ans_line = random.choice(unique_ans_line_idx) 182 | ans_node = _locate_paragraph_node_by_line_idx(evidence['tree'], random_ans_line) 183 | node = get_random_nearby_node(ans_node, max_actions=ans_radius, evidence=evidence) 184 | 185 | if node.is_root: 186 | step = 0 187 | else: 188 | min_step = node.depth + get_node_dist(node.parent)[0] 189 | max_step = max_episode_steps-1 # minus 1 because counting from 0 190 | if min_step >= max_step: 191 | step = max_step 192 | else: 193 | step = np.random.randint(min_step, max_step) 194 | 195 | observ_w, observ_c, props = get_node_observ_props(node, step, evidence, encoder, seq_len, observ_len, token_len) 196 | 197 | return node, observ_w, observ_c, props, step 198 | 199 | 200 | def get_random_nearby_node(start_node, max_actions, evidence): 201 | num_actions = np.random.randint(max_actions) 202 | 203 | node = start_node 204 | for _ in range(num_actions): 205 | legal_actions = get_legal_actions(node) 206 | legal_actions.remove(ACTIONS["STOP"]) 207 | if len(legal_actions) == 0: 208 | print("unexpected behavior - only STOP is a leagal action:\n\nstart_node: {}\t\nnode: {}".format( 209 | start_node, node)) 210 | return node 211 | action = random.choice(legal_actions) 212 | node = _make_step(evidence, node, action) 213 | 214 | return node 215 | 216 | 217 | def get_non_dummy_parent(node, evidence): 218 | if not node.dummy: 219 | return node 220 | 221 | else: 222 | if node.parent is None: 223 | print("Unexpected error, root is a dummy node: {}".format(evidence["title_tokens"])) 224 | return None 225 | 226 | elif not node.parent.dummy: 227 | return node.parent 228 | 229 | else: 230 | if node.parent.parent is None: 231 | print("Unexpected error, root is a dummy node: {}".format(evidence["title_tokens"])) 232 | return None 233 | 234 | elif not node.parent.parent.dummy: 235 | return node.parent.parent 236 | 237 | else: 238 | if node.parent.parent.parent is None: 239 | print("Unexpected error, root is a dummy node: {}".format(evidence["title_tokens"])) 240 | return None 241 | 242 | elif not node.parent.parent.parent.dummy: 243 | return node.parent.parent.parent 244 | 245 | else: 246 | print("Unexpected error, 4 dummy nodes: {}".format(evidence["title_tokens"])) 247 | return None 248 | 249 | 250 | def get_non_dummy_child(node, evidence): 251 | if not node.dummy: 252 | return node 253 | 254 | else: 255 | if node.children == (): 256 | print("Unexpected error, dummy node without children: {}".format(evidence["title_tokens"])) 257 | return None 258 | 259 | elif not node.children[0].dummy: 260 | return node.children[0] 261 | 262 | else: 263 | if node.children[0].children == (): 264 | print("Unexpected error, dummy node without children: {}".format(evidence["title_tokens"])) 265 | return None 266 | 267 | elif not node.children[0].children[0].dummy: 268 | return node.children[0].children[0] 269 | 270 | else: 271 | if node.children[0].children[0].children == (): 272 | print("Unexpected error, dummy node without children: {}".format(evidence["title_tokens"])) 273 | return None 274 | 275 | elif not node.children[0].children[0].children[0].dummy: 276 | return node.children[0].children[0].children[0] 277 | 278 | else: 279 | print("Unexpected error, 4 dummy nodes: {}".format(evidence["title_tokens"])) 280 | return None 281 | 282 | 283 | def make_step_upl(evidence, node): 284 | new_node = node 285 | 286 | if not node.is_root: 287 | new_node = make_step_left(evidence, node.parent) 288 | if new_node == node.parent: 289 | new_node = get_non_dummy_parent(new_node, evidence) 290 | 291 | return new_node 292 | 293 | 294 | def make_step_upr(evidence, node): 295 | new_node = node 296 | 297 | if not node.is_root: 298 | new_node = make_step_right(evidence, node.parent) 299 | if new_node == node.parent: 300 | new_node = get_non_dummy_parent(new_node, evidence) 301 | 302 | return new_node 303 | 304 | 305 | def make_step_down(evidence, node): 306 | new_node = node 307 | 308 | if node.children != (): 309 | new_node = get_non_dummy_child(node.children[0], evidence) 310 | if new_node is None: 311 | new_node = node 312 | 313 | return new_node 314 | 315 | 316 | def make_step_left(evidence, node): 317 | new_node = node 318 | 319 | if not node.is_root: 320 | children = node.parent.children 321 | node_idx = children.index(node) 322 | if node_idx > 0: 323 | new_node = get_non_dummy_child(children[node_idx-1], evidence) 324 | if new_node is None: 325 | new_node = node 326 | 327 | return new_node 328 | 329 | 330 | def make_step_right(evidence, node): 331 | new_node = node 332 | 333 | if not node.is_root: 334 | children = node.parent.children 335 | node_idx = children.index(node) 336 | if node_idx < len(children)-1: 337 | new_node = get_non_dummy_child(children[node_idx+1], evidence) 338 | if new_node is None: 339 | new_node = node 340 | 341 | return new_node 342 | 343 | 344 | def make_step(evidence, encoder, node, action, seq_len, observ_len, token_len, step): 345 | if action == ACTIONS['UPL']: 346 | new_node = make_step_upl(evidence, node) 347 | 348 | elif action == ACTIONS['UPR']: 349 | new_node = make_step_upr(evidence, node) 350 | 351 | elif action == ACTIONS['DOWN']: 352 | new_node = make_step_down(evidence, node) 353 | 354 | elif action == ACTIONS['LEFT']: 355 | new_node = make_step_left(evidence, node) 356 | 357 | elif action == ACTIONS['RIGHT']: 358 | new_node = make_step_right(evidence, node) 359 | 360 | # STOP 361 | else: 362 | observ_w, observ_c, props = get_node_observ_props(node, step, evidence, encoder, seq_len, observ_len, token_len) 363 | done = True 364 | return node, observ_w, observ_c, props, done 365 | 366 | observ_w, observ_c, props = get_node_observ_props(new_node, step, evidence, encoder, seq_len, observ_len, token_len) 367 | done = False 368 | 369 | return new_node, observ_w, observ_c, props, done 370 | 371 | 372 | def _make_step(evidence, node, action): 373 | if action == ACTIONS['UPL']: 374 | new_node = make_step_upl(evidence, node) 375 | 376 | elif action == ACTIONS['UPR']: 377 | new_node = make_step_upr(evidence, node) 378 | 379 | elif action == ACTIONS['DOWN']: 380 | new_node = make_step_down(evidence, node) 381 | 382 | elif action == ACTIONS['LEFT']: 383 | new_node = make_step_left(evidence, node) 384 | 385 | elif action == ACTIONS['RIGHT']: 386 | new_node = make_step_right(evidence, node) 387 | 388 | # STOP 389 | else: 390 | return node 391 | 392 | return new_node 393 | 394 | 395 | def get_node_observ_props(node, step, evidence, encoder, seq_len, observ_len, token_len): 396 | observ_w, observ_c = get_node_observ(node, evidence, encoder, seq_len, observ_len, token_len) 397 | observ_w = np.reshape(observ_w, (1, observ_len)) 398 | observ_c = np.reshape(observ_c, (1, observ_len, token_len)) 399 | 400 | props = get_node_props(node, step) 401 | props = np.reshape(props, (1, len(props))) 402 | 403 | return observ_w, observ_c, props 404 | 405 | 406 | def get_node_observ(node, evidence, encoder, seq_len, observ_len, token_len): 407 | try: 408 | # add root observation 409 | if node.is_root: 410 | res_w, res_c = encoder.encode_seq(evidence['title_tokens'][:seq_len]) 411 | res_c = encoder.pad_idx_seq_2dim(res_c, token_len, PADDING_IDX) 412 | 413 | # add paragraph observation 414 | elif is_paragraph(node): 415 | res_w_anc, res_c_anc = add_ancestors_observs(node, evidence, encoder, seq_len, token_len) 416 | res_w, res_c = encoder.encode_seq(evidence['tokens'][node.line][0][:seq_len]) 417 | res_w = np.concatenate([res_w_anc, res_w]) 418 | res_c = np.concatenate([res_c_anc, encoder.pad_idx_seq_2dim(res_c, token_len, PADDING_IDX)]) 419 | 420 | # add sentence observation 421 | elif is_sentence(node): 422 | res_w_anc, res_c_anc = add_ancestors_observs(node.parent, evidence, encoder, seq_len, token_len) 423 | res_w, res_c = encoder.encode_seq(evidence['tokens'][node.line][node.sidx][:seq_len]) 424 | res_w = np.concatenate([res_w_anc, res_w]) 425 | res_c = np.concatenate([res_c_anc, encoder.pad_idx_seq_2dim(res_c, token_len, PADDING_IDX)]) 426 | 427 | # add section observations 428 | else: 429 | # TODO: handle multi-sentence sections 430 | res_w_anc, res_c_anc = add_ancestors_observs(node, evidence, encoder, seq_len, token_len) 431 | res_w, res_c = encoder.encode_seq(evidence['tokens'][node.line][0][:seq_len]) 432 | res_w = np.concatenate([res_w_anc, res_w]) 433 | res_c = np.concatenate([res_c_anc, encoder.pad_idx_seq_2dim(res_c, token_len, PADDING_IDX)]) 434 | 435 | observ_w = encoder.pad_idx_seq_1dim(res_w, observ_len, PADDING_IDX) 436 | observ_c = encoder.concate_pad_seq(res_c, observ_len, token_len, PADDING_IDX) 437 | 438 | return observ_w, observ_c 439 | 440 | except Exception as e: 441 | msg = "Error during 'get_node_observ': {}\nevidence: {}\nnode: {}\nnode_line: {}\nevidence_tokens_length: {}".format( 442 | e, evidence["tree"].name, node, node.line, len(evidence["tokens"])) 443 | print(msg) 444 | exit() 445 | 446 | 447 | def add_ancestors_observs(node, evidence, encoder, seq_len, token_len): 448 | # assumes node is a title/headline node, namely not a paragraph nor a sentence 449 | observ_w, observ_c = [], [] 450 | for anc_node in node.ancestors: 451 | if anc_node.dummy or is_paragraph(anc_node): 452 | continue 453 | if anc_node.is_root: 454 | res_w, res_c = encoder.encode_seq(evidence['title_tokens'][:seq_len]) 455 | res_c = encoder.pad_idx_seq_2dim(res_c, token_len, PADDING_IDX) 456 | else: 457 | res_w, res_c = encoder.encode_seq(evidence['tokens'][anc_node.line][0][:seq_len]) 458 | res_c = encoder.pad_idx_seq_2dim(res_c, token_len, PADDING_IDX) 459 | 460 | observ_w.extend(res_w) 461 | observ_c.extend(res_c) 462 | 463 | return observ_w, observ_c 464 | 465 | 466 | def get_node_dist(node): 467 | dist_start, dist_end = 0, 0 468 | 469 | if not node.is_root: 470 | dist_start = node.parent.children.index(node) 471 | dist_end = len(node.parent.children) - 1 - node.parent.children.index(node) 472 | 473 | return dist_start, dist_end 474 | 475 | 476 | def get_node_props(node, step): 477 | dist_start, dist_end = get_node_dist(node) 478 | if not node.is_root: 479 | dist_up_start, dist_up_end = get_node_dist(node.parent) 480 | else: 481 | dist_up_start, dist_up_end = 0, 0 482 | 483 | height = [0] * (MAX_HEIGHT+1) 484 | height[node.height] = 1 485 | depth = [0] * (MAX_HEIGHT+1) 486 | depth[node.depth] = 1 487 | 488 | return [step] + depth + height + [dist_start, dist_end, dist_up_start, dist_up_end] 489 | 490 | 491 | def _locate_paragraph_node_by_line_idx(root, idx): 492 | nodes = [x for x in PreOrderIter(root) if not hasattr(x, 'sidx')] 493 | node_lines = [x.line if hasattr(x, 'line') else -1 for x in nodes] 494 | 495 | node_idx = len(node_lines) - 1 - node_lines[::-1].index(idx) 496 | ans_node_idx = node_idx if is_paragraph(nodes[node_idx]) else nodes.index(nodes[node_idx].parent) 497 | 498 | return nodes[ans_node_idx] 499 | 500 | 501 | def get_node_section_idx(node): 502 | if node.is_root: 503 | return -1 504 | 505 | if node.parent.is_root: 506 | return node.parent.children.index(node) 507 | 508 | return node.anchestors[0].children.index(node.anchestors[1]) 509 | 510 | 511 | def is_illegal_move(node, action): 512 | if node.is_root: 513 | if action not in [ACTIONS['DOWN'], ACTIONS['STOP']]: 514 | return True 515 | return False 516 | 517 | if is_sentence(node) and action == ACTIONS['DOWN']: 518 | return True 519 | 520 | if is_leftmost(node) and action == ACTIONS['LEFT']: 521 | return True 522 | 523 | if is_rightmost(node) and action == ACTIONS['RIGHT']: 524 | return True 525 | 526 | return False 527 | 528 | 529 | def get_legal_actions(node): 530 | return [ACTIONS[x] for x in ACTIONS if not is_illegal_move(node, ACTIONS[x])] 531 | 532 | 533 | def navigation_dist(node1, node2): 534 | # navigation distance in steps from node1 to node2, including dummy nodes 535 | # it is valid to include them here because we look at the paths and not navigating 536 | # assuming the two node object are from the same tree instance 537 | path1 = np.asarray(get_node_idx_path(node1), dtype=np.int32) 538 | path2 = np.asarray(get_node_idx_path(node2), dtype=np.int32) 539 | return navigation_dist_idx_path(path1, path2) 540 | 541 | 542 | def navigation_dist_idx_path(path1, path2): 543 | # navigation distance in steps from node at path1 to node at path2, including dummy nodes 544 | path1 = np.asarray(path1, dtype=np.int32) 545 | path2 = np.asarray(path2, dtype=np.int32) 546 | distance = 0 547 | n1 = len(path1) 548 | n2 = len(path2) 549 | 550 | if n1 <= n2: 551 | common_diff = np.abs(path1 - path2[:n1]) 552 | else: 553 | common_diff = np.abs(path1[:n2] - path2) 554 | 555 | diff = np.argwhere(common_diff > 0) 556 | if len(diff) == 0: 557 | first_uncommon_idx = len(common_diff) 558 | else: 559 | first_diff_idx = diff[0][0] 560 | distance += common_diff[first_diff_idx] 561 | first_uncommon_idx = first_diff_idx + 1 562 | 563 | uncommon_path1 = path1[first_uncommon_idx:] 564 | uncommon_path2 = path2[first_uncommon_idx:] 565 | 566 | if len(uncommon_path1) > 0: 567 | distance += len(uncommon_path1) 568 | if len(uncommon_path2) > 0: 569 | distance += sum(uncommon_path2) + len(uncommon_path2) 570 | 571 | return distance 572 | 573 | 574 | def get_node_idx_path(node): 575 | if node.is_root: 576 | return [0] 577 | 578 | return [get_node_dist(x)[0] for x in node.anchestors] + [get_node_dist(node)[0]] 579 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | anytree==2.4.3 2 | numpy 3 | pandas==0.21.0 4 | pika 5 | six 6 | requests 7 | tqdm 8 | simplejson 9 | --------------------------------------------------------------------------------