├── .gitignore ├── LICENSE.txt ├── README.md ├── adam.py ├── base_gru.py ├── convert_story.py ├── display ├── __init__.py ├── display_graph.js ├── display_graph.py ├── generate_images.js └── tolcolormap.py ├── do_babi_run.py ├── fix_old_file_list.py ├── ggtnn_graph_parse.py ├── ggtnn_train.py ├── graceful_interrupt.py ├── graph_state.py ├── layer.py ├── main.py ├── metadata-display.py ├── model.py ├── run_harness.py ├── strength_weighted_gru.py ├── task_generators ├── automaton.py ├── forth.py ├── graph_tools.py ├── ngram_next.py └── turing.py ├── train_exit_status.py ├── transformation_modules ├── README.md ├── __init__.py ├── aggregate_representation.py ├── aggregate_representation_softmax.py ├── direct_reference_update.py ├── edge_state_update.py ├── input_sequence_direct.py ├── new_nodes_inform.py ├── new_nodes_vote.py ├── node_state_update.py ├── output_category.py ├── output_sequence.py ├── output_set.py ├── propagation.py └── sequence_aggregate_summary.py ├── update_cache_compatibility.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2016 Daniel D. Johnson 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graphical State Transitions Framework 2 | 3 | This is the code supporting my paper ["Learning Graphical State Transitions"][gpaper], published in ICLR 2017. It consists of 4 | 5 | - implementation of each of the graph transformations described in the paper (in the transformation_modules subdirectory) 6 | - implementation of the Gated Graph Transformer Neural Network model (model.py) 7 | - a tool to convert a folder of tasks written in a textual form with JSON graphs into a series of python pickle files with appropriate metadata (ggtnn_graph_parse.py) 8 | - a harness to train multiple tasks in sequence (run_harness.py) 9 | - a helper executable to train on the sequence of bAbI tasks (do_babi_run.py) 10 | - a set of task-generators to generate data for particular tasks, such as the Turing machine and automaton tasks discussed in the paper (in the task_generators subdirectory) 11 | - a tool to enable visualization of the produced graphs, either interactively in [Jupyter][] (by importing display_graph.py and calling `graph_display`) or noninteractively using [phantomjs][] (by running display_graph.js) 12 | - other miscellaneous utilities 13 | 14 | Note that there were also modifications to the bAbI task generation code in order to extend them with graph information. For those, see [this repository][bAbi-mine]. 15 | 16 | To use the model, you will need [python 3.5][] or later with [Theano][] installed. If you wish to visualize the results, you will also need [matplotlib][] and [Jupyter][], and will need [phantomjs][] to generate images noninteractively. Additionally, you will need to have `floatX=float32` in your `.theanorc` file to compile models correctly. 17 | 18 | [gpaper]: http://openreview.net/pdf?id=HJ0NvFzxl 19 | [babi-mine]: https://github.com/hexahedria/bAbI-tasks 20 | [python 3.5]: https://www.python.org/ 21 | [Theano]: http://deeplearning.net/software/theano/ 22 | [matplotlib]: http://matplotlib.org/ 23 | [Jupyter]: http://jupyter.org/ 24 | [phantomjs]: http://phantomjs.org/ 25 | 26 | ## Quick start 27 | This section is intended to be a brief guide to reproducing some of the results from the paper. For more details, see the following sections, or the code itself. 28 | 29 | The first step is to create the actual training files. This can be done by cloning the [modified bAbI tasks repo][babi-mine], and then running the script `generate_graph_tasks.sh`, which will produce a large number of graphs in the `output` directory of the repo. 30 | 31 | The next step is to train the model using the `do_babi_run.py` helper script. For example, you might run 32 | ```python 33 | python3 do_babi_run.py path/to/bAbI-tasks/output ./model_results 34 | ``` 35 | which will train the model on all of the bAbI tasks with default parameters, and save the results into `./model_results`. 36 | 37 | Arguments accepted by `do_babi_run.py`: 38 | 39 | - `--including-only TASK_ID_1 TASK_ID_2 ...` will cause it to only train the model on the specific tasks given (where each TASK_ID represents the numerical index of the desired task). 40 | - `--dataset-sizes SIZE_1 SIZE_2 ...` will cause it to only train the model with the specified sizes of dataset. To train only with the full dataset, pass `--dataset-sizes 1000`. The default is equivalent to `--dataset-sizes 50 100 250 500 1000`. 41 | - `--direct-reference` and `--no-direct-reference` will cause it to only train with or without direct reference, respectively. By default, it will train both types of model; this forces it to only train one. 42 | 43 | The `do_babi_run.py` script sets specific parameters based on each task. In particular, it uses the appropriate output format for each network, and also enables or disables the intermediate propagation step depending on the complexity of the task. For each task, it then sets up multiple training runs with differently sized subsets of the input dataset, and also configures versions of the model with direct reference enabled and disabled. Internally, it then defers to the `run_harness.py` module, which runs all of those tasks in sequence. 44 | 45 | Additional arguments passed to `do_babi_run.py` are forwarded unchanged to the `main.py` script, which are described below. Note that the `do_babi_run.py` script automatically sets many of these arguments, so be careful to avoid conflicts. 46 | 47 | ## Non-bAbI graphs 48 | The model also has support for tasks that are not from the bAbI dataset. However, the training process is somewhat more complex. 49 | 50 | First, the graphs have to be obtained in the correct textual format. The parser script expects to see a file that is divided into multiple stories, where each story represents one training example and all stories are independent. A story consists of a series of sequentially numbered lines, which are either statements or queries, and end with a final query. 51 | 52 | A *statement* should be of the form 53 | ``` 54 | ={"nodes":[node1name,node2name,...], "edges":[{ "type":e1type,"from":e1sourcename,"to":e1dest}, { "type":e2type,"from":e2sourcename,"to":e2dest},...]} 55 | ``` 56 | where the names, types, sources, and destinations are all strings, and the sources and destinations match nodes in the node list. Note that each node name must be distinct. During processing, each of the words, node names, and edge types seen in the dataset will be mapped to unique integer indices. If your graph should have multiple nodes of the same type, the nodes can be disambiguated using a "#" and a suffix. For instance "cell#0" and "cell#1" will both be mapped to the "cell" node type. The graph should represent the desired state of the network after processing this statement. 57 | 58 | A *query* should be of the form 59 | ``` 60 | 61 | ``` 62 | The query is given by the words before the first tab character, and the network answer is given by the word or words after. Content after the second tab character is ignored, but is allowed for compatibility with the bAbI task format. (If the task has no meaningful query, then a simple empty string should be used for both the words and the answer.) 63 | 64 | In order for direct reference to work correctly, the name of the node in the graph should be the same as the word that refers to that node in the statement or query. So a node that will be addressed by "Mary" in the statements and queries should be called "Mary" (or "Mary#suffix") in the graph node list. If you do not want to use direct reference, the names of the graph elements are arbitrary. 65 | 66 | As an example, this file (excerpted) is taken from the bAbI graph dataset: 67 | ``` 68 | 1 Mary journeyed to the bathroom.={"nodes":["Mary","bathroom"],"edges":[{ "type":"actor_is_in_location","from":"Mary","to":"bathroom"}]} 69 | 2 Mary moved to the hallway.={"nodes":["Mary","hallway"],"edges":[{ "type":"actor_is_in_location","from":"Mary","to":"hallway"}]} 70 | 3 Where is Mary? hallway 2 71 | 4 Sandra travelled to the hallway.={"nodes":["Mary","hallway","Sandra"],"edges":[{ "type":"actor_is_in_location","from":"Mary","to":"hallway"},{ "type":"actor_is_in_location","from":"Sandra","to":"hallway"}]} 72 | 5 Daniel travelled to the office.={"nodes":["Mary","Daniel","Sandra","hallway","office"],"edges":[{ "type":"actor_is_in_location","from":"Daniel","to":"office"},{ "type":"actor_is_in_location","from":"Mary","to":"hallway"},{ "type":"actor_is_in_location","from":"Sandra","to":"hallway"}]} 73 | 6 Where is Daniel? office 5 74 | 1 Mary travelled to the bathroom.={"nodes":["Mary","bathroom"],"edges":[{ "type":"actor_is_in_location","from":"Mary","to":"bathroom"}]} 75 | 2 Mary travelled to the garden.={"nodes":["Mary","garden"],"edges":[{ "type":"actor_is_in_location","from":"Mary","to":"garden"}]} 76 | (etc) 77 | ``` 78 | 79 | See the `task_generators` subdirectory for a collection of scripts that generate output in this format, including the Turing machine and automaton generators. 80 | 81 | After obtaining a correctly-formatted story file, that file must be parsed and preprocessed to allow training. You can parse the file by running 82 | ``` 83 | python3 ggtnn_graph_parse.py path_to_file 84 | ``` 85 | which will create a directory and populate it with preprocessed training data. 86 | 87 | Note that in the process, it scans the file and uses it to determine the mapping of words to indices that will be used by the network, as well as the maximum lengths of various components of the model, which it stores in a metadata file. Models trained with one metadata file may not be able to correctly run on examples with a different metadata file. If you would prefer the network to not recompute the metadata and instead use an existing metadata file (for example to ensure that the training and testing sets both use the same metadata) you can pass an existing metadata file with the `--metadata-file` argument. You can also view a metadata file using the `metadata-display.py` helper script. 88 | 89 | Finally, you can actually train the model on the dataset. You will need to pass a large number of parameters to completely configure the model for your task. For example, to train the model on the automaton task, you might run 90 | 91 | ``` 92 | python3 main.py task_output/automaton_30_train category 20 --outputdir output_auto30 --validation task_output/automaton_30_valid --mutable-nodes --dynamic-nodes --propagate-intermediate --direct-reference --num-updates 100000 --batch-size 100 --batch-adjust 28000000 --resume-auto --no-query 93 | ``` 94 | 95 | In the next section, I will describe some of the parameters of `main.py` that can be set, and what their uses are. 96 | 97 | ## Overview of parameters for main.py 98 | 99 | The `main.py` script is the actual entry point for training a model. It has a large number of parameters, which are used to configure the model and determine what to do. For a full overview of all of the options available, you can run `python3 main.py -h`. In this section I will summarize the most important commands and their usage. 100 | 101 | The simplest form of the invocation is 102 | ``` 103 | python3 main.py task_dir output_form state_width 104 | ``` 105 | where 106 | 107 | - `task_dir` is the path to the directory created by the preprocessing step 108 | - `output_form` gives the form of the output: it should be `category` if every answer is a single word; `set` if answers have multiple words, each word can appear at most once, and order does not matter; and `sequence` if answers can have multiple words but order does matter and there could be repeats. 109 | - `state_width` determines the size of the state vector at each node. Since every node has a different state vector, making this parameter large can make it easier to overfit, but also allows more complex processing. 110 | 111 | In addition to these, you can also pass parameters to configure other aspects of the model and run process. 112 | 113 | ### Model parameters 114 | 115 | These parameters determine how the model actually is configured and run. 116 | 117 | ``` 118 | --process-repr-size PROCESS_REPR_SIZE 119 | Width of intermediate representations (default: 50) 120 | --mutable-nodes Make nodes mutable (default: False) 121 | --wipe-node-state Wipe node state before the query (default: False) 122 | --direct-reference Use direct reference for input, based on node names 123 | (default: False) 124 | --dynamic-nodes Create nodes after each sentence. (Otherwise, create 125 | unique nodes at the beginning) (default: False) 126 | --propagate-intermediate 127 | Run a propagation step after each sentence (default: 128 | False) 129 | --no-graph Don't train using graph supervision 130 | --no-query Don't train using query supervision 131 | ``` 132 | 133 | Although not given by default, you will likely want to use `--mutable-nodes` and `--dynamic-nodes` for tasks with any complex processing involved; this creates the equivalent of the GGT-NN model in the paper. Otherwise, nodes will not be created at each step, and existing nodes will not update their states. You may also want to want to use `--direct-reference`, as it tends to increase performance. The `--propagate-intermediate` argument should be used if nodes need to exchange information in order to update their intermediate states correctly (for example, if the placement of new nodes depends on edges between other nodes). The `--no-query` argument can be passed if the task does not have a meaningful query and will disable the query processing in the model. 134 | 135 | ### Training parameters 136 | 137 | These parameters affect the model training process. Most should be self explanatory. 138 | 139 | ``` 140 | --num-updates NUM_UPDATES 141 | How many iterations to train (default: 10000) 142 | --batch-size BATCH_SIZE 143 | Batch size to use (default: 10) 144 | --learning-rate LEARNING_RATE 145 | Use this learning rate (default: None) 146 | --dropout-keep DROPOUT_KEEP 147 | Use dropout, with this keep chance (default: 1) 148 | --restrict-dataset NUM_STORIES 149 | Restrict size of dataset to this (default: None) 150 | --validation VALIDATION_DIR 151 | Parsed directory of validation tasks (default: None) 152 | --validation-interval VALIDATION_INTERVAL 153 | Check validation after this many iterations (default: 154 | 1000) 155 | --stop-at-accuracy STOP_AT_ACCURACY 156 | Stop training once it reaches this accuracy on 157 | validation set (default: None) 158 | --stop-at-loss STOP_AT_LOSS 159 | Stop training once it reaches this loss on validation 160 | set (default: None) 161 | --stop-at-overfitting STOP_AT_OVERFITTING 162 | Stop training once validation loss is this many times 163 | higher than train loss (default: None) 164 | --batch-adjust BATCH_ADJUST 165 | If set, ensure that size of edge matrix does not 166 | exceed this (default: None) 167 | ``` 168 | 169 | The `--batch-adjust` argument can be used to prevent out-of-memory errors for large datasets. It uses a heuristic based on the size of the edge matrix to try to adjust the size of the batch based on the length of the input data. Good values of this should be determined by trial and error (with the bAbI I found a value of about 28000000 to work on my machine). 170 | 171 | ### IO Parameters 172 | 173 | These parameters configure how the script performs I/O operations. 174 | ``` 175 | --outputdir OUTPUTDIR 176 | Directory to save output in (default: output) 177 | --save-params-interval TRAIN_SAVE_PARAMS 178 | Save parameters after this many iterations (default: 179 | 1000) 180 | --final-params-only Don't save parameters while training, only at the end. 181 | (default: None) 182 | --set-exit-status Give info about training status in the exit status 183 | (default: False) 184 | --autopickle PICKLEDIR 185 | Automatically cache model in this directory (default: 186 | None) 187 | --pickle-model MODELFILE 188 | Save the compiled model to a file (default: None) 189 | --unpickle-model MODELFILE 190 | Load the model from a file instead of compiling it 191 | from scratch (default: None) 192 | --interrupt-file INTERRUPT_FILE 193 | Interrupt training if this file appears (default: 194 | None) 195 | --resume TIMESTEP PARAMFILE 196 | Where to restore from: timestep, and file to load 197 | (default: None) 198 | --resume-auto Automatically restore from a previous run using output 199 | directory (default: False) 200 | ``` 201 | 202 | To speed up repeated uses of the model, I recommend using the `--autopickle` argument with a particular model-cache directory. The script will automatically determine a unique name for each model version and assign it to a given hash value, and then will try to load a cached model based on this hash. If it fails to find one, it will compile the model as normal and then save it into the directory based on the hash. 203 | 204 | Additionally, if the training process is interrupted, the `--resume-auto` parameter will allow the training process to pick up where it left off. Otherwise, it will start over from iteration 0. You can also explicitly set a starting time using `--resume TIMESTEP PARAMFILE`. 205 | 206 | ### Alternate execution modes 207 | 208 | The `main.py` script can also do other things in addition to training a model. 209 | ``` 210 | --check-nan Check for NaN. Slows execution (default: None) 211 | --check-debug Debug mode. Slows execution (default: None) 212 | --just-compile Don't run the model, just compile it (default: False) 213 | --visualize [BUCKET,STORY] 214 | Visualise current state instead of training. Optional 215 | parameter selects a particular story to visualize, and 216 | should be of the form bucketnum,index (default: False) 217 | --visualize-snap In visualization mode, snap to best option at each 218 | timestep (default: False) 219 | --visualization-test Like visualize, but use the correct graph instead of 220 | the model's graph (default: False) 221 | --evaluate-accuracy Evaluate accuracy of model (default: False) 222 | ``` 223 | 224 | The first two arguments are useful only if you are experiencing either NaN issues or an unexpected Theano error. The `--just-compile` is useful in conjunction with `--autopickle` in that it compiles and saves a model for later training. 225 | 226 | The `--visualize` family of commands run the model on the input and generate visualization files, which can be converted into a diagram. If `--visualize` is used alone, the model will produce nodes whose strengths vary according to the strengths output by the model, producing "fuzzy" partial nodes. If `--visualize-snap` is also passed, the most likely option at each timestep will be selected instead, and the model will be forced to choose its actions with full strength. 227 | 228 | The `--visualization-test` option is of limited use, and simply produces the visualization files correspoding to the correct graph structure from the dataset, but with the states from the model. (If you simply wish to visualize the correct graph structure, it is easier to use the `convert_story.py` script, which takes a story file and produces the graph visualization files.) 229 | 230 | The `--evaluate-accuracy` argument evaluates the accuracy of the model over the dataset. In this mode, as in `--visualize-snap`, the most likely option at each timestep will be selected, and the model will be forced to choose its actions with full strength. If the result of the output exactly matches the correct result in the dataset, that sample is marked as a success, and otherwise it is a failure. It then prints out the fraction of samples that were successes. (When using this, pass the test dataset as the `task_dir` parameter.) 231 | 232 | ## Visualizing the results 233 | 234 | After generating the visualization files, there are two ways to visualize them. 235 | 236 | ### Interactive visualization 237 | To interactively visualize the output, you need to use Jupyter. In Jupyter, run 238 | ``` 239 | import numpy as np 240 | from display.display_graph import graph_display, setup_graph_display 241 | setup_graph_display() 242 | def do_vis(direc, correct_graph=False, options={}): 243 | global results 244 | results = [np.load("{}/result_{}.npy".format(direc,i)) for i in (range(4) if correct_graph else range(1,5))] 245 | return graph_display(results,options) 246 | ``` 247 | Then to visualize a particular output, you can run `do_vis("path/to/vis/files")`. The `correct_graph` flag should be true if you are visualizing something that came from `convert_story.py`, and false if you are visualizing a network output. `options` should be a dictionary that contains various options for the visualization: 248 | 249 | - `width` sets the width of the visualization 250 | - `height` sets the height of the visualization 251 | - `jitter` determines if any jitter is applied to the nodes 252 | - `timestep` determines what timestep to start on (this is configurable interactively) 253 | - `linkDistance` determines how long edges are (this is configurable interactively) 254 | - `linkStrength` determines how much edges resist changes in length (this is configurable interactively) 255 | - `gravity` determines how much nodes are attracted to one another (this is configurable interactively) 256 | - `charge` determines how much nodes repel each other up close (this is configurable interactively) 257 | - `extra_snap_specs` is a list of items of the form `{"id":0,"axis":"y","value": 40,"strength":0.6}`, where `id` specifies an individual node type, `axis` is "x" or "y", `value` gives a position, and `strength` determines how strongly the nodes of that type are attracted to that position. This can be used to impose constraints on the visualization based on the task (so that automaton cells line up in the middle in between values, for example). 258 | - `edge_strength_adjust` is a list of numbers, one for each edge type, which specify how strongly each type of edge resists being stretched. 259 | - `noninteractive` will make the graph non-interactive, so that it doesn't animate or respond to user input. 260 | + `fullAlphaTicks` will determine how long to simulate forces in the graph before halting it, if running noninteractively. 261 | 262 | If you install phantomjs, you can also directly export visualizations as images. First, make a file called `options.py` in the same directory as the visualization files, of the form 263 | ``` 264 | options = { ... } 265 | ``` 266 | where the dictionary has the options specified above. Then run 267 | ``` 268 | phantomjs display/generate_images.js path/to/vis/directory 1 269 | ``` 270 | where the trailing number can be increased to scale the size of the image produced. 271 | 272 | ## Training the extended sequential model (from the appendix) 273 | 274 | The extended sequential model can be enabled using a few extra parameters. To train using `main.py` directly, use this argument: 275 | 276 | ``` 277 | --sequence-aggregate-repr 278 | Compute the query aggregate representation from the 279 | sequence of graphs instead of just the last one 280 | (default: False) 281 | ``` 282 | 283 | If you want to run the two bAbI tasks used in the paper with the extended model, you can pass `--run-sequential-set` to `do_babi_run.py`, which will have the same effect. 284 | 285 | Note that in order to generate the dataset for this model, the additional history nodes do not need to be added. Thus in the files `WhereWasObject.lua` and `WhoWhatGave.lua`, the line including "augment_with_value_histories" should be commented out before generating the dataset for those tasks. -------------------------------------------------------------------------------- /adam.py: -------------------------------------------------------------------------------- 1 | """ 2 | The MIT License (MIT) 3 | Copyright (c) 2015 Alec Radford 4 | Permission is hereby granted, free of charge, to any person obtaining a copy 5 | of this software and associated documentation files (the "Software"), to deal 6 | in the Software without restriction, including without limitation the rights 7 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | copies of the Software, and to permit persons to whom the Software is 9 | furnished to do so, subject to the following conditions: 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 13 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 14 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 15 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 16 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 17 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 18 | SOFTWARE. 19 | """ 20 | import theano 21 | import theano.tensor as T 22 | import numpy as np 23 | 24 | def Adam(cost, params, lr=0.0002, b1=0.1, b2=0.001, e=1e-8): 25 | updates = [] 26 | grads = T.grad(cost, params) 27 | i = theano.shared(np.array(0., theano.config.floatX)) 28 | i_t = i + 1. 29 | fix1 = 1. - (1. - b1)**i_t 30 | fix2 = 1. - (1. - b2)**i_t 31 | lr_t = lr * (T.sqrt(fix2) / fix1) 32 | for p, g in zip(params, grads): 33 | m = theano.shared(p.get_value() * 0.) 34 | v = theano.shared(p.get_value() * 0.) 35 | m_t = (b1 * g) + ((1. - b1) * m) 36 | v_t = (b2 * T.sqr(g)) + ((1. - b2) * v) 37 | g_t = m_t / (T.sqrt(v_t) + e) 38 | p_t = p - (lr_t * g_t) 39 | updates.append((m, m_t)) 40 | updates.append((v, v_t)) 41 | updates.append((p, p_t)) 42 | updates.append((i, i_t)) 43 | return updates -------------------------------------------------------------------------------- /base_gru.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | import numpy as np 4 | 5 | from util import * 6 | 7 | class BaseGRULayer( object ): 8 | """ 9 | Implements a GRU layer 10 | """ 11 | 12 | def __init__(self, input_width, output_width, activation_shift=0.0, name=None, dropout_keep=1, dropout_input=False, dropout_output=True): 13 | """ 14 | Params: 15 | input_width: Width of input 16 | output_width: Width of the GRU output 17 | activation_shift: How to shift the biases of the activation 18 | """ 19 | self._input_width = input_width 20 | self._output_width = output_width 21 | 22 | prefix = "" if name is None else name + "_" 23 | 24 | self._reset_W = theano.shared(init_params([input_width + output_width, output_width]), prefix+"reset_W") 25 | self._reset_b = theano.shared(init_params([output_width], shift=1.0), prefix+"reset_b") 26 | 27 | self._update_W = theano.shared(init_params([input_width + output_width, output_width]), prefix+"update_W") 28 | self._update_b = theano.shared(init_params([output_width], shift=1.0), prefix+"update_b") 29 | 30 | self._activation_W = theano.shared(init_params([input_width + output_width, output_width]), prefix+"activation_W") 31 | self._activation_b = theano.shared(init_params([output_width], shift=activation_shift), prefix+"activation_b") 32 | 33 | self._dropout_keep = dropout_keep 34 | self._dropout_input = dropout_input 35 | self._dropout_output = dropout_output 36 | 37 | @property 38 | def input_width(self): 39 | return self._input_width 40 | 41 | @property 42 | def output_width(self): 43 | return self._output_width 44 | 45 | @property 46 | def params(self): 47 | return [self._reset_W, self._reset_b, self._update_W, self._update_b, self._activation_W, self._activation_b] 48 | 49 | def initial_state(self, batch_size): 50 | """ 51 | The initial state of the network 52 | Params: 53 | batch_size: The batch size to construct the initial state for 54 | """ 55 | return T.zeros([batch_size, self.output_width]) 56 | 57 | def dropout_masks(self, srng, use_output=None): 58 | if self._dropout_keep == 1: 59 | return [] 60 | else: 61 | masks = [] 62 | if self._dropout_input: 63 | masks.append(make_dropout_mask((self._input_width,), self._dropout_keep, srng)) 64 | if self._dropout_output: 65 | if use_output is not None: 66 | masks.append(use_output) 67 | else: 68 | masks.append(make_dropout_mask((self._output_width,), self._dropout_keep, srng)) 69 | return masks 70 | 71 | def split_dropout_masks(self, dropout_masks): 72 | if dropout_masks is None: 73 | return [], None 74 | idx = (self._dropout_keep != 1) * (self._dropout_input + self._dropout_output) 75 | return dropout_masks[:idx], dropout_masks[idx:] 76 | 77 | def step(self, ipt, state, dropout_masks=Ellipsis): 78 | """ 79 | Perform a single step of the network 80 | 81 | Params: 82 | ipt: The current input. Should be an int tensor of shape (n_batch, self.input_width) 83 | state: The previous state. Should be a float tensor of shape (n_batch, self.output_width) 84 | dropout_masks: Masks from get_dropout_masks 85 | 86 | Returns: The next output state 87 | """ 88 | if dropout_masks is Ellipsis: 89 | dropout_masks = None 90 | append_masks = False 91 | else: 92 | append_masks = True 93 | 94 | if self._dropout_keep != 1 and self._dropout_input and dropout_masks is not None: 95 | ipt_masks = dropout_masks[0] 96 | ipt = apply_dropout(ipt, ipt_masks) 97 | dropout_masks = dropout_masks[1:] 98 | 99 | cat_ipt_state = T.concatenate([ipt, state], 1) 100 | reset = do_layer( T.nnet.sigmoid, cat_ipt_state, 101 | self._reset_W, self._reset_b ) 102 | update = do_layer( T.nnet.sigmoid, cat_ipt_state, 103 | self._update_W, self._update_b ) 104 | candidate_act = do_layer( T.tanh, T.concatenate([ipt, (reset * state)], 1), 105 | self._activation_W, self._activation_b ) 106 | 107 | newstate = update * state + (1-update) * candidate_act 108 | 109 | if self._dropout_keep != 1 and self._dropout_output and dropout_masks is not None: 110 | newstate_masks = dropout_masks[0] 111 | newstate = apply_dropout(newstate, newstate_masks) 112 | dropout_masks = dropout_masks[1:] 113 | 114 | if append_masks: 115 | return newstate, dropout_masks 116 | else: 117 | return newstate 118 | -------------------------------------------------------------------------------- /convert_story.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import ggtnn_graph_parse 5 | from ggtnn_graph_parse import PreppedStory 6 | import gzip 7 | import pickle 8 | 9 | def convert(story): 10 | # import pdb; pdb.set_trace() 11 | sentence_arr, graphs, query_arr, answer_arr = story 12 | node_id_w = graphs[2].shape[2] 13 | edge_type_w = graphs[3].shape[3] 14 | 15 | all_node_strengths = [np.zeros([1])] 16 | all_node_ids = [np.zeros([1,node_id_w])] 17 | for num_new_nodes, new_node_strengths, new_node_ids, _ in zip(*graphs): 18 | last_strengths = all_node_strengths[-1] 19 | last_ids = all_node_ids[-1] 20 | 21 | cur_strengths = np.concatenate([last_strengths, new_node_strengths], 0) 22 | cur_ids = np.concatenate([last_ids, new_node_ids], 0) 23 | 24 | all_node_strengths.append(cur_strengths) 25 | all_node_ids.append(cur_ids) 26 | 27 | all_edges = graphs[3] 28 | full_n_nodes = all_edges.shape[1] 29 | all_node_strengths = np.stack([np.pad(x, ((0, full_n_nodes-x.shape[0])), 'constant') for x in all_node_strengths[1:]]) 30 | all_node_ids = np.stack([np.pad(x, ((0, full_n_nodes-x.shape[0]), (0, 0)), 'constant') for x in all_node_ids[1:]]) 31 | all_node_states = np.zeros([len(all_node_strengths), full_n_nodes,0]) 32 | 33 | return tuple(x[np.newaxis,...] for x in (all_node_strengths, all_node_ids, all_node_states, all_edges)) 34 | 35 | def main(storyfile, outputdir): 36 | 37 | with gzip.open(storyfile,'rb') as f: 38 | story, sents, query, ans = pickle.load(f) 39 | 40 | with open(os.path.join(outputdir,'story.txt'),'w') as f: 41 | f.write("{}\n{}\n{}".format("\n".join(" ".join(s) for s in sents), " ".join(query), " ".join(ans))) 42 | 43 | results = convert(story) 44 | if not os.path.exists(outputdir): 45 | os.makedirs(outputdir) 46 | for i,res in enumerate(results): 47 | np.save(os.path.join(outputdir,'result_{}.npy'.format(i)), res) 48 | 49 | parser = argparse.ArgumentParser(description='Convert a story to graph') 50 | parser.add_argument("storyfile", help="Story filename") 51 | parser.add_argument("outputdir", help="Output directory") 52 | 53 | if __name__ == '__main__': 54 | args = vars(parser.parse_args()) 55 | main(**args) 56 | 57 | 58 | -------------------------------------------------------------------------------- /display/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danieldjohnson/gated-graph-transformer-network/1fd7df8b5629152a7afa9a2a8a67346d52bf0d94/display/__init__.py -------------------------------------------------------------------------------- /display/display_graph.js: -------------------------------------------------------------------------------- 1 | require.config({ 2 | paths: { 3 | d3: 'http://cdnjs.cloudflare.com/ajax/libs/d3/4.1.0/d3.min', 4 | dat: 'http://cdnjs.cloudflare.com/ajax/libs/dat-gui/0.5.1/dat.gui.min' 5 | } 6 | }); 7 | 8 | require(['d3','dat'], function(d3,_ignored){ 9 | function _graph_display(states,colormap,el,batch,options){ 10 | var node_strengths = states[0]; 11 | var node_ids = states[1]; 12 | var node_states = states[2]; 13 | var edge_strengths = states[3]; 14 | var max_time = node_strengths[batch].length; 15 | 16 | var width = options.width || 500; 17 | var height = options.height || 500; 18 | 19 | var svg = d3.select(el).append("svg").attr("width",width).attr("height",height); 20 | 21 | if(!options) 22 | options = {} 23 | 24 | var node_map = {}; 25 | var data_nodes = []; 26 | var data_edges = []; 27 | var display_edges = []; 28 | var selection_map, selection_options; 29 | var force = d3.forceSimulation() 30 | .force("charge", d3.forceManyBody()) 31 | .force("link", d3.forceLink()) 32 | .force("gravityX", d3.forceX(width/2)) 33 | .force("gravityY", d3.forceY(height/2)); 34 | 35 | var extra_snap_specs = options.extra_snap_specs || []; 36 | var extra_forces = []; 37 | for(var i=0; i0.1){ 183 | tmp_edges.push(c_edge); 184 | } 185 | if(eff_str>0.03){ 186 | tmp_display_edges.push(c_edge); 187 | } 188 | } 189 | } 190 | for (var i=0; i20) div_w = 20; 270 | focus_detail = focus_detail.data(datalist) 271 | focus_detail.exit().remove(); 272 | focus_detail = focus_detail.enter().append("rect") 273 | .merge(focus_detail) 274 | .attr('fill',function(d){return d3.interpolateViridis(d).toString()}) 275 | .attr('width',div_w) 276 | .attr('height',20) 277 | .attr('x',function(d,i){return div_w*i}) 278 | .attr('y',height-20); 279 | } 280 | function do_focus(d){ 281 | if(options.noninteractive) 282 | return; 283 | console.log("Focusing on ", d) 284 | update_focus(d.data); 285 | } 286 | 287 | var gui = new dat.GUI({ autoPlace: false }); 288 | 289 | if(!options.noninteractive) 290 | el.insertBefore(gui.domElement, el.firstChild); 291 | 292 | gui.add(params,"linkDistance").min(0).max(200).onChange(function(value) { 293 | force.force("link").distance(value); 294 | force.alpha(1).restart(); 295 | }); 296 | gui.add(params,"linkStrength").min(0).max(1).onChange(function(value) { 297 | force.force("link").strength(function(link,i){ 298 | return value*link.link_force_strength; 299 | }) 300 | force.alpha(1).restart(); 301 | }); 302 | 303 | gui.add(params,"gravity").min(0).max(0.2).onChange(function(value) { 304 | force.force("gravityX").strength(value); 305 | force.force("gravityY").strength(value); 306 | force.alpha(1).restart(); 307 | }); 308 | gui.add(params,"charge").min(0).max(200).onChange(function(value) { 309 | force.alpha(1).restart(); 310 | }); 311 | 312 | if(options.jitter){ 313 | gui.add(params,"jitterScale").min(0).max(200); 314 | } 315 | 316 | var last_timestep = params.timestep; 317 | gui.add(params,"timestep").min(0).max(max_time-1).step(1).onChange(function(value) { 318 | if(value != last_timestep){ 319 | update_state(value); 320 | last_timestep = value; 321 | } 322 | }); 323 | 324 | function noninteractive_update(){ 325 | force.stop(); 326 | var startTicks = options.fullAlphaTicks || 0; 327 | for(var i=0; i force.alphaMin()){ 332 | force.tick(); 333 | } 334 | redraw(); 335 | } 336 | if(options.noninteractive){ 337 | noninteractive_update(); 338 | return function(){ 339 | params.timestep++; 340 | if(params.timestep < max_time){ 341 | update_state(params.timestep); 342 | noninteractive_update(); 343 | return true; 344 | } else 345 | return false; 346 | } 347 | } 348 | } 349 | window._graph_display = _graph_display; 350 | console.log("Loaded display_graph"); 351 | if(element) 352 | element.text("Done!"); 353 | }); -------------------------------------------------------------------------------- /display/display_graph.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from IPython.display import Javascript 4 | import json 5 | import itertools 6 | import pickle 7 | 8 | from IPython.core.display import HTML 9 | 10 | from .tolcolormap import cm_rainbow 11 | 12 | def prep_graph_display(states, options={}): 13 | clean_states = [x.tolist() for x in states] 14 | nstr, nid, nstate, estr = states 15 | flat_nid = nid.reshape([-1,nid.shape[-1]]) 16 | flat_estr = estr.reshape([-1,estr.shape[-1]]) 17 | flat_estr = flat_estr / (np.linalg.norm(flat_estr, axis=1, keepdims=True) + 1e-8) 18 | 19 | num_unique_colors = nid.shape[-1] + estr.shape[-1] 20 | id_denom = max(nid.shape[-1] - 1, 1) 21 | id_project_mat = np.array([list(cm_rainbow(i/id_denom)[:3]) for i in range(0,nid.shape[-1])]) 22 | estr_denom = estr.shape[-1] 23 | estr_project_mat = np.array([list(cm_rainbow((i+0.37)/estr_denom)[:3]) for i in range(estr.shape[-1])]) 24 | node_colors = np.dot(flat_nid, id_project_mat) 25 | edge_colors = np.dot(flat_estr, estr_project_mat) 26 | 27 | colormap = { 28 | "node_id": node_colors.reshape(nid.shape[:-1] + (3,)).tolist(), 29 | "edge_type": edge_colors.reshape(estr.shape[:-1] + (3,)).tolist(), 30 | } 31 | 32 | return json.dumps({ 33 | "states":clean_states, 34 | "colormap":colormap, 35 | "options":options 36 | }) 37 | 38 | def graph_display(states, options={}): 39 | stuff = prep_graph_display(states,options) 40 | return Javascript("var tmp={}; window.nonint_next = window._graph_display(tmp.states, tmp.colormap, element[0], 0, tmp.options);".format(stuff)) 41 | 42 | def noninteractive_next(): 43 | return Javascript("window.nonint_next()") 44 | 45 | def setup_graph_display(): 46 | with open(os.path.join(os.path.dirname(__file__), "display_graph.js"), 'r') as f: 47 | JS_SETUP_STRING = f.read() 48 | return Javascript(JS_SETUP_STRING) 49 | 50 | def main(visdir): 51 | results = [] 52 | has_answer = os.path.isfile("{}/result_{}.npy".format(visdir,4)) 53 | the_range = range(1,5) if has_answer else range(4) 54 | results = [np.load("{}/result_{}.npy".format(visdir,i)) for i in the_range] 55 | import importlib.machinery 56 | try: 57 | options_mod = importlib.machinery.SourceFileLoader('options',os.path.join(visdir,"options.py")).load_module() 58 | options = options_mod.options 59 | except FileNotFoundError: 60 | options = {} 61 | print(prep_graph_display(results,options)) 62 | 63 | import argparse 64 | parser = argparse.ArgumentParser(description='Convert a visualization to JSON format') 65 | parser.add_argument("visdir", help="Directory to visualization files") 66 | 67 | if __name__ == '__main__': 68 | args = vars(parser.parse_args()) 69 | main(**args) 70 | -------------------------------------------------------------------------------- /display/generate_images.js: -------------------------------------------------------------------------------- 1 | #! phantomjs 2 | function assert(condition, message) { 3 | if (!condition) { 4 | console.error(message); 5 | message = message || "Assertion failed"; 6 | if (typeof Error !== "undefined") { 7 | throw new Error(message); 8 | } 9 | throw message; // Fallback 10 | } 11 | } 12 | 13 | var system = require('system'); 14 | var visdir = system.args[1]; 15 | var scale = system.args[2]; 16 | if(scale === undefined) 17 | scale = 1; 18 | else 19 | scale = parseFloat(scale); 20 | var process = require("child_process"); 21 | var webPage = require('webpage'); 22 | var fs = require('fs'); 23 | 24 | console.log("Running python display script..."); 25 | process.execFile("python3", ["-m", "display.display_graph", visdir], null, function(err, stdout, stderr){ 26 | console.log("Parsing..."); 27 | var params_obj = JSON.parse(stdout); 28 | if(visdir.charAt(visdir.length-1) == fs.separator) 29 | visdir = visdir.substr(0, visdir.length-1); 30 | var imgdir = visdir + fs.separator + "generated_images" 31 | console.log("Creating images directory "+imgdir); 32 | assert(fs.isDirectory(imgdir) || fs.makeDirectory(imgdir), "Failed to make directory!"); 33 | 34 | params_obj.options.noninteractive = true; 35 | params_obj.options.timestep = 0; 36 | 37 | console.log("Starting image generation..."); 38 | var page = webPage.create(); 39 | page.viewportSize = { width: (params_obj.options.width || 500), height: (params_obj.options.height || 500) }; 40 | page.zoomFactor = scale; 41 | page.onConsoleMessage = function(msg) { 42 | // console.log("Page says: ", msg); 43 | if(msg == "Loaded display_graph"){ 44 | page.evaluate(function(params_obj){ 45 | window.next_fn = window._graph_display(params_obj.states, params_obj.colormap, document.body, 0, params_obj.options); 46 | }, params_obj); 47 | page.render(imgdir + fs.separator + '0.png'); 48 | for(var i=0; true; i++){ 49 | console.log("Writing image ", i); 50 | page.render(imgdir + fs.separator + i + '.png'); 51 | var has_more = page.evaluate(function(){ 52 | return window.next_fn(); 53 | }); 54 | if(!has_more) 55 | break; 56 | } 57 | phantom.exit(); 58 | } 59 | } 60 | page.includeJs("https://cdnjs.cloudflare.com/ajax/libs/require.js/2.2.0/require.min.js", function(){ 61 | page.injectJs("display/display_graph.js"); 62 | }) 63 | }) -------------------------------------------------------------------------------- /display/tolcolormap.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2.7 2 | # encoding: utf-8 3 | # 4 | # matplotlib-ref-density.py -- matplotlib example script 5 | # Copyright (C) 2011 Tim van Werkhoven (t.i.m.vanwerkhoven@xs4all.nl) 6 | # 7 | # This work is licensed under the Creative Commons Attribution-Share Alike 8 | # 3.0 Unported License. To view a copy of this license, visit 9 | # http://creativecommons.org/licenses/by-sa/3.0/ or send a letter to Creative 10 | # Commons, 171 Second Street, Suite 300, San Francisco, California,94105, USA. 11 | 12 | # import pylab as plt 13 | import numpy as N 14 | import matplotlib 15 | 16 | # Make colormap based on Paul Tol's best visibility gradients. See 17 | # for more info on these colors. Also see 18 | # 19 | # and on some 20 | # matplotlib examples 21 | 22 | # Deviation around zero colormap (blue--red) 23 | cols = [] 24 | for x in N.linspace(0,1, 256): 25 | rcol = 0.237 - 2.13*x + 26.92*x**2 - 65.5*x**3 + 63.5*x**4 - 22.36*x**5 26 | gcol = ((0.572 + 1.524*x - 1.811*x**2)/(1 - 0.291*x + 0.1574*x**2))**2 27 | bcol = 1/(1.579 - 4.03*x + 12.92*x**2 - 31.4*x**3 + 48.6*x**4 - 23.36*x**5) 28 | cols.append((rcol, gcol, bcol)) 29 | 30 | cm_plusmin = matplotlib.colors.LinearSegmentedColormap.from_list("PaulT_plusmin", cols) 31 | 32 | # Linear colormap (white--red) 33 | from scipy.special import erf 34 | 35 | cols = [] 36 | for x in N.linspace(0,1, 256): 37 | rcol = (1 - 0.392*(1 + erf((x - 0.869)/ 0.255))) 38 | gcol = (1.021 - 0.456*(1 + erf((x - 0.527)/ 0.376))) 39 | bcol = (1 - 0.493*(1 + erf((x - 0.272)/ 0.309))) 40 | cols.append((rcol, gcol, bcol)) 41 | 42 | cm_linear = matplotlib.colors.LinearSegmentedColormap.from_list("PaulT_linear", cols) 43 | 44 | # Linear colormap (rainbow) 45 | cols = [] 46 | # cols = [(0,0,0)] 47 | for x in N.linspace(0,1, 256): 48 | rcol = (0.472-0.567*x+4.05*x**2)/(1.+8.72*x-19.17*x**2+14.1*x**3) 49 | gcol = 0.108932-1.22635*x+27.284*x**2-98.577*x**3+163.3*x**4-131.395*x**5+40.634*x**6 50 | bcol = 1./(1.97+3.54*x-68.5*x**2+243*x**3-297*x**4+125*x**5) 51 | cols.append((rcol, gcol, bcol)) 52 | 53 | # cols.append((1,1,1)) 54 | cm_rainbow = matplotlib.colors.LinearSegmentedColormap.from_list("PaulT_rainbow", cols) 55 | 56 | if __name__ == '__main__': 57 | # Plot examples 58 | import matplotlib.pyplot as plt 59 | plt.ion() 60 | tmpim = N.arange(256).reshape(1,-1) 61 | plt.close() 62 | plt.title("www.sron.nl/~pault variation around zero colormap") 63 | plt.imshow(tmpim, cmap=plt.get_cmap(cm_plusmin), aspect='auto') 64 | plt.savefig("matplotlib-ref-plusmin.pdf") 65 | 66 | plt.close() 67 | plt.title("www.sron.nl/~pault linear colormap") 68 | plt.imshow(tmpim, cmap=plt.get_cmap(cm_linear), aspect='auto') 69 | plt.savefig("matplotlib-ref-linear.pdf") 70 | 71 | plt.close() 72 | plt.title("www.sron.nl/~pault rainbow colormap") 73 | plt.imshow(tmpim, cmap=plt.get_cmap(cm_rainbow), aspect='auto') 74 | plt.savefig("matplotlib-ref-rainbow.pdf") 75 | 76 | 77 | # EOF 78 | -------------------------------------------------------------------------------- /do_babi_run.py: -------------------------------------------------------------------------------- 1 | import run_harness 2 | import argparse 3 | import os 4 | import shlex 5 | 6 | def main(tasks_dir, output_dir, excluding=[], including_only=None, run_sequential_set=False, just_setup=False, stop_on_error=False, extra_args=[], dataset_sizes=None, direct_ref_enabled=None): 7 | base_params = " ".join([ 8 | "20", 9 | "--mutable-nodes", 10 | "--dynamic-nodes", 11 | "--num-updates 3000", 12 | "--batch-size 100", 13 | "--final-params-only", 14 | "--learning-rate 0.002", 15 | "--save-params-interval 100", 16 | "--validation-interval 100", 17 | "--batch-adjust 16000000"] 18 | + [shlex.quote(s) for s in extra_args]) 19 | 20 | intermediate_propagate_tasks = {3,5} 21 | alt_sequential_set = {3,5} 22 | 23 | output_types = [ 24 | "category", # [1]='WhereIsActor', 25 | "category", # [2]='WhereIsObject', 26 | "category", # [3]='WhereWasObject', 27 | "category", # [4]='IsDir', 28 | "category", # [5]='WhoWhatGave', 29 | "category", # [6]='IsActorThere', 30 | "category", # [7]='Counting', 31 | "subset", # [8]='Listing', 32 | "category", # [9]='Negation', 33 | "category", # [10]='Indefinite', 34 | "category", # [11]='BasicCoreference', 35 | "category", # [12]='Conjunction', 36 | "category", # [13]='CompoundCoreference', 37 | "category", # [14]='Time', 38 | "category", # [15]='Deduction', 39 | "category", # [16]='Induction', 40 | "category", # [17]='PositionalReasoning', 41 | "category", # [18]='Size', 42 | "sequence", # [19]='PathFinding', 43 | "category", # [20]='Motivations' 44 | ] 45 | 46 | restrict_sizes=[50,100,250,500,1000] if dataset_sizes is None else dataset_sizes 47 | 48 | tasks_and_outputs = list(zip(range(1,21),output_types)) 49 | if run_sequential_set: 50 | base_params = base_params + " --sequence-aggregate-repr" 51 | tasks_and_outputs = [tasks_and_outputs[x-1] for x in alt_sequential_set] 52 | intermediate_propagate_tasks = set() 53 | 54 | if just_setup: 55 | base_params = base_params + " --just-compile" 56 | restrict_sizes = [1000] 57 | 58 | direct_ref_options = (True,False) if direct_ref_enabled is None else (direct_ref_enabled,) 59 | 60 | specs = [run_harness.TaskSpec( "task_{}".format(task_i), 61 | str(rsize) + ("-direct" if direct_ref else ""), 62 | "{} --restrict-dataset {} --stop-at-accuracy {} {} {} {}".format( 63 | output_type, 64 | rsize, 65 | "1.0" if rsize==1000 else "0.95", 66 | "--propagate-intermediate" if task_i in intermediate_propagate_tasks else "", 67 | "" if rsize==1000 else "--stop-at-overfitting 5", 68 | ("--direct-reference" if direct_ref else ""))) 69 | for rsize in reversed(restrict_sizes) 70 | for direct_ref in direct_ref_options 71 | for task_i, output_type in tasks_and_outputs] 72 | 73 | specs = [x for x in specs if x.task_name[5:] not in excluding] 74 | if including_only is not None: 75 | specs = [x for x in specs if x.task_name[5:] in including_only] 76 | # from pprint import pprint; pprint(specs); return 77 | run_harness.run(tasks_dir, output_dir, base_params, specs, stop_on_error=stop_on_error, skip_complete=just_setup) 78 | 79 | parser = argparse.ArgumentParser(description="Train all bAbI tasks.") 80 | parser.add_argument('tasks_dir', help="Directory with tasks") 81 | parser.add_argument('output_dir', help="Directory to save output to") 82 | parser.add_argument('--excluding', nargs='+', default=[], help="Tasks to exclude") 83 | parser.add_argument('--including-only', nargs='+', default=None, help="Tasks to include, if given, else all tasks") 84 | parser.add_argument('--run-sequential-set', action="store_true", help="Run tasks with sequential output instead, and only run tasks that need it") 85 | parser.add_argument('--just-setup', action="store_true", help="Just setup the tasks, don't actually run them") 86 | parser.add_argument('--stop-on-error', action="store_true", help="Stop if execution hits an error") 87 | parser.add_argument('--dataset-sizes', nargs='+', default=None, type=int, help="Run the model on these sizes of input") 88 | parser.add_argument('--direct-reference', action="store_true", dest="direct_ref_enabled", default=None, help="Only train with direct reference") 89 | parser.add_argument('--no-direct-reference', action="store_false", dest="direct_ref_enabled", default=None, help="Only train without direct reference") 90 | 91 | if __name__ == '__main__': 92 | namespace, extra = parser.parse_known_args() 93 | args = vars(namespace) 94 | main(extra_args=extra, **args) 95 | 96 | -------------------------------------------------------------------------------- /fix_old_file_list.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import argparse 4 | import sys 5 | 6 | def main(task_dir, dry_run=False): 7 | with open(os.path.join(task_dir,'file_list.p'),'rb') as f: 8 | bucketed = pickle.load(f) 9 | if dry_run: 10 | print("Got {} (for example)".format(bucketed[0][0])) 11 | bucketed = [['./bucket_' + x.split('bucket_')[1] for x in b] for b in bucketed] 12 | if dry_run: 13 | print("Converting to {} (for example)".format(bucketed[0][0])) 14 | print("Will resolve to {} (for example)".format(os.path.normpath(os.path.join(task_dir,bucketed[0][0])))) 15 | else: 16 | with open(os.path.join(task_dir,'file_list.p'),'wb') as f: 17 | pickle.dump(bucketed, f) 18 | 19 | parser = argparse.ArgumentParser(description='Fix the file list of a parsed directory.') 20 | parser.add_argument('task_dir', help="Directory of parsed files") 21 | parser.add_argument('--dry-run', action="store_true", help="Don't overwrite files") 22 | 23 | if __name__ == '__main__': 24 | args = vars(parser.parse_args()) 25 | main(**args) 26 | -------------------------------------------------------------------------------- /ggtnn_graph_parse.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | import collections 5 | import numpy as np 6 | import scipy 7 | import json 8 | import itertools 9 | import pickle 10 | import gc 11 | import gzip 12 | import argparse 13 | 14 | def tokenize(sent): 15 | '''Return the tokens of a sentence including punctuation. 16 | >>> tokenize('Bob dropped the apple. Where is the apple?') 17 | ['Bob', 'dropped', 'the', 'apple', '.', 'Where', 'is', 'the', 'apple', '?'] 18 | ''' 19 | return re.findall('(?:\w+)|\S',sent) 20 | 21 | def list_to_map(l): 22 | '''Convert a list of values to a map from values to indices''' 23 | return {val:i for i,val in enumerate(l)} 24 | 25 | def parse_stories(lines): 26 | ''' 27 | Parse stories provided in the bAbi tasks format, with knowledge graph. 28 | ''' 29 | data = [] 30 | story = [] 31 | for line in lines: 32 | if line[-1] == "\n": 33 | line = line[:-1] 34 | nid, line = line.split(' ', 1) 35 | nid = int(nid) 36 | if nid == 1: 37 | story = [] 38 | questions = [] 39 | if '\t' in line: 40 | q, apre = line.split('\t')[:2] 41 | a = apre.split(',') 42 | q = tokenize(q) 43 | substory = [x for x in story if x] 44 | data.append((substory, q, a)) 45 | story.append('') 46 | else: 47 | line, graph = line.split('=', 1) 48 | sent = tokenize(line) 49 | graph_parsed = json.loads(graph) 50 | story.append((sent, graph_parsed)) 51 | return data 52 | 53 | def get_stories(taskname): 54 | with open(taskname, 'r') as f: 55 | lines = f.readlines() 56 | return parse_stories(lines) 57 | 58 | def get_max_sentence_length(stories): 59 | return max((max((len(sentence) for (sentence, graph) in sents_graphs)) for (sents_graphs, query, answer) in stories)) 60 | 61 | def get_max_query_length(stories): 62 | return max((len(query) for (sents_graphs, query, answer) in stories)) 63 | 64 | def get_max_num_queries(stories): 65 | return max((len(queries) for (sents_graphs, query, answer) in stories)) 66 | 67 | def get_max_nodes_per_iter(stories): 68 | result = 0 69 | for (sents_graphs, query, answer) in stories: 70 | prev_nodes = set() 71 | for (sentence, graph) in sents_graphs: 72 | cur_nodes = set(graph["nodes"]) 73 | new_nodes = len(cur_nodes - prev_nodes) 74 | if new_nodes > result: 75 | result = new_nodes 76 | prev_nodes = cur_nodes 77 | return result 78 | 79 | def get_buckets(stories, max_ignore_unbatched=100, max_pad_amount=25): 80 | sentencecounts = [len(sents_graphs) for (sents_graphs, query, answer) in stories] 81 | countpairs = sorted(collections.Counter(sentencecounts).items()) 82 | 83 | buckets = [] 84 | smallest_left_val = 0 85 | num_unbatched = max_ignore_unbatched 86 | for val,ct in countpairs: 87 | num_unbatched += ct 88 | if val - smallest_left_val > max_pad_amount or num_unbatched > max_ignore_unbatched: 89 | buckets.append(val) 90 | smallest_left_val = val 91 | num_unbatched = 0 92 | if buckets[-1] != countpairs[-1][0]: 93 | buckets.append(countpairs[-1][0]) 94 | 95 | return buckets 96 | 97 | PAD_WORD = "" 98 | 99 | def get_wordlist(stories): 100 | words = [PAD_WORD] + sorted(list(set((word 101 | for (sents_graphs, query, answer) in stories 102 | for wordbag in itertools.chain((s for s,g in sents_graphs), [query]) 103 | for word in wordbag )))) 104 | wordmap = list_to_map(words) 105 | return words, wordmap 106 | 107 | def get_answer_list(stories): 108 | words = sorted(list(set(word for (sents_graphs, query, answer) in stories for word in answer))) 109 | wordmap = list_to_map(words) 110 | return words, wordmap 111 | 112 | def pad_story(story, num_sentences, sentence_length): 113 | def pad(lst,dlen,pad): 114 | return lst + [pad]*(dlen - len(lst)) 115 | 116 | sents_graphs, query, answer = story 117 | padded_sents_graphs = [(pad(s,sentence_length,PAD_WORD), g) for s,g in sents_graphs] 118 | padded_query = pad(query,sentence_length,PAD_WORD) 119 | 120 | sentgraph_padding = (pad([],sentence_length,PAD_WORD), padded_sents_graphs[-1][1]) 121 | return (pad(padded_sents_graphs, num_sentences, sentgraph_padding), padded_query, answer) 122 | 123 | def get_unqualified_id(s): 124 | return s.split("#")[0] 125 | 126 | def get_graph_lists(stories): 127 | node_words = sorted(list(set(get_unqualified_id(node) 128 | for (sents_graphs, query, answer) in stories 129 | for sent,graph in sents_graphs 130 | for node in graph["nodes"]))) 131 | nodemap = list_to_map(node_words) 132 | edge_words = sorted(list(set(get_unqualified_id(edge["type"]) 133 | for (sents_graphs, query, answer) in stories 134 | for sent,graph in sents_graphs 135 | for edge in graph["edges"]))) 136 | edgemap = list_to_map(edge_words) 137 | return node_words, nodemap, edge_words, edgemap 138 | 139 | def convert_graph(graphs, nodemap, edgemap, new_nodes_per_iter, dynamic=True): 140 | num_node_ids = len(nodemap) 141 | num_edge_types = len(edgemap) 142 | 143 | full_size = len(graphs)*new_nodes_per_iter + 1 144 | 145 | prev_size = 1 146 | processed_nodes = [] 147 | index_map = {} 148 | all_num_nodes = [] 149 | all_node_ids = [] 150 | all_node_strengths = [] 151 | all_edges = [] 152 | if not dynamic: 153 | processed_nodes = list(nodemap.keys()) 154 | index_map = nodemap.copy() 155 | prev_size = num_node_ids 156 | full_size = prev_size 157 | new_nodes_per_iter = 0 158 | for g in graphs: 159 | active_nodes = g["nodes"] 160 | active_edges = g["edges"] 161 | 162 | new_nodes = [e for e in active_nodes if e not in processed_nodes] 163 | 164 | num_new_nodes = len(new_nodes) 165 | if not dynamic: 166 | assert num_new_nodes == 0, "Cannot create more nodes in non-dynamic mode!\n{}".format(graphs) 167 | 168 | new_node_strengths = np.zeros([new_nodes_per_iter], np.float32) 169 | new_node_strengths[:num_new_nodes] = 1.0 170 | 171 | new_node_ids = np.zeros([new_nodes_per_iter, num_node_ids], np.float32) 172 | for i, node in enumerate(new_nodes): 173 | new_node_ids[i,nodemap[get_unqualified_id(node)]] = 1.0 174 | index_map[node] = prev_size + i 175 | 176 | next_edges = np.zeros([full_size, full_size, num_edge_types]) 177 | for edge in active_edges: 178 | next_edges[index_map[edge["from"]], 179 | index_map[edge["to"]], 180 | edgemap[get_unqualified_id(edge["type"])]] = 1.0 181 | 182 | processed_nodes.extend(new_nodes) 183 | prev_size += new_nodes_per_iter 184 | 185 | all_num_nodes.append(num_new_nodes) 186 | all_node_ids.append(new_node_ids) 187 | all_edges.append(next_edges) 188 | all_node_strengths.append(new_node_strengths) 189 | 190 | return np.stack(all_num_nodes), np.stack(all_node_strengths), np.stack(all_node_ids), np.stack(all_edges) 191 | 192 | def convert_story(story, wordmap, answer_map, graph_node_map, graph_edge_map, new_nodes_per_iter, dynamic=True): 193 | """ 194 | Converts a story in format 195 | ([(sentence, graph)], [(index, question_arr, answer)]) 196 | to a consolidated story in format 197 | (sentence_arr, [graph_arr_dict], [(index, question_arr, answer)]) 198 | and also replaces words according to the input maps 199 | """ 200 | sents_graphs, query, answer = story 201 | 202 | sentence_arr = [[wordmap[w] for w in s] for s,g in sents_graphs] 203 | graphs = convert_graph([g for s,g in sents_graphs], graph_node_map, graph_edge_map, new_nodes_per_iter, dynamic) 204 | query_arr = [wordmap[w] for w in query] 205 | answer_arr = [answer_map[w] for w in answer] 206 | return (sentence_arr, graphs, query_arr, answer_arr) 207 | 208 | def process_story(s,bucket_len): 209 | return convert_story(pad_story(s, bucket_len, sentence_length), wordmap, answer_map, graph_node_map, graph_edge_map, new_nodes_per_iter, dynamic) 210 | 211 | def bucket_stories(stories, buckets, wordmap, answer_map, graph_node_map, graph_edge_map, sentence_length, new_nodes_per_iter, dynamic=True): 212 | return [ [process_story(story,bmax) for story in stories if bstart < len(story[0]) <= bmax] 213 | for bstart, bmax in zip([0]+buckets,buckets)] 214 | 215 | def prepare_stories(stories, dynamic=True): 216 | sentence_length = max(get_max_sentence_length(stories), get_max_query_length(stories)) 217 | buckets = get_buckets(stories) 218 | wordlist, wordmap = get_wordlist(stories) 219 | anslist, ansmap = get_answer_list(stories) 220 | new_nodes_per_iter = get_max_nodes_per_iter(stories) 221 | 222 | graph_node_list, graph_node_map, graph_edge_list, graph_edge_map = get_graph_lists(stories) 223 | bucketed = bucket_stories(stories, buckets, wordmap, ansmap, graph_node_map, graph_edge_map, sentence_length, new_nodes_per_iter, dynamic) 224 | return sentence_length, new_nodes_per_iter, buckets, wordlist, anslist, graph_node_list, graph_edge_list, bucketed 225 | 226 | def print_batch(story, wordlist, anslist, file=sys.stdout): 227 | sents, query, answer = story 228 | for batch,(s,q,a) in enumerate(zip(sents,query,answer)): 229 | file.write("Story {}\n".format(batch)) 230 | for sent in s: 231 | file.write(" ".join([wordlist[word] for word in sent]) + "\n") 232 | file.write(" ".join(wordlist[word] for word in q) + "\n") 233 | file.write(" ".join(anslist[word] for word in a.nonzero()[1]) + "\n") 234 | 235 | MetadataList = collections.namedtuple("MetadataList", ["sentence_length", "new_nodes_per_iter", "buckets", "wordlist", "anslist", "graph_node_list", "graph_edge_list"]) 236 | PreppedStory = collections.namedtuple("PreppedStory", ["converted", "sentences", "query", "answer"]) 237 | def generate_metadata(stories, dynamic=True): 238 | sentence_length = max(get_max_sentence_length(stories), get_max_query_length(stories)) 239 | buckets = get_buckets(stories) 240 | wordlist, wordmap = get_wordlist(stories) 241 | anslist, ansmap = get_answer_list(stories) 242 | new_nodes_per_iter = get_max_nodes_per_iter(stories) 243 | graph_node_list, graph_node_map, graph_edge_list, graph_edge_map = get_graph_lists(stories) 244 | metadata = MetadataList(sentence_length, new_nodes_per_iter, buckets, wordlist, anslist, graph_node_list, graph_edge_list) 245 | return metadata 246 | 247 | def preprocess_stories(stories, savedir, dynamic=True, metadata_file=None): 248 | if metadata_file is None: 249 | metadata = generate_metadata(stories, dynamic) 250 | else: 251 | with open(metadata_file,'rb') as f: 252 | metadata = pickle.load(f) 253 | 254 | buckets = get_buckets(stories) 255 | sentence_length, new_nodes_per_iter, old_buckets, wordlist, anslist, graph_node_list, graph_edge_list = metadata 256 | metadata = metadata._replace(buckets=buckets) 257 | 258 | if not os.path.exists(savedir): 259 | os.makedirs(savedir) 260 | 261 | with open(os.path.join(savedir,'metadata.p'),'wb') as f: 262 | pickle.dump(metadata, f) 263 | 264 | bucketed_files = [[] for _ in buckets] 265 | 266 | for i,story in enumerate(stories): 267 | bucket_idx, cur_bucket = next(((i,bmax) for (i,(bstart, bmax)) in enumerate(zip([0]+buckets,buckets)) 268 | if bstart < len(story[0]) <= bmax), (None,None)) 269 | assert cur_bucket is not None, "Couldn't put story of length {} into buckets {}".format(len(story[0]), buckets) 270 | bucket_dir = os.path.join(savedir, "bucket_{}".format(cur_bucket)) 271 | if not os.path.exists(bucket_dir): 272 | os.makedirs(bucket_dir) 273 | story_fn = os.path.join(bucket_dir, "story_{}.pz".format(i)) 274 | 275 | sents_graphs, query, answer = story 276 | sents = [s for s,g in sents_graphs] 277 | cvtd = convert_story(pad_story(story, cur_bucket, sentence_length), list_to_map(wordlist), list_to_map(anslist), list_to_map(graph_node_list), list_to_map(graph_edge_list), new_nodes_per_iter, dynamic) 278 | 279 | prepped = PreppedStory(cvtd, sents, query, answer) 280 | 281 | with gzip.open(story_fn, 'wb') as zf: 282 | pickle.dump(prepped, zf) 283 | 284 | bucketed_files[bucket_idx].append(os.path.relpath(story_fn, savedir)) 285 | gc.collect() # we don't want to use too much memory, so try to clean it up 286 | 287 | with open(os.path.join(savedir,'file_list.p'),'wb') as f: 288 | pickle.dump(bucketed_files, f) 289 | 290 | def main(file, dynamic, metadata_file=None): 291 | stories = get_stories(file) 292 | dirname, ext = os.path.splitext(file) 293 | preprocess_stories(stories, dirname, dynamic, metadata_file) 294 | 295 | if __name__ == '__main__': 296 | parser = argparse.ArgumentParser(description='Parse a graph file') 297 | parser.add_argument("file", help="Graph file to parse") 298 | parser.add_argument("--static", dest="dynamic", action="store_false", help="Don't use dynamic nodes") 299 | parser.add_argument("--metadata-file", default=None, help="Use this particular metadata file instead of building it from scratch") 300 | args = vars(parser.parse_args()) 301 | main(**args) 302 | -------------------------------------------------------------------------------- /ggtnn_train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pickle 4 | import model 5 | import random 6 | import ggtnn_graph_parse 7 | import convert_story 8 | import gzip 9 | from enum import Enum 10 | from ggtnn_graph_parse import MetadataList, PreppedStory 11 | from graceful_interrupt import GracefulInterruptHandler 12 | from pprint import pformat 13 | import util 14 | from train_exit_status import TrainExitStatus 15 | from functools import reduce 16 | 17 | BATCH_SIZE = 10 18 | 19 | def convert_answer(answer, num_words, format_spec, maxlen): 20 | """ 21 | Convert an answer into an appropriate answer matrix given 22 | a ModelOutputFormat. 23 | 24 | num_words should be after processing with get_effective_answer_words, 25 | so that the last word is the "stop" word 26 | """ 27 | assert format_spec in model.ModelOutputFormat 28 | if format_spec == model.ModelOutputFormat.subset: 29 | ans_mat = np.zeros((1,num_words), np.float32) 30 | for word in answer: 31 | ans_mat[0, word] = 1.0 32 | elif format_spec == model.ModelOutputFormat.category: 33 | ans_mat = np.zeros((1,num_words), np.float32) 34 | ans_mat[0,answer[0]] = 1.0 35 | elif format_spec == model.ModelOutputFormat.sequence: 36 | ans_mat = np.zeros((maxlen+1,num_words), np.float32) 37 | for i,word in enumerate(answer+[num_words-1]*(maxlen+1-len(answer))): 38 | ans_mat[i, word] = 1.0 39 | return ans_mat 40 | 41 | def get_effective_answer_words(answer_words, format_spec): 42 | """ 43 | If needed, modify answer_words using format spec to add padding chars 44 | """ 45 | if format_spec == model.ModelOutputFormat.sequence: 46 | return answer_words + [""] 47 | else: 48 | return answer_words 49 | 50 | def sample_batch(matching_stories, batch_size, num_answer_words, format_spec): 51 | chosen_stories = [random.choice(matching_stories) for _ in range(batch_size)] 52 | return assemble_batch(chosen_stories, num_answer_words, format_spec) 53 | 54 | def assemble_batch(story_fns, num_answer_words, format_spec): 55 | stories = [] 56 | for sfn in story_fns: 57 | with gzip.open(sfn,'rb') as f: 58 | cvtd_story, _, _, _ = pickle.load(f) 59 | stories.append(cvtd_story) 60 | sents, graphs, queries, answers = zip(*stories) 61 | cvtd_sents = np.array(sents, np.int32) 62 | cvtd_queries = np.array(queries, np.int32) 63 | max_ans_len = max(len(a) for a in answers) 64 | cvtd_answers = np.stack([convert_answer(answer, num_answer_words, format_spec, max_ans_len) for answer in answers]) 65 | num_new_nodes, new_node_strengths, new_node_ids, next_edges = zip(*graphs) 66 | num_new_nodes = np.stack(num_new_nodes) 67 | new_node_strengths = np.stack(new_node_strengths) 68 | new_node_ids = np.stack(new_node_ids) 69 | next_edges = np.stack(next_edges) 70 | return cvtd_sents, cvtd_queries, cvtd_answers, num_new_nodes, new_node_strengths, new_node_ids, next_edges 71 | 72 | def assemble_correct_graphs(story_fns): 73 | correct_strengths, correct_ids, correct_edges = [], [], [] 74 | for sfn in story_fns: 75 | with gzip.open(sfn,'rb') as f: 76 | cvtd_story, _, _, _ = pickle.load(f) 77 | strengths, ids, _, edges = convert_story.convert(cvtd_story) 78 | correct_strengths.append(strengths) 79 | correct_ids.append(ids) 80 | correct_edges.append(edges) 81 | return tuple(np.concatenate(l,0) for l in (correct_strengths, correct_ids, correct_edges)) 82 | 83 | def visualize(m, story_buckets, wordlist, answerlist, output_format, outputdir, batch_size=1, seq_len=5, debugmode=False, snap=False): 84 | cur_bucket = random.choice(story_buckets) 85 | sampled_batch = sample_batch(cur_bucket, batch_size, len(answerlist), output_format) 86 | part_sampled_batch = sampled_batch[:3] 87 | with open(os.path.join(outputdir,'stories.txt'),'w') as f: 88 | ggtnn_graph_parse.print_batch(part_sampled_batch, wordlist, answerlist, file=f) 89 | with open(os.path.join(outputdir,'answer_list.txt'),'w') as f: 90 | f.write('\n'.join(answerlist) + '\n') 91 | if debugmode: 92 | args = sampled_batch 93 | fn = m.debug_test_fn 94 | else: 95 | args = part_sampled_batch[:2] + ((seq_len,) if output_format == model.ModelOutputFormat.sequence else ()) 96 | fn = m.snap_test_fn if snap else m.fuzzy_test_fn 97 | results = fn(*args) 98 | for i,result in enumerate(results): 99 | np.save(os.path.join(outputdir,'result_{}.npy'.format(i)), result) 100 | 101 | def test_accuracy(m, story_buckets, bucket_sizes, num_answer_words, format_spec, batch_size, batch_auto_adjust=None, test_graph=False): 102 | correct = 0 103 | out_of = 0 104 | for bucket, bucket_size in zip(story_buckets, bucket_sizes): 105 | cur_batch_size = adj_size(m, bucket_size, batch_size, batch_auto_adjust) 106 | for start_idx in range(0, len(bucket), cur_batch_size): 107 | stories = bucket[start_idx:start_idx+cur_batch_size] 108 | batch = assemble_batch(stories, num_answer_words, format_spec) 109 | answers = batch[2] 110 | args = batch[:2] + ((answers.shape[1],) if format_spec == model.ModelOutputFormat.sequence else ()) 111 | 112 | if test_graph: 113 | _, batch_close, _ = m.eval(*batch, with_accuracy=True) 114 | else: 115 | out_answers, out_strengths, out_ids, out_states, out_edges = m.snap_test_fn(*args) 116 | close = np.isclose(out_answers, answers) 117 | batch_close = np.all(close, (1,2)) 118 | 119 | print(batch_close) 120 | 121 | batch_correct = np.sum(batch_close).tolist() 122 | batch_out_of = len(stories) 123 | correct += batch_correct 124 | out_of += batch_out_of 125 | 126 | return correct/out_of 127 | 128 | def adj_size(m, cur_bucket_size, batch_size, batch_auto_adjust): 129 | if batch_auto_adjust is not None: 130 | # Adjust batch size for this bucket 131 | edge_size = (cur_bucket_size**3) * (m.new_nodes_per_iter**2) * m.num_edge_types 132 | if m.sequence_representation: 133 | # In sequence representation mode, we are doing stuff with all objects at the same time 134 | # so add a multiple of the edge size to get a nice bound 135 | edge_size = edge_size * 4 136 | max_batch_size = batch_auto_adjust//edge_size 137 | return min(batch_size, max_batch_size) 138 | else: 139 | return batch_size 140 | 141 | def train(m, story_buckets, bucket_sizes, len_answers, output_format, num_updates, outputdir, start=0, batch_size=BATCH_SIZE, validation_buckets=None, validation_bucket_sizes=None, stop_at_accuracy=None, stop_at_loss=None, stop_at_overfitting=None, save_params=1000, validation_interval=1000, batch_auto_adjust=None, interrupt_file=None): 142 | with GracefulInterruptHandler() as interrupt_h: 143 | for i in range(start+1,num_updates+1): 144 | exit_with = None 145 | cur_bucket, cur_bucket_size = random.choice(list(zip(story_buckets, bucket_sizes))) 146 | cur_batch_size = adj_size(m, cur_bucket_size, batch_size, batch_auto_adjust) 147 | sampled_batch = sample_batch(cur_bucket, cur_batch_size, len_answers, output_format) 148 | loss, info = m.train(*sampled_batch) 149 | if np.any(np.isnan(loss)): 150 | print("Loss at timestep {} was nan! Aborting".format(i)) 151 | return TrainExitStatus.nan_loss # Don't bother saving 152 | with open(os.path.join(outputdir,'data.csv'),'a') as f: 153 | if i == 1: 154 | f.seek(0) 155 | f.truncate() 156 | keylist = "iter, loss, " + ", ".join(k for k,v in sorted(info.items())) + "\n" 157 | f.write(keylist) 158 | if validation_buckets is not None: 159 | with open(os.path.join(outputdir,'valid.csv'),'w') as f2: 160 | f2.write(keylist) 161 | f.write("{}, {},".format(i,loss) + ", ".join(str(v) for k,v in sorted(info.items())) + "\n") 162 | if i % 1 == 0: 163 | print("update {}: {}\n{}".format(i,loss,pformat(info))) 164 | if i % validation_interval == 0: 165 | if validation_buckets is not None: 166 | cur_bucket, cur_bucket_size = random.choice(list(zip(validation_buckets, validation_bucket_sizes))) 167 | cur_batch_size = adj_size(m, cur_bucket_size, batch_size, batch_auto_adjust) 168 | sampled_batch = sample_batch(cur_bucket, cur_batch_size, len_answers, output_format) 169 | valid_loss, valid_info = m.eval(*sampled_batch) 170 | print("validation at {}: {}\n{}".format(i,valid_loss,pformat(valid_info))) 171 | with open(os.path.join(outputdir,'valid.csv'),'a') as f: 172 | f.write("{}, {}, ".format(i,valid_loss) + ", ".join(str(v) for k,v in sorted(valid_info.items())) + "\n") 173 | valid_accuracy = test_accuracy(m, validation_buckets, validation_bucket_sizes, len_answers, output_format, batch_size, batch_auto_adjust, (not m.train_with_query)) 174 | print("Best-choice accuracy at {}: {}".format(i,valid_accuracy)) 175 | with open(os.path.join(outputdir,'valid_acc.csv'),'a') as f: 176 | f.write("{}, {}\n".format(i,valid_accuracy)) 177 | if stop_at_accuracy is not None and valid_accuracy >= stop_at_accuracy: 178 | print("Accuracy reached threshold! Stopping training") 179 | exit_with = TrainExitStatus.success 180 | if stop_at_loss is not None and valid_loss <= stop_at_loss: 181 | print("Loss reached threshold! Stopping training") 182 | exit_with = TrainExitStatus.success 183 | if stop_at_overfitting is not None and valid_loss/loss > stop_at_overfitting: 184 | print("Model appears to be overfitting! Stopping training") 185 | exit_with = TrainExitStatus.overfitting 186 | if exit_with is None and (interrupt_h.interrupted or (interrupt_file is not None and os.path.isfile(interrupt_file))): 187 | exit_with = TrainExitStatus.interrupted 188 | if (save_params is not None and i % save_params == 0) or (exit_with is not None) or (i==num_updates): 189 | util.save_params(m.params, open(os.path.join(outputdir, 'params{}.p'.format(i)), 'wb')) 190 | if exit_with is not None: 191 | return exit_with 192 | return TrainExitStatus.reached_update_limit 193 | -------------------------------------------------------------------------------- /graceful_interrupt.py: -------------------------------------------------------------------------------- 1 | # From http://stackoverflow.com/questions/1112343/how-do-i-capture-sigint-in-python 2 | 3 | import signal 4 | 5 | class GracefulInterruptHandler(object): 6 | 7 | def __init__(self, sig=signal.SIGINT): 8 | self.sig = sig 9 | 10 | def __enter__(self): 11 | 12 | self.interrupted = False 13 | self.released = False 14 | 15 | self.original_handler = signal.getsignal(self.sig) 16 | 17 | def handler(signum, frame): 18 | self.release() 19 | self.interrupted = True 20 | print("(Caught interrupt. Terminating when safe.... Press Ctrl-C again to force stop)") 21 | 22 | signal.signal(self.sig, handler) 23 | 24 | return self 25 | 26 | def __exit__(self, type, value, tb): 27 | self.release() 28 | 29 | def release(self): 30 | 31 | if self.released: 32 | return False 33 | 34 | signal.signal(self.sig, self.original_handler) 35 | 36 | self.released = True 37 | 38 | return True -------------------------------------------------------------------------------- /graph_state.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | import numpy as np 4 | from util import * 5 | 6 | from collections import namedtuple 7 | 8 | GraphStateSpec = namedtuple("GraphStateSpec", ["num_node_ids", "node_state_size", "num_edge_types"]) 9 | 10 | class GraphState( object ): 11 | """ 12 | A class representing the state of a graph. Wrapper for a few theano tensors 13 | """ 14 | def __init__(self, node_strengths, node_ids, node_states, edge_strengths): 15 | """ 16 | Create a graph state directly from existing nodes and edges. 17 | 18 | node_strengths: Tensor of shape (batch, n_nodes) 19 | node_ids: Tensor of shape (batch, n_nodes, num_node_ids) 20 | node_states: Tensor of shape (batch, n_nodes, node_state_size) 21 | edge_strengths: Tensor of shape (batch, n_nodes, n_nodes, num_edge_types) 22 | """ 23 | self._node_strengths = node_strengths 24 | self._node_ids = node_ids 25 | self._node_states = node_states 26 | self._edge_strengths = edge_strengths 27 | 28 | @classmethod 29 | def create_empty(cls, batch_size, num_node_ids, node_state_size, num_edge_types): 30 | """ 31 | Create an empty graph state with the specified sizes. Note that this 32 | will contain one zero-strength element to prevent nasty GPU errors 33 | from a dimension with 0 in it. 34 | 35 | batch_size: Number of batches 36 | num_node_ids: An integer giving size of node id 37 | node_state_size: An integer giving size of node state 38 | num_edge_types: An integer giving number of edge types 39 | """ 40 | return cls( T.unbroadcast(T.zeros([batch_size, 1]), 1), 41 | T.unbroadcast(T.zeros([batch_size, 1, num_node_ids]), 1), 42 | T.unbroadcast(T.zeros([batch_size, 1, node_state_size]), 1), 43 | T.unbroadcast(T.zeros([batch_size, 1, 1, num_edge_types]), 1, 2)) 44 | 45 | @classmethod 46 | def create_empty_from_spec(cls, batch_size, spec): 47 | """ 48 | Create an empty graph state from a spec 49 | 50 | batch_size: Number of batches 51 | spec: Instance of GraphStateSpec 52 | """ 53 | return cls.create_empty(batch_size, spec.num_node_ids, spec.node_state_size, spec.num_edge_types) 54 | 55 | @classmethod 56 | def create_full_unique(cls, batch_size, num_node_ids, node_state_size, num_edge_types): 57 | """ 58 | Create a 'full unique' graph state (i.e. a graph state where every id has exactly one node) from a spec 59 | 60 | batch_size: Number of batches 61 | num_node_ids: An integer giving size of node id 62 | node_state_size: An integer giving size of node state 63 | num_edge_types: An integer giving number of edge types 64 | """ 65 | return cls( T.ones([batch_size, num_node_ids]), 66 | T.tile(T.shape_padleft(T.eye(num_node_ids)), (batch_size,1,1)), 67 | T.zeros([batch_size, num_node_ids, node_state_size]), 68 | T.zeros([batch_size, num_node_ids, num_node_ids, num_edge_types])) 69 | 70 | @classmethod 71 | def create_full_unique_from_spec(cls, batch_size, spec): 72 | """ 73 | Create a 'full unique' graph state (i.e. a graph state where every id has exactly one node) from a spec 74 | 75 | batch_size: Number of batches 76 | spec: Instance of GraphStateSpec 77 | """ 78 | return cls.create_full_unique(batch_size, spec.num_node_ids, spec.node_state_size, spec.num_edge_types) 79 | 80 | @property 81 | def node_strengths(self): 82 | return self._node_strengths 83 | 84 | @property 85 | def node_states(self): 86 | return self._node_states 87 | 88 | @property 89 | def node_ids(self): 90 | return self._node_ids 91 | 92 | @property 93 | def edge_strengths(self): 94 | return self._edge_strengths 95 | 96 | @property 97 | def n_batch(self): 98 | return self.node_states.shape[0] 99 | 100 | @property 101 | def n_nodes(self): 102 | return self.node_states.shape[1] 103 | 104 | @property 105 | def num_node_ids(self): 106 | return self.node_ids.shape[2] 107 | 108 | @property 109 | def node_state_size(self): 110 | return self.node_states.shape[2] 111 | 112 | @property 113 | def num_edge_types(self): 114 | return self.edge_strengths.shape[3] 115 | 116 | def flatten(self): 117 | return [self.node_strengths, self.node_ids, self.node_states, self.edge_strengths] 118 | 119 | @classmethod 120 | def unflatten(cls, vals): 121 | return cls(*vals) 122 | 123 | @classmethod 124 | def const_flattened_length(cls): 125 | return 5 126 | 127 | def flatten_to_const_size(self, const_n_nodes): 128 | exp_node_strengths = pad_to(self.node_strengths, [self.n_batch, const_n_nodes]) 129 | exp_node_ids = pad_to(self.node_ids, [self.n_batch, const_n_nodes, self.num_node_ids]) 130 | exp_node_states = pad_to(self.node_states, [self.n_batch, const_n_nodes, self.node_state_size]) 131 | exp_edge_strengths = pad_to(self.edge_strengths, [self.n_batch, const_n_nodes, const_n_nodes, self.num_edge_types]) 132 | return [exp_node_strengths, exp_node_ids, exp_node_states, exp_edge_strengths, self.n_nodes] 133 | 134 | @classmethod 135 | def unflatten_from_const_size(cls, vals): 136 | exp_node_strengths, exp_node_ids, exp_node_states, exp_edge_strengths, n_nodes = vals 137 | return cls( exp_node_strengths[:,:n_nodes], 138 | exp_node_ids[:,:n_nodes,:], 139 | exp_node_states[:,:n_nodes,:], 140 | exp_edge_strengths[:,:n_nodes,:n_nodes,:]) 141 | 142 | def with_updates(self, node_strengths=None, node_ids=None, node_states=None, edge_strengths=None): 143 | """ 144 | Helper function to generate a new state with changes applied. Params like in constructor, or None 145 | to use current values 146 | 147 | Returns: A new graph state with the changes 148 | """ 149 | node_strengths = self.node_strengths if node_strengths is None else node_strengths 150 | node_ids = self.node_ids if node_ids is None else node_ids 151 | node_states = self.node_states if node_states is None else node_states 152 | edge_strengths = self.edge_strengths if edge_strengths is None else edge_strengths 153 | cls = type(self) 154 | return cls(node_strengths, node_ids, node_states, edge_strengths) 155 | 156 | def with_additional_nodes(self, new_node_strengths, new_node_ids, new_node_states=None): 157 | """ 158 | Helper function to generate a new state with new nodes added. 159 | 160 | Params: 161 | new_node_strengths: Tensor of shape (n_batch, n_new_nodes) 162 | new_node_ids: Tensor of shape (n_batch, n_new_nodes, num_node_ids) 163 | new_node_states: (Optional) Tensor of shape (n_batch, n_new_nodes, node_state_size) 164 | If not provided, will be zero 165 | 166 | Returns: A new graph state with the changes 167 | """ 168 | if new_node_states is None: 169 | new_node_states = T.zeros([self.n_batch, new_node_strengths.shape[1], self.node_state_size]) 170 | 171 | next_node_strengths = T.concatenate([self.node_strengths, new_node_strengths], 1) 172 | next_node_ids = T.concatenate([self.node_ids, new_node_ids], 1) 173 | next_node_states = T.concatenate([self.node_states, new_node_states], 1) 174 | next_n_nodes = next_node_strengths.shape[1] 175 | 176 | next_edge_strengths = pad_to(self.edge_strengths, [self.n_batch, next_n_nodes, next_n_nodes, self.num_edge_types]) 177 | 178 | cls = type(self) 179 | return cls(next_node_strengths, next_node_ids, next_node_states, next_edge_strengths) 180 | 181 | 182 | -------------------------------------------------------------------------------- /layer.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | import numpy as np 4 | from util import * 5 | 6 | class Layer(object): 7 | 8 | def __init__(self, input_size, output_size, bias_shift=0.0, name='layer', activation=identity, dropout_keep=1): 9 | self.input_size = input_size 10 | self.output_size = output_size 11 | self.activation = activation 12 | self.name = name if name is not None else get_unique_name(type(self)) 13 | self._W = theano.shared(init_params([input_size, output_size]), self.name+"_W") 14 | self._b = theano.shared(init_params([output_size], shift=bias_shift), self.name+"_W") 15 | self.dropout_keep = dropout_keep 16 | 17 | @property 18 | def params(self): 19 | return [self._W, self._b] 20 | 21 | def dropout_masks(self, srng): 22 | if self.dropout_keep == 1: 23 | return [] 24 | else: 25 | return [make_dropout_mask((self.input_size,), self.dropout_keep, srng)] 26 | 27 | def split_dropout_masks(self, dropout_masks): 28 | if dropout_masks is None: 29 | return [], None 30 | idx = (self.dropout_keep != 1) 31 | return dropout_masks[:idx], dropout_masks[idx:] 32 | 33 | def process(self, ipt, dropout_masks=Ellipsis): 34 | if dropout_masks is Ellipsis: 35 | dropout_masks = None 36 | append_masks = False 37 | else: 38 | append_masks = True 39 | if self.dropout_keep != 1 and dropout_masks not in ([], None): 40 | ipt = apply_dropout(ipt, dropout_masks[0]) 41 | dropout_masks = dropout_masks[1:] 42 | xW = T.dot(ipt, self._W) 43 | b = T.shape_padleft(self._b) 44 | if append_masks: 45 | return self.activation( xW + b ), dropout_masks 46 | else: 47 | return self.activation( xW + b ) 48 | 49 | class LayerStack(object): 50 | def __init__(self, input_size, output_size, hidden_sizes=[], bias_shift=0.0, name=None, hidden_activation=T.tanh, activation=identity, dropout_keep=1, dropout_input=True, dropout_output=False): 51 | self.input_size = input_size 52 | self.output_size =output_size 53 | self.name = name if name is not None else get_unique_name(type(self)) 54 | 55 | self.dropout_keep = dropout_keep 56 | self.dropout_output = dropout_output 57 | 58 | self.layers = [] 59 | for i, isize, osize in zip(itertools.count(), 60 | [input_size]+hidden_sizes, 61 | hidden_sizes+[output_size]): 62 | cur_dropout_keep = 1 if (i==0 and not dropout_input) else dropout_keep 63 | if i == len(hidden_sizes): 64 | # Last layer 65 | self.layers.append(Layer(isize, osize, bias_shift=bias_shift, name="{}[output]".format(self.name), activation=activation, dropout_keep=cur_dropout_keep)) 66 | else: 67 | self.layers.append(Layer(isize, osize, name="{}[hidden{}]".format(self.name,i), activation=hidden_activation, dropout_keep=cur_dropout_keep)) 68 | 69 | @property 70 | def params(self): 71 | return [param for layer in self.layers for param in layer.params] 72 | 73 | def dropout_masks(self, srng): 74 | masks = [mask for layer in self.layers for mask in layer.dropout_masks(srng)] 75 | if self.dropout_keep != 1 and self.dropout_output: 76 | masks.append(make_dropout_mask((self.output_size,), self.dropout_keep, srng)) 77 | return masks 78 | 79 | def split_dropout_masks(self, dropout_masks): 80 | if dropout_masks is None: 81 | return [], None 82 | used = [] 83 | for layer in self.layers: 84 | new_used, dropout_masks = layer.split_dropout_masks(dropout_masks) 85 | used.extend(new_used) 86 | if self.dropout_keep != 1 and self.dropout_output: 87 | used.append(dropout_masks[0]) 88 | dropout_masks = dropout_masks[1:] 89 | return used, dropout_masks 90 | 91 | def process(self, ipt, dropout_masks=Ellipsis): 92 | if dropout_masks is Ellipsis: 93 | dropout_masks = None 94 | append_masks = False 95 | else: 96 | append_masks = True 97 | val = ipt 98 | for layer in self.layers: 99 | val, dropout_masks = layer.process(val, dropout_masks) 100 | if self.dropout_keep != 1 and self.dropout_output and dropout_masks not in ([], None): 101 | val = apply_dropout(val, dropout_masks[0]) 102 | dropout_masks = dropout_masks[1:] 103 | if append_masks: 104 | return val, dropout_masks 105 | else: 106 | return val 107 | 108 | 109 | 110 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import argparse 4 | import shutil 5 | import math 6 | import sys 7 | 8 | import model 9 | import ggtnn_train 10 | import ggtnn_graph_parse 11 | from ggtnn_graph_parse import MetadataList, PreppedStory 12 | from util import * 13 | 14 | def helper_trim(bucketed, desired_total): 15 | """Trim bucketed fairly so that it has desired_total things total""" 16 | cur_total = sum(len(b) for b in bucketed) 17 | keep_frac = desired_total/cur_total 18 | if keep_frac > 1.0: 19 | print("WARNING: Asked to trim to {} items, but was already only {} items. Keeping original length.".format(desired_total, cur_total)) 20 | return bucketed 21 | keep_amts = [math.floor(len(b) * keep_frac) for b in bucketed] 22 | tmp_total = sum(keep_amts) 23 | addtl_to_add = desired_total - tmp_total 24 | assert addtl_to_add >= 0 25 | keep_amts = [x + (1 if i < addtl_to_add else 0) for i,x in enumerate(keep_amts)] 26 | assert sum(keep_amts) == desired_total 27 | trimmed_bucketed = [b[:amt] for b,amt in zip(bucketed, keep_amts)] 28 | return trimmed_bucketed 29 | 30 | def main(task_dir, output_format_str, state_width, process_repr_size, dynamic_nodes, mutable_nodes, wipe_node_state, direct_reference, propagate_intermediate, sequence_aggregate_repr, old_aggregate, train_with_graph, train_with_query, outputdir, num_updates, batch_size, learning_rate, dropout_keep, resume, resume_auto, visualize, visualize_snap, visualization_test, validation, validation_interval, evaluate_accuracy, check_mode, stop_at_accuracy, stop_at_loss, stop_at_overfitting, restrict_dataset, train_save_params, batch_adjust, set_exit_status, just_compile, autopickle, pickle_model, unpickle_model, interrupt_file): 31 | output_format = model.ModelOutputFormat[output_format_str] 32 | 33 | with open(os.path.join(task_dir,'metadata.p'),'rb') as f: 34 | metadata = pickle.load(f) 35 | with open(os.path.join(task_dir,'file_list.p'),'rb') as f: 36 | bucketed = pickle.load(f) 37 | bucketed = [[os.path.join(task_dir,x) for x in b] for b in bucketed] 38 | if restrict_dataset is not None: 39 | bucketed = helper_trim(bucketed, restrict_dataset) 40 | 41 | sentence_length, new_nodes_per_iter, bucket_sizes, wordlist, anslist, graph_node_list, graph_edge_list = metadata 42 | eff_anslist = ggtnn_train.get_effective_answer_words(anslist, output_format) 43 | 44 | if validation is None: 45 | validation_buckets = None 46 | validation_bucket_sizes = None 47 | else: 48 | with open(os.path.join(validation,'metadata.p'),'rb') as f: 49 | validation_metadata = pickle.load(f) 50 | with open(os.path.join(validation,'file_list.p'),'rb') as f: 51 | validation_buckets = pickle.load(f) 52 | validation_buckets = [[os.path.join(validation,x) for x in b] for b in validation_buckets] 53 | validation_bucket_sizes = validation_metadata[2] 54 | 55 | if direct_reference: 56 | word_node_mapping = {wi:ni for wi,word in enumerate(wordlist) 57 | for ni,node in enumerate(graph_node_list) 58 | if word == node} 59 | else: 60 | word_node_mapping = {} 61 | 62 | model_kwargs = dict(num_input_words=len(wordlist), 63 | num_output_words=len(eff_anslist), 64 | num_node_ids=len(graph_node_list), 65 | node_state_size=state_width, 66 | num_edge_types=len(graph_edge_list), 67 | input_repr_size=100, 68 | output_repr_size=100, 69 | propose_repr_size=process_repr_size, 70 | propagate_repr_size=process_repr_size, 71 | new_nodes_per_iter=new_nodes_per_iter, 72 | output_format=output_format, 73 | final_propagate=5, 74 | word_node_mapping=word_node_mapping, 75 | dynamic_nodes=dynamic_nodes, 76 | nodes_mutable=mutable_nodes, 77 | wipe_node_state=wipe_node_state, 78 | intermediate_propagate=(5 if propagate_intermediate else 0), 79 | sequence_representation=sequence_aggregate_repr, 80 | dropout_keep=dropout_keep, 81 | use_old_aggregate=old_aggregate, 82 | best_node_match_only=True, 83 | train_with_graph=train_with_graph, 84 | train_with_query=train_with_query, 85 | setup=True, 86 | check_mode=check_mode) 87 | 88 | model_kwargs = get_compatible_kwargs(model.Model, model_kwargs) 89 | 90 | if autopickle is not None: 91 | if not os.path.exists(autopickle): 92 | os.makedirs(autopickle) 93 | model_hash = object_hash(model_kwargs) 94 | model_filename = os.path.join(autopickle, "model_{}.p".format(model_hash)) 95 | print("Looking for cached model at {}".format(model_filename)) 96 | if os.path.isfile(model_filename): 97 | print("Loading model from cache") 98 | m, stored_kwargs = pickle.load(open(model_filename, 'rb')) 99 | assert model_kwargs == stored_kwargs, "Hash collision between models!\nCurrent: {}\nStored: {}".format(model_kwargs,stored_kwargs) 100 | else: 101 | print("Building model from scratch") 102 | m = model.Model(**model_kwargs) 103 | print("Saving model to cache") 104 | sys.setrecursionlimit(100000) 105 | pickle.dump((m,model_kwargs), open(model_filename,'wb'), protocol=pickle.HIGHEST_PROTOCOL) 106 | elif unpickle_model is not None: 107 | print("Unpickling model...") 108 | m = pickle.load(open(unpickle_model, 'rb')) 109 | else: 110 | m = model.Model(**model_kwargs) 111 | 112 | if pickle_model is not None: 113 | sys.setrecursionlimit(100000) 114 | print("Pickling model...") 115 | pickle.dump(m, open(pickle_model,'wb'), protocol=pickle.HIGHEST_PROTOCOL) 116 | 117 | if just_compile: 118 | return 119 | 120 | if learning_rate is not None: 121 | m.set_learning_rate(learning_rate) 122 | 123 | if not os.path.exists(outputdir): 124 | os.makedirs(outputdir) 125 | 126 | if resume_auto: 127 | result = find_recent_params(outputdir) 128 | if result is not None: 129 | start_idx, paramfile = result 130 | print("Automatically resuming from {} after iteration {}.".format(paramfile, start_idx)) 131 | resume = result 132 | else: 133 | print("Didn't find anything to resume. Starting from the beginning...") 134 | 135 | if resume is not None: 136 | start_idx, paramfile = resume 137 | start_idx = int(start_idx) 138 | load_params(m.params, open(paramfile, "rb") ) 139 | else: 140 | start_idx = 0 141 | 142 | if visualize is not False: 143 | if visualize is True: 144 | source = bucketed 145 | else: 146 | bucket, story = visualize 147 | source = [[bucketed[bucket][story]]] 148 | print("Starting to visualize...") 149 | ggtnn_train.visualize(m, source, wordlist, eff_anslist, output_format, outputdir, snap=visualize_snap) 150 | print("Wrote visualization files to {}.".format(outputdir)) 151 | elif evaluate_accuracy: 152 | print("Evaluating accuracy...") 153 | acc = ggtnn_train.test_accuracy(m, bucketed, bucket_sizes, len(eff_anslist), output_format, batch_size, batch_adjust, (not train_with_query)) 154 | print("Obtained accuracy of {}".format(acc)) 155 | elif visualization_test: 156 | print("Starting visualization test...") 157 | ggtnn_train.visualize(m, bucketed, wordlist, eff_anslist, output_format, outputdir, debugmode=True) 158 | print("Wrote visualization files to {}.".format(outputdir)) 159 | else: 160 | print("Starting to train...") 161 | status = ggtnn_train.train(m, bucketed, bucket_sizes, len(eff_anslist), output_format, num_updates, outputdir, start_idx, batch_size, validation_buckets, validation_bucket_sizes, stop_at_accuracy, stop_at_loss, stop_at_overfitting, train_save_params, validation_interval, batch_adjust, interrupt_file) 162 | if set_exit_status: 163 | sys.exit(status.value) 164 | 165 | parser = argparse.ArgumentParser(description='Train a graph memory network model.', formatter_class=argparse.ArgumentDefaultsHelpFormatter) 166 | parser.add_argument('task_dir', help="Parsed directory for the task to load") 167 | parser.add_argument('output_format_str', choices=[x.name for x in model.ModelOutputFormat], help="Output format for the task") 168 | parser.add_argument('state_width', type=int, help="Width of node state") 169 | parser.add_argument('--process-repr-size', type=int, default=50, help="Width of intermediate representations") 170 | parser.add_argument('--mutable-nodes', action="store_true", help="Make nodes mutable") 171 | parser.add_argument('--wipe-node-state', action="store_true", help="Wipe node state before the query") 172 | parser.add_argument('--direct-reference', action="store_true", help="Use direct reference for input, based on node names") 173 | parser.add_argument('--dynamic-nodes', action="store_true", help="Create nodes after each sentence. (Otherwise, create unique nodes at the beginning)") 174 | parser.add_argument('--propagate-intermediate', action="store_true", help="Run a propagation step after each sentence") 175 | parser.add_argument('--sequence-aggregate-repr', action="store_true", help="Compute the query aggregate representation from the sequence of graphs instead of just the last one") 176 | parser.add_argument('--old-aggregate', action="store_true", help="Use the old, incorrect aggregate function") 177 | parser.add_argument('--no-graph', dest='train_with_graph', action="store_false", help="Don't train using graph supervision") 178 | parser.add_argument('--no-query', dest='train_with_query', action="store_false", help="Don't train using query supervision") 179 | parser.add_argument('--outputdir', default="output", help="Directory to save output in") 180 | parser.add_argument('--num-updates', default="10000", type=int, help="How many iterations to train") 181 | parser.add_argument('--batch-size', default="10", type=int, help="Batch size to use") 182 | parser.add_argument('--learning-rate', type=float, default=None, help="Use this learning rate") 183 | parser.add_argument('--dropout-keep', default=1, type=float, help="Use dropout, with this keep chance") 184 | parser.add_argument('--restrict-dataset', metavar="NUM_STORIES", type=int, default=None, help="Restrict size of dataset to this") 185 | parser.add_argument('--save-params-interval', type=int, default=1000, dest="train_save_params", help="Save parameters after this many iterations") 186 | parser.add_argument('--final-params-only', action="store_const", const=None, dest="train_save_params", help="Don't save parameters while training, only at the end.") 187 | parser.add_argument('--validation', metavar="VALIDATION_DIR", default=None, help="Parsed directory of validation tasks") 188 | parser.add_argument('--validation-interval', type=int, default=1000, help="Check validation after this many iterations") 189 | parser.add_argument('--check-nan', dest="check_mode", action="store_const", const="nan", help="Check for NaN. Slows execution") 190 | parser.add_argument('--check-debug', dest="check_mode", action="store_const", const="debug", help="Debug mode. Slows execution") 191 | parser.add_argument('--visualize', nargs="?", const=True, default=False, metavar="BUCKET,STORY", type=lambda s:[int(x) for x in s.split(',')], help="Visualise current state instead of training. Optional parameter selects a particular story to visualize, and should be of the form bucketnum,index") 192 | parser.add_argument('--visualize-snap', action="store_true", help="In visualization mode, snap to best option at each timestep") 193 | parser.add_argument('--visualization-test', action="store_true", help="Like visualize, but use the correct graph instead of the model's graph") 194 | parser.add_argument('--evaluate-accuracy', action="store_true", help="Evaluate accuracy of model") 195 | parser.add_argument('--stop-at-accuracy', type=float, default=None, help="Stop training once it reaches this accuracy on validation set") 196 | parser.add_argument('--stop-at-loss', type=float, default=None, help="Stop training once it reaches this loss on validation set") 197 | parser.add_argument('--stop-at-overfitting', type=float, default=None, help="Stop training once validation loss is this many times higher than train loss") 198 | parser.add_argument('--batch-adjust', type=int, default=None, help="If set, ensure that size of edge matrix does not exceed this") 199 | parser.add_argument('--set-exit-status', action="store_true", help="Give info about training status in the exit status") 200 | parser.add_argument('--just-compile', action="store_true", help="Don't run the model, just compile it") 201 | parser.add_argument('--autopickle', metavar="PICKLEDIR", default=None, help="Automatically cache model in this directory") 202 | parser.add_argument('--pickle-model', metavar="MODELFILE", default=None, help="Save the compiled model to a file") 203 | parser.add_argument('--unpickle-model', metavar="MODELFILE", default=None, help="Load the model from a file instead of compiling it from scratch") 204 | parser.add_argument('--interrupt-file', default=None, help="Interrupt training if this file appears") 205 | resume_group = parser.add_mutually_exclusive_group() 206 | resume_group.add_argument('--resume', nargs=2, metavar=('TIMESTEP', 'PARAMFILE'), default=None, help='Where to restore from: timestep, and file to load') 207 | resume_group.add_argument('--resume-auto', action='store_true', help='Automatically restore from a previous run using output directory') 208 | 209 | if __name__ == '__main__': 210 | np.set_printoptions(linewidth=shutil.get_terminal_size((80, 20)).columns) 211 | args = vars(parser.parse_args()) 212 | main(**args) 213 | -------------------------------------------------------------------------------- /metadata-display.py: -------------------------------------------------------------------------------- 1 | from ggtnn_graph_parse import MetadataList 2 | from pprint import pprint 3 | import pickle 4 | import sys 5 | 6 | def main(file): 7 | with open(file,'rb') as f: 8 | metadata = pickle.load(f) 9 | pprint(dict(metadata._asdict())) 10 | 11 | if __name__ == '__main__': 12 | main(sys.argv[1]) -------------------------------------------------------------------------------- /run_harness.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import subprocess 4 | import shutil 5 | import shlex 6 | import collections 7 | from train_exit_status import TrainExitStatus 8 | from graceful_interrupt import GracefulInterruptHandler 9 | from termcolor import colored 10 | 11 | TaskSpec = collections.namedtuple("TaskSpec", ["task_name", "variant_name", "run_params"]) 12 | 13 | def run(tasks_dir, output_dir, base_params, specs, stop_on_error=False, skip_complete=False): 14 | base_params_split = shlex.split(base_params) 15 | for spec in specs: 16 | print(colored("### Task {} ({}) ###".format(spec.task_name, spec.variant_name), "yellow")) 17 | run_params_split = shlex.split(spec.run_params) 18 | 19 | task_folder_train = os.path.join(tasks_dir, "{}_train".format(spec.task_name)) 20 | if not os.path.isdir(task_folder_train): 21 | print(colored("Train directory doesn't exist. Parsing text file...", attrs=["dark"])) 22 | textfile = task_folder_train + ".txt" 23 | subprocess.run(["python3","ggtnn_graph_parse.py",textfile], check=True) 24 | 25 | task_folder_valid = os.path.join(tasks_dir, "{}_valid".format(spec.task_name)) 26 | if not os.path.isdir(task_folder_valid): 27 | print(colored("Validation directory doesn't exist. Parsing text file...", attrs=["dark"])) 28 | textfile = task_folder_valid + ".txt" 29 | try: 30 | subprocess.run(["python3","ggtnn_graph_parse.py",textfile,"--metadata-file",os.path.join(task_folder_train,"metadata.p")], check=True) 31 | except subprocess.CalledProcessError: 32 | print(colored("Could not parse validation set! Skipping. You may need to regenerate the training set.","magenta")) 33 | continue 34 | 35 | task_output_dir = os.path.join(output_dir, spec.task_name, spec.variant_name) 36 | if not os.path.isdir(task_output_dir): 37 | os.makedirs(task_output_dir) 38 | 39 | completed_file = os.path.join(task_output_dir, "completed.txt") 40 | if os.path.exists(completed_file): 41 | with open(completed_file,'r') as f: 42 | reason = f.readline().strip() 43 | reason = colored(reason, "green" if (reason == "SUCCESS") else "red" if ("FAIL" in reason) else "magenta") 44 | print("Task is already completed, with result {}. Skipping...".format(reason)) 45 | continue 46 | 47 | stdout_fn = os.path.join(task_output_dir, "stdout.txt") 48 | 49 | all_params = ["python3", "-u", "main.py", task_folder_train] + run_params_split + base_params_split 50 | all_params.extend(["--outputdir", task_output_dir]) 51 | all_params.extend(["--validation", task_folder_valid]) 52 | all_params.extend(["--set-exit-status"]) 53 | all_params.extend(["--resume-auto"]) 54 | all_params.extend(["--autopickle", os.path.join(output_dir, "model_cache")]) 55 | print("Running command: " + " ".join(all_params)) 56 | with open(stdout_fn, 'a', 1) as stdout_file: 57 | proc = subprocess.Popen(all_params, bufsize=1, universal_newlines=True, stdout=stdout_file, stderr=subprocess.STDOUT) 58 | with GracefulInterruptHandler() as handler: 59 | returncode = proc.wait() 60 | interrupted = handler.interrupted 61 | 62 | task_status = None 63 | was_error = False 64 | if returncode < 0: 65 | print(colored("Process was killed by a signal!","magenta")) 66 | was_error = True 67 | elif skip_complete: 68 | print(colored("Skipping saving the result (skip_complete=True)")) 69 | else: 70 | task_status = TrainExitStatus(returncode) 71 | 72 | if task_status == TrainExitStatus.success: 73 | print(colored("SUCCESS! Reached desired correctness.","green")) 74 | with open(completed_file,'w') as f: 75 | f.write("SUCCESS\n") 76 | elif task_status == TrainExitStatus.reached_update_limit: 77 | print(colored("FAIL! Reached update limit without attaining desired correctness.","red")) 78 | with open(completed_file,'w') as f: 79 | f.write("FAIL_UPDATE_LIMIT\n") 80 | elif task_status == TrainExitStatus.overfitting: 81 | print(colored("FAIL! Detected overfitting.","red")) 82 | with open(completed_file,'w') as f: 83 | f.write("FAIL_OVERFITTING\n") 84 | elif task_status in (TrainExitStatus.error, TrainExitStatus.malformed_command): 85 | print(colored("Got an error; skipping for now. See {} for details.".format(stdout_fn),"magenta")) 86 | was_error = True 87 | elif task_status == TrainExitStatus.nan_loss: 88 | print(colored("NaN loss detected; skipping for now.","magenta")) 89 | was_error = True 90 | 91 | if task_status == TrainExitStatus.interrupted or interrupted: 92 | print(colored("Process was interrupted! Stopping...","cyan")) 93 | break 94 | 95 | if was_error and stop_on_error: 96 | print(colored("Got an error. Exiting...","cyan")) 97 | break 98 | -------------------------------------------------------------------------------- /strength_weighted_gru.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | import numpy as np 4 | 5 | from util import * 6 | 7 | class StrengthWeightedGRULayer( object ): 8 | """ 9 | Implements a strength-weighted GRU layer 10 | """ 11 | 12 | def __init__(self, input_width, output_width, activation_shift=0.0, name=None): 13 | """ 14 | Params: 15 | input_width: Width of input. 16 | output_width: Width of the GRU output 17 | activation_shift: How to shift the biases of the activation 18 | """ 19 | self._input_width = input_width 20 | self._output_width = output_width 21 | 22 | prefix = "" if name is None else name + "_" 23 | 24 | self._reset_W = theano.shared(init_params([input_width + output_width, output_width]), prefix+"reset_W") 25 | self._reset_b = theano.shared(init_params([output_width], shift=1.0), prefix+"reset_b") 26 | 27 | self._update_W = theano.shared(init_params([input_width + output_width, output_width+1]), prefix+"update_W") 28 | self._update_b = theano.shared(init_params([output_width+1], shift=1.0), prefix+"update_b") 29 | 30 | self._activation_W = theano.shared(init_params([input_width + output_width, output_width]), prefix+"activation_W") 31 | self._activation_b = theano.shared(init_params([output_width], shift=activation_shift), prefix+"activation_b") 32 | 33 | self._strength_W = theano.shared(init_params([input_width + output_width, 1]), prefix+"strength_W") 34 | self._strength_b = theano.shared(init_params([1], shift=1.0), prefix+"strength_b") 35 | 36 | @property 37 | def input_width(self): 38 | return self._input_width 39 | 40 | @property 41 | def output_width(self): 42 | return self._output_width 43 | 44 | @property 45 | def params(self): 46 | return [self._reset_W, self._reset_b, self._update_W, self._update_b, self._activation_W, self._activation_b, self._strength_W, self._strength_b] 47 | 48 | @property 49 | def num_dropout_masks(self): 50 | return 2 51 | 52 | def get_dropout_masks(self, srng, keep_frac): 53 | """ 54 | Get dropout masks for the GRU. 55 | """ 56 | return [T.shape_padleft(T.cast(srng.binomial((self._input_width,), p=keep_frac), 'float32') / keep_frac), 57 | T.shape_padleft(T.cast(srng.binomial((self._output_width,), p=keep_frac), 'float32') / keep_frac)] 58 | 59 | def step(self, ipt, state, state_strength, dropout_masks=None): 60 | """ 61 | Perform a single step of the network 62 | 63 | Params: 64 | ipt: The current input. Should be an int tensor of shape (n_batch, self.input_width) 65 | state: The previous state. Should be a float tensor of shape (n_batch, self.output_width) 66 | state_strength: Strength of the previous state. Should be a float tensor of shape 67 | (n_batch) 68 | dropout_masks: Masks from get_dropout_masks 69 | 70 | Returns: The next output state, and the next output strength 71 | """ 72 | if dropout_masks is not None: 73 | ipt_masks, state_masks = dropout_masks 74 | ipt = ipt*ipt_masks 75 | state = state*state_masks 76 | 77 | obs_state = state * T.shape_padright(state_strength) 78 | cat_ipt_state = T.concatenate([ipt, obs_state], 1) 79 | reset = do_layer( T.nnet.sigmoid, cat_ipt_state, 80 | self._reset_W, self._reset_b ) 81 | update = do_layer( T.nnet.sigmoid, cat_ipt_state, 82 | self._update_W, self._update_b ) 83 | update_state = update[:,:-1] 84 | update_strength = update[:,-1] 85 | 86 | cat_reset_ipt_state = T.concatenate([ipt, (reset * obs_state)], 1) 87 | candidate_act = do_layer( T.tanh, cat_reset_ipt_state, 88 | self._activation_W, self._activation_b ) 89 | candidate_strength = do_layer( T.nnet.sigmoid, cat_reset_ipt_state, 90 | self._strength_W, self._strength_b ).reshape(state_strength.shape) 91 | 92 | newstate = update_state * state + (1-update_state) * candidate_act 93 | newstrength = update_strength * state_strength + (1-update_strength) * candidate_strength 94 | 95 | return newstate, newstrength 96 | -------------------------------------------------------------------------------- /task_generators/automaton.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import json 4 | import sys 5 | 6 | def simulate(cells, rules): 7 | assert rules[(0,0,0)] == 0 8 | old_cells = [0,0] + cells + [0,0] 9 | new_cells = [] 10 | for i in range(len(cells)+2): 11 | cur_block = tuple(old_cells[i:i+3]) 12 | new_cells.append(rules[cur_block]) 13 | return new_cells 14 | 15 | def int_to_bintuple(val,width): 16 | val = tuple(int(x) for x in bin(val)[2:]) 17 | while len(val) < width: 18 | val = (0,) + val 19 | return val 20 | 21 | def decode_rules(rule_idx): 22 | keys = [int_to_bintuple(i,3) for i in reversed(range(8))] 23 | values = int_to_bintuple(rule_idx,8) 24 | return dict(zip(keys,values)) 25 | 26 | def generate(num_seqs, init_len, run_len, rule_idx, start_with=None): 27 | assert init_len > 0 28 | rules = decode_rules(rule_idx) 29 | result = [] 30 | for _ in range(num_seqs): 31 | story = [] 32 | cell_ptrs = [] 33 | cell_values = [] 34 | nodes = [] 35 | connect_edges = [] 36 | value_edges = [] 37 | if start_with is None: 38 | val_sequence = [random.choice([0,1]) for _ in range(init_len)] 39 | else: 40 | val_sequence = [int(x) for x in start_with] 41 | for i,val in enumerate(val_sequence): 42 | cell_values.append(val) 43 | val_node = str(val) 44 | if val_node not in nodes: 45 | nodes.append(val_node) 46 | cell_node = "cell_init#"+str(i) 47 | nodes.append(cell_node) 48 | value_edges.append({"type":"value","from":cell_node,"to":val_node}) 49 | if len(cell_ptrs) > 0: 50 | connect_edges.append({"type":"next_r","from":cell_ptrs[-1],"to":cell_node}) 51 | cell_ptrs.append(cell_node) 52 | 53 | graph_str = json.dumps({ 54 | "nodes":nodes, 55 | "edges":connect_edges + value_edges, 56 | }) 57 | story.append("init {}={}".format(val,graph_str)) 58 | for i in range(run_len): 59 | new_cell_values = simulate(cell_values, rules) 60 | cell_left = "cell_left#"+str(i) 61 | cell_right = "cell_right#"+str(i) 62 | connect_edges.append({"type":"next_r","from":cell_left,"to":cell_ptrs[0]}) 63 | connect_edges.append({"type":"next_r","from":cell_ptrs[-1],"to":cell_right}) 64 | nodes.extend([cell_left,cell_right]) 65 | cell_ptrs = [cell_left] + cell_ptrs + [cell_right] 66 | value_edges = [] 67 | for cell_ptr,val in zip(cell_ptrs,new_cell_values): 68 | val_node = str(val) 69 | if val_node not in nodes: 70 | nodes.append(val_node) 71 | value_edges.append({"type":"value","from":cell_ptr,"to":val_node}) 72 | cell_values = new_cell_values 73 | graph_str = json.dumps({ 74 | "nodes":nodes, 75 | "edges":connect_edges + value_edges, 76 | }) 77 | story.append("simulate={}".format(graph_str)) 78 | story.append("\t") 79 | result.extend(["{} {}".format(i+1,s) for i,s in enumerate(story)]) 80 | return "\n".join(result)+"\n" 81 | 82 | def main(num_seqs, init_len, run_len, rule_idx, file, start_with): 83 | generated = generate(num_seqs, init_len, run_len, rule_idx, start_with) 84 | file.write(generated) 85 | 86 | parser = argparse.ArgumentParser(description='Generate an ngrams task') 87 | parser.add_argument("rule_idx", type=int, help="Which automaton rule to use") 88 | parser.add_argument("file", nargs="?", default=sys.stdout, type=argparse.FileType('w'), help="Output file") 89 | parser.add_argument("--num-seqs", type=int, default=1, help="Number of sequences to generate") 90 | parser.add_argument("--init-len", type=int, default=5, help="Length of initial cells") 91 | parser.add_argument("--run-len", type=int, default=5, help="Number of simulate steps") 92 | parser.add_argument("--start-with", default=None, help="Start with this exact input") 93 | 94 | if __name__ == '__main__': 95 | args = vars(parser.parse_args()) 96 | main(**args) 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /task_generators/forth.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import json 4 | import sys 5 | import graph_tools 6 | 7 | def build_sequence(forth_sequence, run_steps=0): 8 | story = graph_tools.Story() 9 | graph = story.graph 10 | 11 | # Start 12 | n_pc = graph.make('pc') 13 | n_start = graph.make('c_START') 14 | n_pc.executing = n_start 15 | n_head = n_start 16 | next_cmd_edge = "next_cmd" 17 | scope_stack = [n_start] 18 | story.add_line("[START]") 19 | 20 | # Compiling 21 | for command in forth_sequence.split(' '): 22 | basic_cmds = ["NOP", "ZERO", "INC", "DEC", "DUP", "SWAP", "NOT", "POP", "HALT"] 23 | if command in basic_cmds: 24 | n_new = graph.make("c_{}".format(command)) 25 | n_head[next_cmd_edge] = n_new 26 | n_head = n_new 27 | next_cmd_edge = "next_cmd" 28 | elif command == "IF": 29 | n_if = graph.make("c_IF") 30 | n_head[next_cmd_edge] = n_if 31 | n_head = n_if 32 | scope_stack[-1].next_scope = n_if 33 | scope_stack.append(n_if) 34 | next_cmd_edge = "next_if_true" 35 | elif command == "ELSE": 36 | n_if = scope_stack[-1] 37 | assert "c_IF" == n_if.type 38 | assert "c_IF" != n_head.type 39 | n_head.scope_end_cmd = n_if 40 | n_head = n_if 41 | next_cmd_edge = "next_if_false" 42 | elif command == "THEN": 43 | n_if = scope_stack.pop() 44 | scope_stack[-1].next_scope = None 45 | assert "c_IF" == n_if.type 46 | assert "c_IF" != n_head.type 47 | n_head.scope_end_cmd = n_if 48 | n_head = n_if 49 | next_cmd_edge = "next_then" 50 | elif command == "WHILE": 51 | n_while = graph.make("c_WHILE") 52 | n_head[next_cmd_edge] = n_while 53 | n_head = n_while 54 | scope_stack[-1].next_scope = n_while 55 | scope_stack.append(n_while) 56 | next_cmd_edge = "next_if_true" 57 | elif command == "REPEAT": 58 | n_while = scope_stack.pop() 59 | scope_stack[-1].next_scope = None 60 | assert "c_WHILE" == n_while.type 61 | assert "c_WHILE" != n_head.type 62 | n_head.scope_end_cmd = n_while 63 | n_head = n_while 64 | story.add_line(command) 65 | assert len(scope_stack) == 1 66 | 67 | # Running 68 | data_stack = [] 69 | is_returning_to_if = False 70 | for i in range(run_steps): 71 | command = n_pc.executing.identifier[2:] 72 | if command == "IF" or command == "WHILE": 73 | assert len(data_stack) > 0 74 | if is_returning_to_if: 75 | n_pc.executing = n_pc.executing.next_then 76 | is_returning_to_if = False 77 | elif data_stack[-1].value is not None: 78 | n_pc.executing = n_pc.executing.next_if_true 79 | elif n_pc.executing.next_if_false is not None: 80 | n_pc.executing = n_pc.executing.next_if_false 81 | else: 82 | n_pc.executing = n_pc.executing.next_then 83 | elif command == "HALT": 84 | pass 85 | else: 86 | if command == "NOP": 87 | pass 88 | elif command == "ZERO": 89 | n_stacknode = graph.make("stacknode") 90 | if len(data_stack) > 0: 91 | n_stacknode.prev = data_stack[-1] 92 | data_stack.append(n_stacknode) 93 | n_pc.stack_top = n_stacknode 94 | elif command == "INC": 95 | assert len(data_stack) > 0 96 | n_stacknode = data_stack[-1] 97 | n_counter = graph.make("counter") 98 | n_counter.successor = n_stacknode.value 99 | n_stacknode.value = n_counter 100 | elif command == "DEC": 101 | assert len(data_stack) > 0 102 | n_stacknode = data_stack[-1] 103 | if n_stacknode.value is not None: 104 | n_stacknode.value = n_stacknode.value.successor 105 | elif command == "DUP": 106 | n_stacknode = graph.make("stacknode") 107 | if len(data_stack) > 0: 108 | n_stacknode.prev = data_stack[-1] 109 | n_stacknode.value = n_stacknode.prev.value 110 | data_stack.append(n_stacknode) 111 | n_pc.stack_top = n_stacknode 112 | elif command == "SWAP": 113 | assert len(data_stack) >= 2 114 | n_node1, n_node2 = data_stack[-2:] 115 | data_stack[-2:] = n_node2, n_node1 116 | n_node1.prev, n_node2.prev = n_node2.prev, n_node1.prev 117 | elif command == "POP": 118 | assert len(data_stack) > 0 119 | n_stacknode = data_stack.pop() 120 | n_stacknode.prev = None 121 | n_pc.stack_top = data_stack[-1] 122 | elif command == "NOT": 123 | assert len(data_stack) > 0 124 | n_stack_top = data_stack[-1] 125 | n_stacknode = graph.make("stacknode") 126 | n_stacknode.prev = data_stack[-1] 127 | if n_stack_top.value is None: 128 | n_counter = graph.make("counter") 129 | n_stacknode.value = n_counter 130 | data_stack.append(n_stacknode) 131 | n_pc.stack_top = n_stacknode 132 | if n_pc.executing.next_cmd is not None: 133 | n_pc.executing = n_pc.executing.next_cmd 134 | else: 135 | if n_pc.executing.scope_end_cmd.type == "c_IF": 136 | is_returning_to_if = True 137 | n_pc.executing = n_pc.executing.scope_end_cmd 138 | assert n_pc.executing is not None 139 | story.add_line("[RUN]") 140 | 141 | def _build_forth_string(max_len, stacklen=0): 142 | if max_len == 0: 143 | return [], stacklen 144 | chances = { 145 | "NOP":5, 146 | "ZERO":10, 147 | "INC":10 if stacklen>0 else 0, 148 | "DEC":3 if stacklen>0 else 0, 149 | "DUP":5 if stacklen>0 else 0, 150 | "SWAP":7 if stacklen>=2 else 0, 151 | "POP":3 if stacklen>0 else 0, 152 | "NOT":3 if stacklen>0 else 0, 153 | "HALT":2, 154 | "IF_THEN": 5 if max_len >=3 and stacklen>0 else 0, 155 | "IF_ELSE_THEN": 5 if max_len >=5 and stacklen>0 else 0, 156 | "WHILE_REPEAT": 10 if max_len >=3 and stacklen>0 else 0, 157 | } 158 | stack_deltas = { 159 | "ZERO":1, 160 | "DUP":1, 161 | "POP":-1, 162 | } 163 | chance_sum = sum(v for k,v in chances.items()) 164 | while True: 165 | val = random.randrange(chance_sum) 166 | for cmd, chance in chances.items(): 167 | val -= chance 168 | if val < 0: 169 | chosen_command = cmd 170 | break 171 | if chosen_command == "IF_THEN": 172 | tot_allocation = max_len - 2 173 | true_allocation = random.randrange(1, tot_allocation+1) 174 | then_allocation = tot_allocation - true_allocation 175 | 176 | true_cmds, true_stacklen = _build_forth_string(true_allocation, stacklen) 177 | next_stacklen = min(stacklen, true_stacklen) 178 | then_cmds, final_stacklen = _build_forth_string(then_allocation, next_stacklen) 179 | return (["IF"] + true_cmds + ["THEN"] + then_cmds), final_stacklen 180 | 181 | elif chosen_command == "IF_ELSE_THEN": 182 | tot_allocation = max_len - 3 183 | cond_allocation = random.randrange(2, tot_allocation+1) 184 | then_allocation = tot_allocation - cond_allocation 185 | true_allocation = random.randrange(1, cond_allocation) 186 | false_allocation = cond_allocation - true_allocation 187 | 188 | true_cmds, true_stacklen = _build_forth_string(true_allocation, stacklen) 189 | false_cmds, false_stacklen = _build_forth_string(false_allocation, stacklen) 190 | next_stacklen = min(true_stacklen, false_stacklen) 191 | then_cmds, final_stacklen = _build_forth_string(then_allocation, next_stacklen) 192 | 193 | return (["IF"] + true_cmds \ 194 | + ["ELSE"] + false_cmds \ 195 | + ["THEN"] + then_cmds), final_stacklen 196 | elif chosen_command == "WHILE_REPEAT": 197 | tot_allocation = max_len - 2 198 | true_allocation = random.randrange(1, tot_allocation+1) 199 | then_allocation = tot_allocation - true_allocation 200 | 201 | while True: 202 | true_cmds, true_stacklen = _build_forth_string(true_allocation, stacklen) 203 | if true_stacklen >= stacklen: 204 | break 205 | then_cmds, final_stacklen = _build_forth_string(then_allocation, stacklen) 206 | return (["WHILE"] + true_cmds + ["REPEAT"] + then_cmds), final_stacklen 207 | else: 208 | if chosen_command in stack_deltas: 209 | next_stacklen = stacklen + stack_deltas[chosen_command] 210 | else: 211 | next_stacklen = stacklen 212 | rest_cmds, rest_stacklen = _build_forth_string(max_len-1, next_stacklen) 213 | return [chosen_command] + rest_cmds, rest_stacklen 214 | 215 | def build_forth_string(max_len): 216 | return " ".join(_build_forth_string(max_len)[0] + ["HALT"]) 217 | 218 | def generate(num_seqs, seq_length): 219 | for _ in range(num_seqs): 220 | forth_string = build_forth_string(seq_length) 221 | print(forth_string) 222 | 223 | 224 | 225 | -------------------------------------------------------------------------------- /task_generators/graph_tools.py: -------------------------------------------------------------------------------- 1 | import json 2 | import collections 3 | 4 | Edge = collections.namedtuple("Edge",["source","dest","type"]) 5 | 6 | class GraphHelper( object ): 7 | def __init__(self): 8 | self.counters = collections.defaultdict(lambda: 0) 9 | self.nodes = set() 10 | self.edges = set() 11 | 12 | def dumps(self): 13 | return json.dumps({ 14 | "nodes": sorted(self.nodes), 15 | "edges": [{"from":e.source,"to":e.dest,"type":e.type} for e in self.edges] 16 | }) 17 | 18 | def make(self, node_type): 19 | full_name = node_type + "#" + str(self.counters[node_type]) 20 | self.counters[node_type] += 1 21 | self.nodes.add(full_name) 22 | return Node(full_name, self) 23 | 24 | def make_unique(self, node_name): 25 | if not node_name in self.nodes: 26 | self.nodes.add(node_name) 27 | return Node(node_name, self) 28 | 29 | class BadEdgeError( Exception ): 30 | pass 31 | 32 | class Node( object ): 33 | def __init__(self, identifier, parent): 34 | object.__setattr__(self, 'identifier', identifier) 35 | object.__setattr__(self, 'parent', parent) 36 | 37 | def __getattr__(self, edgename): 38 | matching = set(e.dest for e in self.parent.edges 39 | if e.source == self.identifier 40 | and e.type == edgename) 41 | if len(matching) == 0: 42 | return None 43 | elif len(matching) > 1: 44 | raise BadEdgeError("Expected one result for {}.{}, got {}".format(self.identifier, edgename, matching)) 45 | return Node(matching.pop(), self.parent) 46 | 47 | def __getitem__(self, edgename): 48 | return self.__getattr__(edgename) 49 | 50 | def __setattr__(self, edgename, value): 51 | if edgename in ["identifier","parent"]: 52 | self.__setattribute__(edgename, value) 53 | return 54 | matching = set(e for e in self.parent.edges 55 | if e.source == self.identifier 56 | and e.type == edgename) 57 | if len(matching) > 1: 58 | print("WARNING: Setting attr {} on {} clears old values, but has multiple edges {}".format(edgename,self.identifier,matching)) 59 | self.parent.edges -= matching 60 | if value is not None: 61 | self.parent.edges.add(Edge(self.identifier, value.identifier, edgename)) 62 | 63 | def __setitem__(self, edgename, value): 64 | return self.__setattr__(edgename, value) 65 | 66 | def getall(self, edgename): 67 | matching = frozenset(e.dest for e in self.parent.edges 68 | if e.source == self.identifier 69 | and e.type == edgename) 70 | return matching 71 | 72 | def add(self, edgename, dest): 73 | self.parent.edges.add(Edge(self.identifier, dest.identifier, edgename)) 74 | 75 | def remove(self, edgename=None, dest=None): 76 | matching = set(e for e in self.parent.edges 77 | if e.source == self.identifier 78 | and (e.dest == dest.identifier or dest is None) 79 | and (e.type == edgename or edgename is None)) 80 | self.parent.edges -= matching 81 | 82 | @property 83 | def type(self): 84 | return self.identifier.split("#")[0] 85 | 86 | 87 | class Story( object ): 88 | def __init__(self): 89 | self.graph = GraphHelper() 90 | self.counter = 1 91 | self.lines = [] 92 | 93 | def add_line(self, line_str): 94 | assert not "=" in line_str 95 | assert not "\t" in line_str 96 | self.lines.append("{} {}={}".format(self.counter, line_str, self.graph.dumps())) 97 | self.counter += 1 98 | 99 | def no_query(self): 100 | self.add_query("","") 101 | 102 | def add_query(self, query, answer): 103 | assert not "=" in query + answer 104 | assert not "\t" in query + answer 105 | self.lines.append("{} {}\t{}".format(self.counter, query, answer)) 106 | self.counter += 1 107 | 108 | 109 | 110 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /task_generators/ngram_next.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import json 4 | import sys 5 | 6 | def all_ngrams(seq, ngram_size): 7 | for i in range(len(seq)+1-ngram_size): 8 | yield tuple(seq[i:i+ngram_size]) 9 | 10 | def ngram_next_map(seq, ngram_size): 11 | the_map = {} 12 | for ngram in all_ngrams(seq, ngram_size+1): 13 | key = ngram[:-1] 14 | val = ngram[-1] 15 | if key in the_map and the_map[key] != val: 16 | # Don't want keys that appear twice 17 | the_map[key] = None 18 | else: 19 | the_map[key] = val 20 | return {k:v for k,v in the_map.items() if v is not None} 21 | 22 | ITEM_PTR = "$ITEM$" 23 | def generate(num_seqs, seq_length, ngram_size, symbols): 24 | assert ITEM_PTR not in symbols 25 | assert seq_length > ngram_size 26 | result = [] 27 | for _ in range(num_seqs): 28 | while True: #just in case we don't find a good query 29 | story = [] 30 | last_ptr = None 31 | values = [] 32 | nodes = [] 33 | edges = [] 34 | for i in range(seq_length): 35 | # Choose next number 36 | next_item = random.choice(symbols) 37 | if not next_item in nodes: 38 | nodes.append(next_item) 39 | cur_ptr = ITEM_PTR + "#" + str(i) 40 | nodes.append(cur_ptr) 41 | if last_ptr is not None: 42 | edges.append({"from":last_ptr,"to":cur_ptr,"type":"next"}) 43 | edges.append({"from":cur_ptr,"to":next_item,"type":"value"}) 44 | last_ptr = cur_ptr 45 | values.append(next_item) 46 | graph_str = json.dumps({ 47 | "nodes":nodes, 48 | "edges":edges, 49 | }) 50 | story.append("{} {}={}".format(i+1, next_item, graph_str)) 51 | possible_queries = ngram_next_map(values, ngram_size) 52 | if len(possible_queries) > 0: 53 | key, val = random.choice(list(possible_queries.items())) 54 | story.append("{} {}?\t{}".format(seq_length+1, ' '.join(key), val)) 55 | result.extend(story) 56 | break 57 | return "\n".join(result)+"\n" 58 | 59 | def main(num_seqs, seq_length, ngram_size, file): 60 | generated = generate(num_seqs, seq_length, ngram_size, [str(x) for x in range(10)]) 61 | file.write(generated) 62 | 63 | parser = argparse.ArgumentParser(description='Generate an ngrams task') 64 | parser.add_argument("file", nargs="?", default=sys.stdout, type=argparse.FileType('w'), help="Output file") 65 | parser.add_argument("--ngram-size", type=int, default=3, help="Size of ngrams") 66 | parser.add_argument("--num-seqs", type=int, default=1, help="Number of sequences to generate") 67 | parser.add_argument("--seq-length", type=int, default=10, help="Length of sequences to generate") 68 | 69 | if __name__ == '__main__': 70 | args = vars(parser.parse_args()) 71 | main(**args) 72 | -------------------------------------------------------------------------------- /task_generators/turing.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import json 4 | import sys 5 | import graph_tools 6 | 7 | def make_turing_machine_rules(n_states, n_symbols): 8 | the_rules = [ [ (random.randrange(n_symbols), random.randrange(n_states), random.choice('LNR')) 9 | for symbol in range(n_symbols)] 10 | for state in range(n_states)] 11 | return the_rules 12 | 13 | def encode_turing_machine_rules(rules, starting_state=None, story=None): 14 | if story is None: 15 | story = graph_tools.Story() 16 | graph = story.graph 17 | if starting_state is None: 18 | starting_state = random.choice(len(rules)) 19 | the_edges = [(cstate, read, write, nstate, direc) 20 | for (cstate, stuff) in enumerate(rules) 21 | for (read, (write, nstate, direc)) in enumerate(stuff)] 22 | random.shuffle(the_edges) 23 | for cstate, read, write, nstate, direc in the_edges: 24 | source = graph.make_unique('state_{}'.format(cstate)) 25 | dest = graph.make_unique('state_{}'.format(nstate)) 26 | edge_type = "rule_{}_{}_{}".format(read,write,direc) 27 | source[edge_type] = dest 28 | story.add_line("rule {} {} {} {} {}".format(source.type, read, write, dest.type, direc)) 29 | head = graph.make_unique('head') 30 | 31 | head.state = graph.make_unique('state_{}'.format(starting_state)) 32 | story.add_line("start {}".format(head.state.type)) 33 | return story 34 | 35 | def encode_turing_machine_process(rules, starting_state, iptlist, process_len, head_index=0, story=None, update_state=False): 36 | if story is None: 37 | story = graph_tools.Story() 38 | graph = story.graph 39 | last_input = None 40 | cells = [] 41 | for i,symbol in enumerate(iptlist): 42 | cell = graph.make('cell') 43 | cell.left = last_input 44 | cell.value = graph.make_unique('symbol_{}'.format(symbol)) 45 | cells.append(cell) 46 | last_input = cell 47 | if head_index == i: 48 | head = graph.make_unique('head') 49 | head.cell = cell 50 | story.add_line("input {} head".format(cell.value.type)) 51 | else: 52 | story.add_line("input {}".format(cell.value.type)) 53 | 54 | cstate = starting_state 55 | cell_values = iptlist[:] 56 | for _ in range(process_len): 57 | cell = cells[head_index] 58 | read = cell_values[head_index] 59 | write, nstate, direc = rules[cstate][read] 60 | cell_values[head_index] = write 61 | cell.value = graph.make_unique('symbol_{}'.format(write)) 62 | cstate = nstate 63 | if update_state: 64 | head.state = graph.make_unique('state_{}'.format(nstate)) 65 | 66 | if direc == "L": 67 | if head_index == 0: 68 | newcell = graph.make('cell') 69 | cells.insert(0, newcell) 70 | cells[1].left = newcell 71 | newcell.value = graph.make_unique('symbol_{}'.format(0)) 72 | cell_values.insert(0, 0) 73 | head_index += 1 74 | head_index -= 1 75 | head.cell = cells[head_index] 76 | elif direc == "R": 77 | if head_index == len(cells)-1: 78 | newcell = graph.make('cell') 79 | cells.append(newcell) 80 | newcell.left = cells[-2] 81 | newcell.value = graph.make_unique('symbol_{}'.format(0)) 82 | cell_values.append(0) 83 | head_index += 1 84 | head.cell = cells[head_index] 85 | story.add_line('[RUN]') 86 | story.no_query() 87 | return story 88 | 89 | def generate_universal(num_seqs, num_states, num_symbols, input_len, run_len): 90 | result = [] 91 | for _ in range(num_seqs): 92 | rules = make_turing_machine_rules(num_states, num_symbols) 93 | start_state = random.randrange(num_states) 94 | input_list = [random.choice(range(num_symbols)) for _ in range(input_len)] 95 | head_index = random.randrange(input_len) 96 | story = encode_turing_machine_rules(rules, start_state) 97 | story = encode_turing_machine_process(rules, start_state, input_list, run_len, head_index, story, True) 98 | result.extend(story.lines) 99 | return "\n".join(result)+"\n" 100 | 101 | def generate_busybeaver(alt=False): 102 | if alt: 103 | rules = [ 104 | [ # State A (0) 105 | (1,1,'R'), 106 | (1,3,'N'), 107 | ], 108 | [ # State B (1) 109 | (0,2,'R'), 110 | (1,1,'R'), 111 | ], 112 | [ # State C (2) 113 | (1,2,'L'), 114 | (1,0,'L'), 115 | ], 116 | [ # State HALT (3) 117 | (0,3,'N'), 118 | (1,3,'N'), 119 | ], 120 | ] 121 | else: 122 | rules = [ 123 | [ # State A (0) 124 | (1,1,'R'), 125 | (1,2,'L'), 126 | ], 127 | [ # State B (1) 128 | (1,0,'L'), 129 | (1,1,'R'), 130 | ], 131 | [ # State C (2) 132 | (1,1,'L'), 133 | (1,3,'N'), 134 | ], 135 | [ # State HALT (3) 136 | (0,3,'N'), 137 | (1,3,'N'), 138 | ], 139 | ] 140 | start_state = 0 141 | input_list = [0] 142 | head_index = 0 143 | story = encode_turing_machine_rules(rules, start_state) 144 | story = encode_turing_machine_process(rules, start_state, input_list, 16, head_index, story, True) 145 | return "\n".join(story.lines)+"\n" 146 | 147 | def main(num_seqs, num_states, num_symbols, input_len, run_len, file, busybeaver, busybeaver_alt): 148 | if busybeaver: 149 | generated = generate_busybeaver(busybeaver_alt) 150 | else: 151 | generated = generate_universal(num_seqs, num_states, num_symbols, input_len, run_len) 152 | file.write(generated) 153 | 154 | parser = argparse.ArgumentParser(description='Generate a universal turing machine task') 155 | parser.add_argument("file", nargs="?", default=sys.stdout, type=argparse.FileType('w'), help="Output file") 156 | parser.add_argument("--num-states", type=int, default=4, help="Number of states") 157 | parser.add_argument("--num-symbols", type=int, default=4, help="Number of symbols") 158 | parser.add_argument("--input-len", type=int, default=5, help="Length of input") 159 | parser.add_argument("--run-len", type=int, default=10, help="How many steps to simulate") 160 | parser.add_argument("--num-seqs", type=int, default=1, help="Number of sequences to generate") 161 | parser.add_argument("--busybeaver", action="store_true", help="Just generate the busy-beaver task") 162 | parser.add_argument("--busybeaver-alt", action="store_true", help="Generate alternate busy-beaver task") 163 | 164 | if __name__ == '__main__': 165 | args = vars(parser.parse_args()) 166 | main(**args) 167 | 168 | 169 | 170 | 171 | -------------------------------------------------------------------------------- /train_exit_status.py: -------------------------------------------------------------------------------- 1 | 2 | from enum import Enum 3 | 4 | class TrainExitStatus( Enum ): 5 | success = 0 6 | error = 1 # for consistency with python's default error exit status 7 | malformed_command = 2 8 | reached_update_limit = 3 9 | interrupted = 4 10 | nan_loss = 5 11 | overfitting = 6 12 | -------------------------------------------------------------------------------- /transformation_modules/README.md: -------------------------------------------------------------------------------- 1 | This directory contains modules implementing a set of different operations used by the model. 2 | 3 | First, the modules that implement the five classes of graph transformation are: 4 | 5 | - Node addition: implemented in [new_nodes_inform.py](new_nodes_inform.py) 6 | - Node state update: non-direct-reference update implemented in [node_state_update.py](node_state_update.py), and direct-reference update implemented in [direct_reference_update.py](direct_reference_update.py) 7 | - Edge update: implemented in [edge_state_update.py](edge_state_update.py) 8 | - Propagation: implemented in [propagation.py](propagation.py) 9 | - Aggregation: implemented in [aggregate_representation.py](aggregate_representation.py) 10 | 11 | Additional modules used by the model are: 12 | 13 | - [input_sequence_direct.py](input_sequence_direct.py) implements a GRU layer that scans through an input and creates the full sequence representation vector and the direct-reference input matrix (described in the paper in section 4). 14 | - [output_category.py](output_category.py) implements a simple output layer that chooses a single output out of a set of possibilities. This is used for most of the bAbI tasks. 15 | - [output_set.py](output_set.py) implements a simple output layer that chooses a single output out of a set of possibilities. This is used for task 8. 16 | - [output_sequence.py](output_sequence.py) implements a GRU-based output layer that chooses a sequence of outputs. This is used for task 19. 17 | - [sequence_aggregate_summary.py](sequence_aggregate_summary.py) implements a GRU-based layer that aggregates information from a series of vectors, used in implementing the version of the model described in Appendix D. 18 | 19 | Finally, for compatibility with preliminary versions of the model, these (deprecated) modules are included: 20 | 21 | - [aggregate_representation_softmax.py](aggregate_representation_softmax.py) implements a version of the aggregation transformation that uses softmax to select attention targets. This was found to work equivalently well to to using a sigmoid activation function for selecting attention, but was not used for final experiments for compatibility with the GG-NN model. 22 | - [new_nodes_vote.py](new_nodes_vote.py) implements a version of the node addition transformation that uses votes from existing nodes to determine the existence of new ones, where any existing node can "veto" a new node. This was found not to work very well in initial experiments. 23 | -------------------------------------------------------------------------------- /transformation_modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .aggregate_representation import AggregateRepresentationTransformation 2 | from .aggregate_representation_softmax import AggregateRepresentationTransformationSoftmax 3 | from .edge_state_update import EdgeStateUpdateTransformation 4 | from .input_sequence_direct import InputSequenceDirectTransformation 5 | from .new_nodes_vote import NewNodesVoteTransformation 6 | from .new_nodes_inform import NewNodesInformTransformation 7 | from .node_state_update import NodeStateUpdateTransformation 8 | from .direct_reference_update import DirectReferenceUpdateTransformation 9 | from .output_category import OutputCategoryTransformation 10 | from .output_sequence import OutputSequenceTransformation 11 | from .output_set import OutputSetTransformation 12 | from .propagation import PropagationTransformation 13 | from .sequence_aggregate_summary import SequenceAggregateSummaryTransformation -------------------------------------------------------------------------------- /transformation_modules/aggregate_representation.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | import numpy as np 4 | 5 | from util import * 6 | from layer import * 7 | from graph_state import GraphState, GraphStateSpec 8 | 9 | class AggregateRepresentationTransformation( object ): 10 | """ 11 | Transforms a graph state into a single representation vector 12 | """ 13 | def __init__(self, representation_width, graph_spec, dropout_keep=1, dropout_output=True): 14 | self._representation_width = representation_width 15 | self._graph_spec = graph_spec 16 | 17 | self._representation_stack = LayerStack(graph_spec.num_node_ids + graph_spec.node_state_size, representation_width+1, name="aggregaterepr", dropout_keep=dropout_keep, dropout_input=False, dropout_output=dropout_output) 18 | 19 | @property 20 | def params(self): 21 | return self._representation_stack.params 22 | 23 | def dropout_masks(self, srng): 24 | return self._representation_stack.dropout_masks(srng) 25 | 26 | def process(self, gstate, dropout_masks=Ellipsis): 27 | """ 28 | Convert the graph state to a representation vector, using sigmoid attention to scale representations 29 | 30 | Params: 31 | gstate: A GraphState giving the current state 32 | 33 | Returns: A representation vector of shape (n_batch, representation_width) 34 | """ 35 | if dropout_masks is Ellipsis: 36 | dropout_masks = None 37 | append_masks = False 38 | else: 39 | append_masks = True 40 | 41 | flat_obs = T.concatenate([ 42 | gstate.node_ids.reshape([-1, self._graph_spec.num_node_ids]), 43 | gstate.node_states.reshape([-1, self._graph_spec.node_state_size])], 1) 44 | flat_activations, dropout_masks = self._representation_stack.process(flat_obs, dropout_masks) 45 | activations = flat_activations.reshape([gstate.n_batch, gstate.n_nodes, self._representation_width+1]) 46 | 47 | activation_strengths = activations[:,:,0] 48 | selector = T.shape_padright(T.nnet.sigmoid(activation_strengths) * gstate.node_strengths) 49 | representations = T.tanh(activations[:,:,1:]) 50 | 51 | result = T.tanh(T.sum(selector * representations, 1)) 52 | if append_masks: 53 | return result, dropout_masks 54 | else: 55 | return result 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /transformation_modules/aggregate_representation_softmax.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | import numpy as np 4 | 5 | from util import * 6 | from layer import * 7 | from graph_state import GraphState, GraphStateSpec 8 | 9 | class AggregateRepresentationTransformationSoftmax( object ): 10 | """ 11 | Old version of the AggregateRepresentationTransformation 12 | Transforms a graph state into a single representation vector 13 | """ 14 | def __init__(self, representation_width, graph_spec, dropout_keep=1, dropout_output=True): 15 | self._representation_width = representation_width 16 | self._graph_spec = graph_spec 17 | 18 | self._representation_stack = LayerStack(graph_spec.num_node_ids + graph_spec.node_state_size, representation_width+1, name="aggregaterepr", dropout_keep=dropout_keep, dropout_input=False, dropout_output=dropout_output) 19 | 20 | @property 21 | def params(self): 22 | return self._representation_stack.params 23 | 24 | def dropout_masks(self, srng): 25 | return self._representation_stack.dropout_masks(srng) 26 | 27 | def process(self, gstate, dropout_masks=Ellipsis): 28 | """ 29 | Convert the graph state to a representation vector, using softmax attention to scale representations 30 | 31 | Params: 32 | gstate: A GraphState giving the current state 33 | 34 | Returns: A representation vector of shape (n_batch, representation_width) 35 | """ 36 | if dropout_masks is Ellipsis: 37 | dropout_masks = None 38 | append_masks = False 39 | else: 40 | append_masks = True 41 | 42 | flat_obs = T.concatenate([ 43 | gstate.node_ids.reshape([-1, self._graph_spec.num_node_ids]), 44 | gstate.node_states.reshape([-1, self._graph_spec.node_state_size])], 1) 45 | flat_activations, dropout_masks = self._representation_stack.process(flat_obs, dropout_masks) 46 | activations = flat_activations.reshape([gstate.n_batch, gstate.n_nodes, self._representation_width+1]) 47 | 48 | activation_strengths = activations[:,:,0] 49 | existence_penalty = T.log(gstate.node_strengths + EPSILON) # TODO: consider removing epsilon here 50 | selector = T.shape_padright(T.nnet.softmax(activation_strengths + existence_penalty)) 51 | representations = T.tanh(activations[:,:,1:]) 52 | 53 | result = T.sum(selector * representations, 1) 54 | if append_masks: 55 | return result, dropout_masks 56 | else: 57 | return result 58 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /transformation_modules/direct_reference_update.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | import numpy as np 4 | 5 | from util import * 6 | from graph_state import GraphState, GraphStateSpec 7 | from base_gru import BaseGRULayer 8 | 9 | class DirectReferenceUpdateTransformation( object ): 10 | """ 11 | Transforms a graph state by updating note states, conditioned on a direct reference accumulation 12 | """ 13 | def __init__(self, input_width, graph_spec, dropout_keep=1): 14 | """ 15 | Params: 16 | input_width: Integer giving size of input 17 | graph_spec: Instance of GraphStateSpec giving graph spec 18 | """ 19 | self._input_width = input_width 20 | self._graph_spec = graph_spec 21 | 22 | self._update_gru = BaseGRULayer(input_width + graph_spec.num_node_ids, graph_spec.node_state_size, name="nodestateupdate", dropout_keep=dropout_keep) 23 | 24 | @property 25 | def params(self): 26 | return self._update_gru.params 27 | 28 | def dropout_masks(self, srng, state_mask=None): 29 | return self._update_gru.dropout_masks(srng, use_output=state_mask) 30 | 31 | def process(self, gstate, ref_matrix, dropout_masks=Ellipsis): 32 | """ 33 | Process a direct ref matrix and update the state accordingly. Each node runs a GRU step 34 | with previous state from the node state and input from the matrix. 35 | 36 | Params: 37 | gstate: A GraphState giving the current state 38 | ref_matrix: A tensor of the form (n_batch, num_node_ids, input_width) 39 | """ 40 | if dropout_masks is Ellipsis: 41 | dropout_masks = None 42 | append_masks = False 43 | else: 44 | append_masks = True 45 | 46 | # To process the input, we need to map from node id to node index 47 | # We can do this using the gstate.node_ids, of shape (n_batch, n_nodes, num_node_ids) 48 | prepped_input_vector = T.batched_dot(gstate.node_ids, ref_matrix) 49 | 50 | # prepped_input_vector is of shape (n_batch, n_nodes, input_width) 51 | # gstate.node_states is of shape (n_batch, n_nodes, node_state_width) 52 | # so they match nicely 53 | full_input = T.concatenate([gstate.node_ids, prepped_input_vector], 2) 54 | 55 | # we flatten to apply GRU 56 | flat_input = full_input.reshape([-1, self._input_width + self._graph_spec.num_node_ids]) 57 | flat_state = gstate.node_states.reshape([-1, self._graph_spec.node_state_size]) 58 | new_flat_state, dropout_masks = self._update_gru.step(flat_input, flat_state, dropout_masks) 59 | 60 | new_node_states = new_flat_state.reshape(gstate.node_states.shape) 61 | 62 | new_gstate = gstate.with_updates(node_states=new_node_states) 63 | if append_masks: 64 | return new_gstate, dropout_masks 65 | else: 66 | return new_gstate 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /transformation_modules/edge_state_update.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | import numpy as np 4 | 5 | from util import * 6 | from layer import * 7 | from graph_state import GraphState, GraphStateSpec 8 | 9 | class EdgeStateUpdateTransformation( object ): 10 | """ 11 | Transforms a graph state by updating edge states, conditioned on an input vector and nodes 12 | """ 13 | def __init__(self, input_width, graph_spec, dropout_keep=1): 14 | """ 15 | Params: 16 | input_width: Integer giving size of input 17 | graph_spec: Instance of GraphStateSpec giving graph spec 18 | """ 19 | self._input_width = input_width 20 | self._graph_spec = graph_spec 21 | self._process_input_size = input_width + 2*(graph_spec.num_node_ids + graph_spec.node_state_size) 22 | 23 | self._update_stack = LayerStack(self._process_input_size, 2*graph_spec.num_edge_types, [self._process_input_size], activation=T.nnet.sigmoid, bias_shift=-3.0, name="edge_update", dropout_keep=dropout_keep, dropout_input=False) 24 | 25 | @property 26 | def params(self): 27 | return self._update_stack.params 28 | 29 | def dropout_masks(self, srng): 30 | return self._update_stack.dropout_masks(srng) 31 | 32 | def process(self, gstate, input_vector, dropout_masks=Ellipsis): 33 | """ 34 | Process an input vector and update the state accordingly. Each node runs a GRU step 35 | with previous state from the node state and input from the vector. 36 | 37 | Params: 38 | gstate: A GraphState giving the current state 39 | input_vector: A tensor of the form (n_batch, input_width) 40 | """ 41 | if dropout_masks is Ellipsis: 42 | dropout_masks = None 43 | append_masks = False 44 | else: 45 | append_masks = True 46 | 47 | # gstate.edge_states is of shape (n_batch, n_nodes, n_nodes, id+state) 48 | # combined input should be broadcasted to (n_batch, n_nodes, n_nodes, X) 49 | input_vector_part = T.shape_padaxis(T.shape_padaxis(input_vector, 1), 2) 50 | source_state_part = T.shape_padaxis(T.concatenate([gstate.node_ids, gstate.node_states], 2), 2) 51 | dest_state_part = T.shape_padaxis(T.concatenate([gstate.node_ids, gstate.node_states], 2), 1) 52 | full_input = broadcast_concat([input_vector_part, source_state_part, dest_state_part], 3) 53 | 54 | # we flatten to process updates 55 | flat_input = full_input.reshape([-1, self._process_input_size]) 56 | flat_result, dropout_masks = self._update_stack.process(flat_input, dropout_masks) 57 | result = flat_result.reshape([gstate.n_batch, gstate.n_nodes, gstate.n_nodes, self._graph_spec.num_edge_types, 2]) 58 | should_set = result[:,:,:,:,0] 59 | should_clear = result[:,:,:,:,1] 60 | 61 | new_strengths = gstate.edge_strengths*(1-should_clear) + (1-gstate.edge_strengths)*should_set 62 | 63 | new_gstate = gstate.with_updates(edge_strengths=new_strengths) 64 | if append_masks: 65 | return new_gstate, dropout_masks 66 | else: 67 | return new_gstate 68 | 69 | -------------------------------------------------------------------------------- /transformation_modules/input_sequence_direct.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | import numpy as np 4 | 5 | from util import * 6 | from base_gru import BaseGRULayer 7 | 8 | 9 | class InputSequenceDirectTransformation( object ): 10 | """ 11 | Transforms an input sequence into a representation vector 12 | """ 13 | def __init__(self, num_words, num_node_ids, word_node_mapping, output_width): 14 | """ 15 | num_words: Number of words in the input sequence 16 | word_node_mapping: Mapping of word idx to node idx for direct mapping 17 | """ 18 | self._num_words = num_words 19 | self._num_node_ids = num_node_ids 20 | self._word_node_mapping = word_node_mapping 21 | self._output_width = output_width 22 | 23 | self._word_node_matrix = np.zeros([num_words, num_node_ids], np.float32) 24 | for word,node in word_node_mapping.items(): 25 | self._word_node_matrix[word,node] = 1.0 26 | 27 | self._gru = BaseGRULayer(num_words, output_width, name="input_sequence") 28 | 29 | @property 30 | def params(self): 31 | return self._gru.params 32 | 33 | def process(self, inputs): 34 | """ 35 | Process a set of inputs and return the final state 36 | 37 | Params: 38 | input_words: List of input indices. Should be an int tensor of shape (n_batch, input_len) 39 | 40 | Returns: repr_vect, node_vects 41 | repr_vect: The final representation vector, of shape (n_batch, output_width) 42 | node_vects: Direct-access vects for each node id, of shape (n_batch, num_node_ids, output_width) 43 | """ 44 | n_batch, input_len = inputs.shape 45 | valseq = inputs.dimshuffle([1,0]) 46 | one_hot_vals = T.extra_ops.to_one_hot(inputs.flatten(), self._num_words)\ 47 | .reshape([n_batch, input_len, self._num_words]) 48 | one_hot_valseq = one_hot_vals.dimshuffle([1,0,2]) 49 | 50 | def scan_fn(idx_ipt, onehot_ipt, last_accum, last_state): 51 | # last_accum stores accumulated outputs per word type 52 | # and is of shape (n_batch, word_idx, output_width) 53 | gru_state = self._gru.step(onehot_ipt, last_state) 54 | new_accum = T.inc_subtensor(last_accum[T.arange(n_batch), idx_ipt, :], gru_state) 55 | return new_accum, gru_state 56 | 57 | outputs_info = [T.zeros([n_batch, self._num_words, self._output_width]), self._gru.initial_state(n_batch)] 58 | (all_accum, all_out), _ = theano.scan(scan_fn, sequences=[valseq, one_hot_valseq], outputs_info=outputs_info) 59 | 60 | # all_out is of shape (input_len, n_batch, self.output_width). We want last timestep 61 | repr_vect = all_out[-1,:,:] 62 | 63 | final_accum = all_accum[-1,:,:,:] 64 | # Now we also want to extract and accumulate the outputs that directly map to each word 65 | # We can do this by multipying the final accum's second dimension (word_idx) through by 66 | # the word_node_matrix 67 | resh_flat_final_accum = final_accum.dimshuffle([0,2,1]).reshape([-1, self._num_words]) 68 | resh_flat_node_mat = T.dot(resh_flat_final_accum, self._word_node_matrix) 69 | node_vects = resh_flat_node_mat.reshape([n_batch, self._output_width, self._num_node_ids]).dimshuffle([0,2,1]) 70 | 71 | return repr_vect, node_vects 72 | -------------------------------------------------------------------------------- /transformation_modules/new_nodes_inform.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | import numpy as np 4 | 5 | from util import * 6 | from layer import * 7 | from graph_state import GraphState, GraphStateSpec 8 | from base_gru import BaseGRULayer 9 | from .aggregate_representation import AggregateRepresentationTransformation 10 | from .aggregate_representation_softmax import AggregateRepresentationTransformationSoftmax 11 | 12 | class NewNodesInformTransformation( object ): 13 | """ 14 | Transforms a graph state by adding nodes, conditioned on an input vector 15 | """ 16 | def __init__(self, input_width, inform_width, proposal_width, graph_spec, use_old_aggregate=False, dropout_keep=1): 17 | """ 18 | Params: 19 | input_width: Integer giving size of input 20 | inform_width: Size of internal aggregate 21 | proposal_width: Size of internal proposal 22 | graph_spec: Instance of GraphStateSpec giving graph spec 23 | use_old_aggregate: Use the old aggregation mode 24 | """ 25 | self._input_width = input_width 26 | self._graph_spec = graph_spec 27 | self._proposal_width = proposal_width 28 | self._inform_width = inform_width 29 | 30 | aggregate_type = AggregateRepresentationTransformationSoftmax \ 31 | if use_old_aggregate \ 32 | else AggregateRepresentationTransformation 33 | 34 | self._inform_aggregate = aggregate_type(inform_width, graph_spec, dropout_keep, dropout_output=True) 35 | self._proposer_gru = BaseGRULayer(input_width+inform_width, proposal_width, name="newnodes_proposer", dropout_keep=dropout_keep, dropout_input=False, dropout_output=True) 36 | self._proposer_stack = LayerStack(proposal_width, 1+graph_spec.num_node_ids, [proposal_width], bias_shift=3.0, name="newnodes_proposer_post", dropout_keep=dropout_keep, dropout_input=False) 37 | 38 | @property 39 | def params(self): 40 | return self._proposer_gru.params + self._proposer_stack.params + self._inform_aggregate.params 41 | 42 | def dropout_masks(self, srng): 43 | return self._inform_aggregate.dropout_masks(srng) + self._proposer_gru.dropout_masks(srng) + self._proposer_stack.dropout_masks(srng) 44 | 45 | def get_candidates(self, gstate, input_vector, max_candidates, dropout_masks=Ellipsis): 46 | """ 47 | Get the current candidate new nodes. This is accomplished as follows: 48 | 1. Using the aggregate transformation, we gather information from nodes (who should have performed 49 | a state update already) 50 | 1. The proposer network, conditioned on the input and info, proposes multiple candidate nodes, 51 | along with a confidence 52 | 3. A new node is created for each candidate node, with an existence strength given by 53 | confidence, and an initial id as proposed 54 | This method directly returns these new nodes for comparision 55 | 56 | Params: 57 | gstate: A GraphState giving the current state 58 | input_vector: A tensor of the form (n_batch, input_width) 59 | max_candidates: Integer, limit on the number of candidates to produce 60 | 61 | Returns: 62 | new_strengths: A tensor of the form (n_batch, new_node_idx) 63 | new_ids: A tensor of the form (n_batch, new_node_idx, num_node_ids) 64 | """ 65 | if dropout_masks is Ellipsis: 66 | dropout_masks = None 67 | append_masks = False 68 | else: 69 | append_masks = True 70 | 71 | n_batch = gstate.n_batch 72 | n_nodes = gstate.n_nodes 73 | 74 | aggregated_repr, dropout_masks = self._inform_aggregate.process(gstate, dropout_masks) 75 | # aggregated_repr is of shape (n_batch, inform_width) 76 | 77 | full_input = T.concatenate([input_vector, aggregated_repr],1) 78 | 79 | outputs_info = [self._proposer_gru.initial_state(n_batch)] 80 | gru_dropout_masks, dropout_masks = self._proposer_gru.split_dropout_masks(dropout_masks) 81 | proposer_step = lambda st,ipt,*dm: self._proposer_gru.step(ipt, st, dm if dropout_masks is not None else None)[0] 82 | raw_proposal_acts, _ = theano.scan(proposer_step, n_steps=max_candidates, non_sequences=[full_input]+gru_dropout_masks, outputs_info=outputs_info) 83 | 84 | # raw_proposal_acts is of shape (candidate, n_batch, blah) 85 | flat_raw_acts = raw_proposal_acts.reshape([-1, self._proposal_width]) 86 | flat_processed_acts, dropout_masks = self._proposer_stack.process(flat_raw_acts, dropout_masks) 87 | candidate_strengths = T.nnet.sigmoid(flat_processed_acts[:,0]).reshape([max_candidates, n_batch]) 88 | candidate_ids = T.nnet.softmax(flat_processed_acts[:,1:]).reshape([max_candidates, n_batch, self._graph_spec.num_node_ids]) 89 | 90 | new_strengths = candidate_strengths.dimshuffle([1,0]) 91 | new_ids = candidate_ids.dimshuffle([1,0,2]) 92 | if append_masks: 93 | return new_strengths, new_ids, dropout_masks 94 | else: 95 | return new_strengths, new_ids 96 | 97 | def process(self, gstate, input_vector, max_candidates, dropout_masks=Ellipsis): 98 | """ 99 | Process an input vector and update the state accordingly. 100 | """ 101 | if dropout_masks is Ellipsis: 102 | dropout_masks = None 103 | append_masks = False 104 | else: 105 | append_masks = True 106 | new_strengths, new_ids, dropout_masks = self.get_candidates(gstate, input_vector, max_candidates, dropout_masks) 107 | new_gstate = gstate.with_additional_nodes(new_strengths, new_ids) 108 | if append_masks: 109 | return new_gstate, dropout_masks 110 | else: 111 | return new_gstate 112 | 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /transformation_modules/new_nodes_vote.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | import numpy as np 4 | 5 | from util import * 6 | from layer import * 7 | from graph_state import GraphState, GraphStateSpec 8 | from base_gru import BaseGRULayer 9 | 10 | class NewNodesVoteTransformation( object ): 11 | """ 12 | Transforms a graph state by adding nodes, conditioned on an input vector 13 | """ 14 | def __init__(self, input_width, proposal_width, graph_spec): 15 | """ 16 | Params: 17 | input_width: Integer giving size of input 18 | proposal_width: Size of internal proposal 19 | graph_spec: Instance of GraphStateSpec giving graph spec 20 | """ 21 | self._input_width = input_width 22 | self._graph_spec = graph_spec 23 | self._proposal_width = proposal_width 24 | 25 | self._proposer_gru = BaseGRULayer(input_width, proposal_width, name="newnodes_proposer") 26 | 27 | self._proposer_stack = LayerStack(proposal_width, 1+graph_spec.num_node_ids, [proposal_width], bias_shift=3.0, name="newnodes_proposer_post") 28 | isize = 2*graph_spec.num_node_ids + graph_spec.node_state_size 29 | self._vote_stack = LayerStack(isize, 1, [isize], activation=T.nnet.sigmoid, bias_shift=-3.0, name="newnodes_vote") 30 | 31 | @property 32 | def params(self): 33 | return self._proposer_gru.params + self._proposer_stack.params + self._vote_stack.params 34 | 35 | @property 36 | def num_dropout_masks(self): 37 | return self._proposer_gru.num_dropout_masks 38 | 39 | def get_dropout_masks(self, srng, keep_frac): 40 | return self._proposer_gru.get_dropout_masks(srng, keep_frac) 41 | 42 | def get_candidates(self, gstate, input_vector, max_candidates, dropout_masks=None): 43 | """ 44 | Get the current candidate new nodes. This is accomplished as follows: 45 | 1. The proposer network, conditioned on the input vector, proposes multiple candidate nodes, 46 | along with a confidence 47 | 2. Every existing node, conditioned on its own state and the candidate, votes on whether or not 48 | to accept this node 49 | 3. A new node is created for each candidate node, with an existence strength given by 50 | confidence * [product of all votes], and an initial state state as proposed 51 | This method directly returns these new nodes for comparision 52 | 53 | Params: 54 | gstate: A GraphState giving the current state 55 | input_vector: A tensor of the form (n_batch, input_width) 56 | max_candidates: Integer, limit on the number of candidates to produce 57 | 58 | Returns: 59 | new_strengths: A tensor of the form (n_batch, new_node_idx) 60 | new_ids: A tensor of the form (n_batch, new_node_idx, num_node_ids) 61 | """ 62 | n_batch = gstate.n_batch 63 | n_nodes = gstate.n_nodes 64 | outputs_info = [self._proposer_gru.initial_state(n_batch)] 65 | proposer_step = lambda st,ipt,*dm: self._proposer_gru.step(ipt,st,dm if dropout_masks is not None else None) 66 | raw_proposal_acts, _ = theano.scan(proposer_step, n_steps=max_candidates, non_sequences=[input_vector]+(dropout_masks if dropout_masks is not None else []), outputs_info=outputs_info) 67 | 68 | # raw_proposal_acts is of shape (candidate, n_batch, blah) 69 | flat_raw_acts = raw_proposal_acts.reshape([-1, self._proposal_width]) 70 | flat_processed_acts = self._proposer_stack.process(flat_raw_acts) 71 | candidate_strengths = T.nnet.sigmoid(flat_processed_acts[:,0]).reshape([max_candidates, n_batch]) 72 | candidate_ids = T.nnet.softmax(flat_processed_acts[:,1:]).reshape([max_candidates, n_batch, self._graph_spec.num_node_ids]) 73 | 74 | # Votes will be of shape (candidate, n_batch, n_nodes) 75 | # To generate this we want to assemble (candidate, n_batch, n_nodes, input_stuff), 76 | # squash to (parallel, input_stuff), do voting op, then unsquash 77 | candidate_id_part = T.shape_padaxis(candidate_ids, 2) 78 | node_id_part = T.shape_padaxis(gstate.node_ids, 0) 79 | node_state_part = T.shape_padaxis(gstate.node_states, 0) 80 | full_vote_input = broadcast_concat([node_id_part, node_state_part, candidate_id_part], 3) 81 | flat_vote_input = full_vote_input.reshape([-1, full_vote_input.shape[-1]]) 82 | vote_result = self._vote_stack.process(flat_vote_input) 83 | final_votes_no = vote_result.reshape([max_candidates, n_batch, n_nodes]) 84 | weighted_votes_yes = 1 - final_votes_no * T.shape_padleft(gstate.node_strengths) 85 | # Add in the strength vote 86 | all_votes = T.concatenate([T.shape_padright(candidate_strengths), weighted_votes_yes], 2) 87 | # Take the product -> (candidate, n_batch) 88 | chosen_strengths = T.prod(all_votes, 2) 89 | 90 | new_strengths = chosen_strengths.dimshuffle([1,0]) 91 | new_ids = candidate_ids.dimshuffle([1,0,2]) 92 | return new_strengths, new_ids 93 | 94 | def process(self, gstate, input_vector, max_candidates, dropout_masks=None): 95 | """ 96 | Process an input vector and update the state accordingly. 97 | """ 98 | new_strengths, new_ids = self.get_candidates(gstate, input_vector, max_candidates, dropout_masks) 99 | new_gstate = gstate.with_additional_nodes(new_strengths, new_ids) 100 | return new_gstate 101 | 102 | 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /transformation_modules/node_state_update.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | import numpy as np 4 | 5 | from util import * 6 | from graph_state import GraphState, GraphStateSpec 7 | from base_gru import BaseGRULayer 8 | 9 | class NodeStateUpdateTransformation( object ): 10 | """ 11 | Transforms a graph state by updating note states, conditioned on an input vector 12 | """ 13 | def __init__(self, input_width, graph_spec, dropout_keep=1): 14 | """ 15 | Params: 16 | input_width: Integer giving size of input 17 | graph_spec: Instance of GraphStateSpec giving graph spec 18 | """ 19 | self._input_width = input_width 20 | self._graph_spec = graph_spec 21 | 22 | self._update_gru = BaseGRULayer(input_width + graph_spec.num_node_ids, graph_spec.node_state_size, name="nodestateupdate", dropout_keep=dropout_keep, dropout_input=False, dropout_output=True) 23 | 24 | @property 25 | def params(self): 26 | return self._update_gru.params 27 | 28 | def dropout_masks(self, srng, state_mask=None): 29 | return self._update_gru.dropout_masks(srng, use_output=state_mask) 30 | 31 | def process(self, gstate, input_vector, dropout_masks=Ellipsis): 32 | """ 33 | Process an input vector and update the state accordingly. Each node runs a GRU step 34 | with previous state from the node state and input from the vector. 35 | 36 | Params: 37 | gstate: A GraphState giving the current state 38 | input_vector: A tensor of the form (n_batch, input_width) 39 | """ 40 | 41 | # gstate.node_states is of shape (n_batch, n_nodes, node_state_width) 42 | # input_vector should be broadcasted to match this 43 | if dropout_masks is Ellipsis: 44 | dropout_masks = None 45 | append_masks = False 46 | else: 47 | append_masks = True 48 | prepped_input_vector = T.tile(T.shape_padaxis(input_vector, 1), [1, gstate.n_nodes, 1]) 49 | full_input = T.concatenate([gstate.node_ids, prepped_input_vector], 2) 50 | 51 | # we flatten to apply GRU 52 | flat_input = full_input.reshape([-1, self._input_width + self._graph_spec.num_node_ids]) 53 | flat_state = gstate.node_states.reshape([-1, self._graph_spec.node_state_size]) 54 | new_flat_state, dropout_masks = self._update_gru.step(flat_input, flat_state, dropout_masks) 55 | 56 | new_node_states = new_flat_state.reshape(gstate.node_states.shape) 57 | 58 | new_gstate = gstate.with_updates(node_states=new_node_states) 59 | if append_masks: 60 | return new_gstate, dropout_masks 61 | else: 62 | return new_gstate 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /transformation_modules/output_category.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | import numpy as np 4 | 5 | from util import * 6 | from layer import * 7 | from graph_state import GraphState, GraphStateSpec 8 | 9 | class OutputCategoryTransformation( object ): 10 | """ 11 | Transforms a representation vector into a single categorical output 12 | """ 13 | def __init__(self, input_width, num_categories): 14 | self._input_width = input_width 15 | self._num_categories = num_categories 16 | 17 | self._transform_stack = LayerStack(input_width, num_categories, activation=T.nnet.softmax, name="output_category") 18 | 19 | @property 20 | def params(self): 21 | return self._transform_stack.params 22 | 23 | def process(self, input_vector): 24 | """ 25 | Convert an input vector into a categorical distribution across num_categories categories 26 | 27 | Params: 28 | input_vector: Vector of shape (n_batch, input_width) 29 | 30 | Returns: Categorical distribution of shape (n_batch, 1, num_categories), such that it sums to 1 across 31 | all categories for each instance in the batch 32 | """ 33 | transformed = self._transform_stack.process(input_vector) 34 | return T.shape_padaxis(transformed,1) 35 | 36 | def snap_to_best(self, answer): 37 | """ 38 | Convert output of process to the "best" answer, i.e. the answer with highest probability. 39 | """ 40 | return categorical_best(answer) 41 | -------------------------------------------------------------------------------- /transformation_modules/output_sequence.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | import numpy as np 4 | 5 | from util import * 6 | from layer import * 7 | from graph_state import GraphState, GraphStateSpec 8 | from base_gru import BaseGRULayer 9 | 10 | 11 | class OutputSequenceTransformation( object ): 12 | """ 13 | Transforms a representation vector into a sequence of outputs 14 | """ 15 | def __init__(self, input_width, state_size, num_words): 16 | self._input_width = input_width 17 | self._state_size = state_size 18 | self._num_words = num_words 19 | 20 | self._seq_gru = BaseGRULayer(input_width, state_size, name="output_seq_gru") 21 | self._transform_stack = LayerStack(state_size, num_words, activation=T.nnet.softmax, name="output_seq_transf") 22 | 23 | @property 24 | def params(self): 25 | return self._seq_gru.params + self._transform_stack.params 26 | 27 | def process(self, input_vector, seq_len): 28 | """ 29 | Convert an input vector into a sequence of categorical distributions 30 | 31 | Params: 32 | input_vector: Vector of shape (n_batch, input_width) 33 | seq_len: How many outputs to produce 34 | 35 | Returns: Sequence distribution of shape (n_batch, seq_len, num_words) 36 | """ 37 | n_batch = input_vector.shape[0] 38 | outputs_info = [self._seq_gru.initial_state(n_batch)] 39 | scan_step = lambda state, ipt: self._seq_gru.step(ipt, state) 40 | all_out, _ = theano.scan(scan_step, non_sequences=[input_vector], n_steps=seq_len, outputs_info=outputs_info) 41 | 42 | # all_out is of shape (seq_len, n_batch, state_size). Squash and apply layer 43 | flat_out = all_out.reshape([-1, self._state_size]) 44 | flat_final = self._transform_stack.process(flat_out) 45 | final = flat_final.reshape([seq_len, n_batch, self._num_words]).dimshuffle([1,0,2]) 46 | 47 | return final 48 | 49 | def snap_to_best(self, answer): 50 | """ 51 | Convert output of process to the "best" answer, i.e. the answer with highest probability. 52 | """ 53 | return categorical_best(answer) 54 | -------------------------------------------------------------------------------- /transformation_modules/output_set.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | import numpy as np 4 | 5 | from util import * 6 | from layer import * 7 | from graph_state import GraphState, GraphStateSpec 8 | 9 | class OutputSetTransformation( object ): 10 | """ 11 | Transforms a representation vector into an independent set output 12 | """ 13 | def __init__(self, input_width, num_categories): 14 | self._input_width = input_width 15 | self._num_categories = num_categories 16 | 17 | self._transform_stack = LayerStack(input_width, num_categories, activation=T.nnet.sigmoid, name="output_set") 18 | 19 | @property 20 | def params(self): 21 | return self._transform_stack.params 22 | 23 | def process(self, input_vector): 24 | """ 25 | Convert an input vector into a probabilistic set, i.e. a list of probabilities of item i being in 26 | the output set. 27 | 28 | Params: 29 | input_vector: Vector of shape (n_batch, input_width) 30 | 31 | Returns: Set distribution of shape (n_batch, 1, num_categories), where each value is independent from 32 | the others. 33 | """ 34 | transformed = self._transform_stack.process(input_vector) 35 | return T.shape_padaxis(transformed,1) 36 | 37 | def snap_to_best(self, answer): 38 | """ 39 | Convert output of process to the "best" answer, i.e. the answer with highest probability. 40 | """ 41 | return independent_best(answer) 42 | -------------------------------------------------------------------------------- /transformation_modules/propagation.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | import numpy as np 4 | 5 | from util import * 6 | from layer import * 7 | from graph_state import GraphState, GraphStateSpec 8 | from base_gru import BaseGRULayer 9 | 10 | class PropagationTransformation( object ): 11 | """ 12 | Transforms a graph state by propagating info across the graph 13 | """ 14 | def __init__(self, transfer_size, graph_spec, transfer_activation=identity, dropout_keep=1): 15 | """ 16 | Params: 17 | transfer_size: Integer, how much to transfer 18 | graph_spec: Instance of GraphStateSpec giving graph spec 19 | transfer_activation: Activation function to use during transfer 20 | """ 21 | self._transfer_size = transfer_size 22 | self._transfer_activation = transfer_activation 23 | self._graph_spec = graph_spec 24 | self._process_input_size = graph_spec.num_node_ids + graph_spec.node_state_size 25 | 26 | self._transfer_stack = LayerStack(self._process_input_size, 2 * graph_spec.num_edge_types * transfer_size, activation=self._transfer_activation, name="propagation_transfer", dropout_keep=dropout_keep, dropout_input=False, dropout_output=True) 27 | self._propagation_gru = BaseGRULayer(graph_spec.num_node_ids + self._transfer_size, graph_spec.node_state_size, name="propagation", dropout_keep=dropout_keep, dropout_input=False, dropout_output=True) 28 | 29 | @property 30 | def params(self): 31 | return self._propagation_gru.params + self._transfer_stack.params 32 | 33 | def dropout_masks(self, srng, state_mask=None): 34 | return self._transfer_stack.dropout_masks(srng) + self._propagation_gru.dropout_masks(srng, use_output=state_mask) 35 | 36 | def split_dropout_masks(self, dropout_masks): 37 | transfer_used, dropout_masks = self._transfer_stack.split_dropout_masks(dropout_masks) 38 | gru_used, dropout_masks = self._propagation_gru.split_dropout_masks(dropout_masks) 39 | return (transfer_used+gru_used), dropout_masks 40 | 41 | def process(self, gstate, dropout_masks=Ellipsis): 42 | """ 43 | Process a graph state. 44 | 1. Data is transfered from each node to each other node along both forward and backward edges. 45 | This data is processed with a Wx+b style update, and an optional transformation is applied 46 | 2. Nodes sum the transfered data, weighted by the existence of the other node and the edge. 47 | 3. Nodes perform a GRU update with this input 48 | 49 | Params: 50 | gstate: A GraphState giving the current state 51 | """ 52 | if dropout_masks is Ellipsis: 53 | dropout_masks = None 54 | append_masks = False 55 | else: 56 | append_masks = True 57 | 58 | node_obs = T.concatenate([gstate.node_ids, gstate.node_states],2) 59 | flat_node_obs = node_obs.reshape([-1, self._process_input_size]) 60 | transformed, dropout_masks = self._transfer_stack.process(flat_node_obs,dropout_masks) 61 | transformed = transformed.reshape([gstate.n_batch, gstate.n_nodes, 2*self._graph_spec.num_edge_types, self._transfer_size]) 62 | scaled_transformed = transformed * T.shape_padright(T.shape_padright(gstate.node_strengths)) 63 | # scaled_transformed is of shape (n_batch, n_nodes, 2*num_edge_types, transfer_size) 64 | # We want to multiply through by edge strengths, which are of shape 65 | # (n_batch, n_nodes, n_nodes, num_edge_types), both fwd and backward 66 | edge_strength_scale = T.concatenate([gstate.edge_strengths, gstate.edge_strengths.swapaxes(1,2)], 3) 67 | # edge_strength_scale is of (n_batch, n_nodes, n_nodes, 2*num_edge_types) 68 | intermed = T.shape_padaxis(scaled_transformed, 2) * T.shape_padright(edge_strength_scale) 69 | # intermed is of shape (n_batch, n_nodes "source", n_nodes "dest", 2*num_edge_types, transfer_size) 70 | # now reduce along the "source" and "edge_types" dimensions to get dest activations 71 | # of shape (n_batch, n_nodes, transfer_size) 72 | reduced_result = T.sum(T.sum(intermed, 3), 1) 73 | 74 | # now add information fom current node id 75 | full_input = T.concatenate([gstate.node_ids, reduced_result], 2) 76 | 77 | # we flatten to apply GRU 78 | flat_input = full_input.reshape([-1, self._graph_spec.num_node_ids + self._transfer_size]) 79 | flat_state = gstate.node_states.reshape([-1, self._graph_spec.node_state_size]) 80 | new_flat_state, dropout_masks = self._propagation_gru.step(flat_input, flat_state, dropout_masks) 81 | 82 | new_node_states = new_flat_state.reshape(gstate.node_states.shape) 83 | 84 | new_gstate = gstate.with_updates(node_states=new_node_states) 85 | if append_masks: 86 | return new_gstate, dropout_masks 87 | else: 88 | return new_gstate 89 | 90 | def process_multiple(self, gstate, iterations, dropout_masks=Ellipsis): 91 | """ 92 | Run multiple propagagtion steps. 93 | 94 | Params: 95 | gstate: A GraphState giving the current state 96 | iterations: An integer. How many steps to propagate 97 | """ 98 | if dropout_masks is Ellipsis: 99 | dropout_masks = None 100 | append_masks = False 101 | else: 102 | append_masks = True 103 | 104 | def _scan_step(cur_node_states, node_strengths, node_ids, edge_strengths, *dmasks): 105 | curstate = GraphState(node_strengths, node_ids, cur_node_states, edge_strengths) 106 | newstate, _ = self.process(curstate, dmasks if dropout_masks is not None else None) 107 | return newstate.node_states 108 | 109 | outputs_info = [gstate.node_states] 110 | used_dropout_masks, dropout_masks = self.split_dropout_masks(dropout_masks) 111 | all_node_states, _ = theano.scan(_scan_step, n_steps=iterations, non_sequences=[gstate.node_strengths, gstate.node_ids, gstate.edge_strengths] + used_dropout_masks, outputs_info=outputs_info) 112 | 113 | final_gstate = gstate.with_updates(node_states=all_node_states[-1,:,:,:]) 114 | if append_masks: 115 | return final_gstate, dropout_masks 116 | else: 117 | return final_gstate 118 | 119 | -------------------------------------------------------------------------------- /transformation_modules/sequence_aggregate_summary.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | import numpy as np 4 | 5 | from util import * 6 | from layer import * 7 | from graph_state import GraphState, GraphStateSpec 8 | from base_gru import BaseGRULayer 9 | 10 | class SequenceAggregateSummaryTransformation( object ): 11 | """ 12 | Transforms a sequence of aggregate representation vectors into a summary vector 13 | """ 14 | def __init__(self, input_representation_width, output_representation_width, dropout_keep=1): 15 | self._input_representation_width = input_representation_width 16 | self._output_representation_width = output_representation_width 17 | 18 | self._seq_gru = BaseGRULayer(input_representation_width, output_representation_width, dropout_keep=dropout_keep, name="summary_seq_gru") 19 | 20 | @property 21 | def params(self): 22 | return self._seq_gru.params 23 | 24 | def dropout_masks(self, srng): 25 | return self._seq_gru.dropout_masks(srng) 26 | 27 | def process(self, input_sequence, dropout_masks=Ellipsis): 28 | """ 29 | Convert the sequence of vectors to a summary vector 30 | Params: 31 | input_sequence: A tensor of shape (n_batch, time, input_representation_width) 32 | 33 | Returns: A representation vector of shape (n_batch, output_representation_width) 34 | """ 35 | if dropout_masks is Ellipsis: 36 | dropout_masks = None 37 | append_masks = False 38 | else: 39 | append_masks = True 40 | 41 | n_batch = input_sequence.shape[0] 42 | swapped_input = input_sequence.swapaxes(0,1) 43 | outputs_info = [self._seq_gru.initial_state(n_batch)] 44 | scan_step = lambda ipt, state, *dmasks: self._seq_gru.step(ipt, state, None if dropout_masks is None else dmasks)[0] 45 | all_out, _ = theano.scan(scan_step, sequences=[swapped_input], non_sequences=dropout_masks, outputs_info=outputs_info) 46 | 47 | result = all_out[-1] 48 | 49 | if append_masks: 50 | return result, dropout_masks 51 | else: 52 | return result 53 | -------------------------------------------------------------------------------- /update_cache_compatibility.py: -------------------------------------------------------------------------------- 1 | import model 2 | import util 3 | import pickle 4 | import os 5 | import argparse 6 | import sys 7 | 8 | def main(cache_dir): 9 | files_list = list(os.listdir(cache_dir)) 10 | for file in files_list: 11 | full_filename = os.path.join(cache_dir, file) 12 | if os.path.isfile(full_filename): 13 | print("Processing {}".format(full_filename)) 14 | m, stored_kwargs = pickle.load(open(full_filename, 'rb')) 15 | updated_kwargs = util.get_compatible_kwargs(model.Model, stored_kwargs) 16 | 17 | model_hash = util.object_hash(updated_kwargs) 18 | print("New hash -> " + model_hash) 19 | model_filename = os.path.join(cache_dir, "model_{}.p".format(model_hash)) 20 | sys.setrecursionlimit(100000) 21 | pickle.dump((m,updated_kwargs), open(model_filename,'wb'), protocol=pickle.HIGHEST_PROTOCOL) 22 | 23 | os.remove(full_filename) 24 | 25 | parser = argparse.ArgumentParser(description='Update a model cache directory.') 26 | parser.add_argument('cache_dir', help="Directory of cached models to update") 27 | 28 | if __name__ == '__main__': 29 | args = vars(parser.parse_args()) 30 | main(**args) 31 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import theano 2 | import theano.tensor as T 3 | import numpy as np 4 | import pickle 5 | import hashlib 6 | import json 7 | import enum 8 | import inspect 9 | import os 10 | 11 | import itertools 12 | import collections 13 | 14 | EPSILON = np.array(1e-8, np.float32) 15 | 16 | def identity(x): 17 | return x 18 | 19 | def init_params(shape, stddev=0.1, shift=0.0): 20 | """Get an initial value for a parameter""" 21 | return np.float32(np.random.normal(shift, stddev, shape)) 22 | 23 | def do_layer(activation, ipt, weights, biases): 24 | """ 25 | Perform a layer operation, i.e. out = activ( xW + b ) 26 | activation: An activation function 27 | ipt: Tensor of shape (n_batch, X) 28 | weights: Tensor of shape (X, Y) 29 | biases: Tensor of shape (Y) 30 | 31 | Returns: Tensor of shape (n_batch, Y) 32 | """ 33 | xW = T.dot(ipt, weights) 34 | b = T.shape_padleft(biases) 35 | return activation( xW + b ) 36 | 37 | def broadcast_concat(tensors, axis): 38 | """ 39 | Broadcast tensors together, then concatenate along axis 40 | """ 41 | ndim = tensors[0].ndim 42 | assert all(t.ndim == ndim for t in tensors), "ndims don't match for broadcast_concat: {}".format(tensors) 43 | broadcast_shapes = [] 44 | for i in range(ndim): 45 | if i == axis: 46 | broadcast_shapes.append(1) 47 | else: 48 | dim_size = next((t.shape[i] for t in tensors if not t.broadcastable[i]), 1) 49 | broadcast_shapes.append(dim_size) 50 | broadcasted_tensors = [] 51 | for t in tensors: 52 | tile_reps = [bshape if t.broadcastable[i] else 1 for i,bshape in enumerate(broadcast_shapes)] 53 | if all(rep is 1 for rep in tile_reps): 54 | # Don't need to broadcast this tensor 55 | broadcasted_tensors.append(t) 56 | else: 57 | broadcasted_tensors.append(T.tile(t, tile_reps)) 58 | return T.concatenate(broadcasted_tensors, axis) 59 | 60 | def pad_to(tensor, shape): 61 | """ 62 | Pads tensor to shape with zeros 63 | """ 64 | current = tensor 65 | for i in range(len(shape)): 66 | padding = T.zeros([(fs-ts if i==j else fs if j0.5, 131 | else 0 132 | """ 133 | return T.cast(T.ge(tensor, 0.5), 'floatX') 134 | 135 | def categorical_best(tensor): 136 | """ 137 | tensor should be a tensor of shape (..., categories) 138 | Return a new tensor of the same shape but one-hot at position of best category 139 | """ 140 | flat_tensor = tensor.reshape([-1, tensor.shape[-1]]) 141 | argmax_posns = T.argmax(flat_tensor, 1) 142 | flat_snapped = T.zeros_like(flat_tensor) 143 | flat_snapped = T.set_subtensor(flat_snapped[T.arange(flat_tensor.shape[0]), argmax_posns], 1.0) 144 | snapped = flat_snapped.reshape(tensor.shape) 145 | return snapped 146 | 147 | def make_dropout_mask(shape, keep_frac, srng): 148 | return T.shape_padleft(T.cast(srng.binomial(shape, p=keep_frac), 'float32') / keep_frac) 149 | 150 | def apply_dropout(ipt, dropout): 151 | return ipt * dropout 152 | 153 | def object_hash(thing): 154 | class EnumEncoder(json.JSONEncoder): 155 | def default(self, obj): 156 | if isinstance(obj, enum.Enum): 157 | return obj.name 158 | return super().default(obj) 159 | strform = json.dumps(thing, sort_keys=True, cls=EnumEncoder) 160 | h = hashlib.sha1() 161 | h.update(strform.encode('utf-8')) 162 | return h.hexdigest() 163 | 164 | def get_compatible_kwargs(function, kwargs): 165 | kwargs = dict(kwargs) 166 | sig = inspect.signature(function) 167 | for param in sig.parameters.values(): 168 | if param.name not in kwargs: 169 | if param.default is inspect.Parameter.empty: 170 | raise TypeError("kwargs missing required argument '{}'".format(param.name)) 171 | else: 172 | kwargs[param.name] = param.default 173 | return kwargs 174 | 175 | def find_recent_params(outputdir): 176 | files_list = list(os.listdir(outputdir)) 177 | numbers = [int(x[6:-2]) for x in files_list if x[:6]=="params" and x[-2:]==".p"] 178 | if len(numbers) == 0: 179 | return None 180 | most_recent = max(numbers) 181 | return most_recent, os.path.join(outputdir,"params{}.p".format(most_recent)) 182 | --------------------------------------------------------------------------------