├── .gitignore ├── .gitmodules ├── .travis.yml ├── LICENSE.txt ├── README.md ├── checkpoints ├── _init.sh ├── cbow.ckpt_best ├── cbow.log ├── cbow.sh ├── rnn.ckpt_best ├── rnn.log ├── rnn.sh ├── spinn.ckpt_best ├── spinn.log ├── spinn.sh ├── spinn_gru.sh ├── spinn_pi.ckpt_best ├── spinn_pi.log ├── spinn_pi.sh ├── spinn_pi_nt.ckpt_best ├── spinn_pi_nt.log └── spinn_pi_nt.sh ├── cpp ├── Makefile ├── bin │ ├── rnntest.cc │ └── stacktest.cc ├── kernels.cu ├── kernels.cuh ├── params │ ├── compose_W_l.txt │ ├── compose_W_r.txt │ └── compose_b.txt ├── rnn.cc ├── rnn.h ├── sequence-model.h ├── test.cc ├── thin-stack.cc ├── thin-stack.h ├── util.cc └── util.h ├── evalb_rembed.prm ├── python ├── requirements.txt └── spinn │ ├── __init__.py │ ├── afs_safe_logger.py │ ├── cbow.py │ ├── data │ ├── __init__.py │ ├── boolean │ │ ├── __init__.py │ │ ├── generate_bl_data.py │ │ └── load_boolean_data.py │ ├── snli │ │ ├── __init__.py │ │ └── load_snli_data.py │ └── sst │ │ ├── __init__.py │ │ └── load_sst_data.py │ ├── fat_stack.py │ ├── models │ ├── __init__.py │ ├── classifier.py │ └── fat_classifier.py │ ├── plain_rnn.py │ ├── recurrences.py │ ├── stack.py │ ├── tests │ ├── __init__.py │ ├── test_cuda_util.py │ ├── test_embedding_matrix.5d.txt │ ├── test_plain_rnn.py │ ├── test_stack.py │ └── test_util.py │ └── util │ ├── __init__.py │ ├── blocks.py │ ├── cuda.py │ ├── data.py │ ├── theano_internal.py │ └── variable_store.py ├── scripts ├── 12_5_sst_2s.sh ├── analyze_log.py ├── make_snli_sweep.py ├── make_sst_sweep.py ├── make_theano_patch.sh ├── pick_gpu.py ├── snli_test-snli_1.0_dev.jsonl-parse.lbl ├── train_spinn_classifier.sh └── train_spinn_fat_classifier.sh └── writing ├── MLSemantics.bib ├── acl2016.sty ├── acl_natbib.bst ├── gist └── gist.tex ├── hard_stack_paper ├── batching_fig.tex ├── bowman2016spinn.bib ├── paper.tex └── runtime.tsv ├── model0_fig.tex ├── model1_fig.tex ├── titlepage.pdf ├── titlepage.tex └── tree_attn_fig.tex /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *-data/* 3 | *~ 4 | .env 5 | .DS_Store 6 | *.bbl 7 | *.blg 8 | *.log 9 | *.out 10 | *.aux 11 | *.synctex.gz 12 | *.zip 13 | *.ckpt 14 | ./snli-data 15 | *.parse_log 16 | *.ckpt_best 17 | ./writing/gist/gist.pdf 18 | *.fls 19 | *.fdb_latexmk 20 | !checkpoints/* 21 | 22 | writing/gist/gist.pdf 23 | 24 | *.fls 25 | 26 | *.fdb_latexmk 27 | 28 | writing/gist/gist.fls 29 | 30 | writing/gist/gist.fdb_latexmk 31 | 32 | writing/gist/gist.fls 33 | 34 | writing/hard_stack_paper/paper.pdf 35 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "cpp/lib/googletest"] 2 | path = cpp/lib/googletest 3 | url = https://github.com/google/googletest 4 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | dist: trusty 3 | python: "2.7" 4 | sudo: true 5 | 6 | env: CUDA=7.5-18 THEANO_FLAGS=device=gpu,floatX=float32 7 | 8 | before_install: 9 | # Install CUDA toolkit 10 | - echo "Installing CUDA library" 11 | - travis_retry wget http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1404/x86_64/cuda-repo-ubuntu1404_${CUDA}_amd64.deb 12 | - travis_retry sudo dpkg -i cuda-repo-ubuntu1404_${CUDA}_amd64.deb 13 | - travis_retry sudo apt-get update -qq 14 | - export CUDA_APT=${CUDA%-*} 15 | - export CUDA_APT=${CUDA_APT/./-} 16 | # - travis_retry sudo apt-get install -y cuda-${CUDA_APT} 17 | - travis_retry sudo apt-get install -y nvidia-settings cuda-drivers cuda-core-${CUDA_APT} cuda-cudart-dev-${CUDA_APT} 18 | - travis_retry sudo apt-get clean 19 | - export CUDA_HOME=/usr/local/cuda-${CUDA%%-*} 20 | - export LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} 21 | - export PATH=${CUDA_HOME}/bin:${PATH} 22 | 23 | install: 24 | - nvcc --version 25 | - pip install -r python/requirements.txt 26 | 27 | script: "nosetests -a \\!slow python/spinn/tests" 28 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright 2018, Stanford University 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | *NOTE:* This codebase is under active development. To exactly reproduce the experiments published in ACL 2016, use [this release][7]. For the most recent version, see the [NYU fork](https://github.com/nyu-mll/spinn). 2 | 3 | # Stack-augmented Parser-Interpreter Neural Network 4 | 5 | This repository contains the source code described in our paper [A Fast Unified Model for Sentence Parsing and Understanding][1]. For a more informal introduction to the ideas behind the model, see this [Stanford NLP blog post][8]. 6 | 7 | 8 | There are three separate implementations available: 9 | 10 | - A **Python/Theano** implementation of SPINN using a naïve stack representation (named `fat-stack`) 11 | - A **Python/Theano** implementation of SPINN using the `thin-stack` representation described in our paper 12 | - A **C++/CUDA** implementation of the SPINN feedforward, used for performance testing 13 | 14 | ## Python code 15 | 16 | The Python code lives, quite intuitively, in the `python` folder. We used this code to train and test the SPINN models before publication. 17 | 18 | There is one enormous difference in the `fat-` and `thin-stack` implementations: `fat-stack` uses Theano's automatically generated symbolic backpropagation graphs, while `thin-stack` generates its own optimal backpropagation graph. This makes `thin-stack` oodles faster than its brother, but we have not yet implemented all SPINN variants to support this custom backpropagation. 19 | 20 | ### Installation 21 | 22 | Requirements: 23 | 24 | - Python 2.7 25 | - CUDA >= 7.0 26 | - CuDNN == v4 (v5 is not compatible with our Theano fork) 27 | 28 | Install all required Python dependencies using the command below. (**WARNING:** This installs our custom Theano fork. We recommend installing in a virtual environment in order to avoid overwriting any stock Theano install you already have.) 29 | 30 | pip install -r python/requirements.txt 31 | 32 | We use [a modified version of Theano][3] in order to support fast forward- and backward-prop in `thin-stack`. While it isn't absolutely necessary to use this hacked Theano, it greatly improves `thin-stack` performance. 33 | 34 | Alternatively, you can use a custom Docker image that we've prepared, as discussed in this [CodaLab worksheet](https://worksheets.codalab.org/worksheets/0xa85b2da5365f423d952f800370ebb9b5/). 35 | 36 | ### Running the code 37 | 38 | The easiest way to launch a train/test run is to use one of the [`checkpoints` directory](https://github.com/stanfordnlp/spinn/tree/master/checkpoints). 39 | The Bash scripts in this directory will download the necessary data and launch train/test runs of all models reported in our paper. You can run any of the following scripts: 40 | 41 | ./checkpoints/spinn.sh 42 | ./checkpoints/spinn_pi.sh 43 | ./checkpoints/spinn_pi_nt.sh 44 | ./checkpoints/rnn.sh 45 | 46 | All of the above scripts will by default launch a training run beginning with the recorded parameters of our best models. You can change their behavior using the arguments below: 47 | 48 | $ ./checkpoints/spinn.sh -h 49 | spinn.sh [-h] [-e] [-t] [-s] -- run a train or test run of a SPINN model 50 | 51 | where: 52 | -h show this help text 53 | -e run in eval-only mode (evaluates on dev set by default) 54 | -t evaluate on test set 55 | -s skip the checkpoint loading; run with a randomly initialized model 56 | 57 | To evaluate our best SPINN-PI-NT model on the test set, for example, run 58 | 59 | $ ./checkpoints/spinn_pi_nt.sh -e -t 60 | Running command: 61 | python -m spinn.models.fat_classifier --data_type snli --embedding_data_path ../glove/glove.840B.300d.txt --log_path ../logs --training_data_path ../snli_1.0/snli_1.0_train.jsonl --experiment_name spinn_pi_nt --expanded_eval_only --eval_data_path ../snli_1.0/snli_1.0_test.jsonl --ckpt_path spinn_pi_nt.ckpt_best --batch_size 32 --embedding_keep_rate 0.828528124124 --eval_seq_length 50 --init_range 0.005 --l2_lambda 3.45058959758e-06 --learning_rate 0.000297682444894 --model_dim 600 --model_type Model0 --noconnect_tracking_comp --num_sentence_pair_combination_layers 2 --semantic_classifier_keep_rate 0.9437038157 --seq_length 50 --tracking_lstm_hidden_dim 57 --use_tracking_lstm --word_embedding_dim 300 62 | ... 63 | [1] Checkpointed model was trained for 156500 steps. 64 | [1] Building forward pass. 65 | [1] Writing eval output for ../snli_1.0/snli_1.0_test.jsonl. 66 | [1] Written gold parses in spinn_pi_nt-snli_1.0_test.jsonl-parse.gld 67 | [1] Written predicted parses in spinn_pi_nt-snli_1.0_test.jsonl-parse.tst 68 | [1] Step: 156500 Eval acc: 0.808734 0.000000 ../snli_1.0/snli_1.0_test.jsonl 69 | 70 | #### Custom model configurations 71 | 72 | The main executable for the SNLI experiments in the paper is [fat_classifier.py](https://github.com/stanfordnlp/spinn/blob/master/python/spinn/models/fat_classifier.py), whose flags specify the hyperparameters of the model. You may also need to set Theano flags through the THEANO_FLAGS environment variable, which specifies compilation mode (set it to `fast_compile` during development, and delete it to use the default state for longer runs), `device`, which can be set to `cpu` or `gpu#`, and `cuda.root`, which specifies the location of CUDA when running on GPU. `floatX` should always be set to `float32`. 73 | 74 | Here's a sample command that runs a fast, low-dimensional CPU training run, training and testing only on the dev set. It assumes that you have a copy of [SNLI](http://nlp.stanford.edu/projects/snli/) available locally. 75 | 76 | PYTHONPATH=spinn/python \ 77 | THEANO_FLAGS=optimizer=fast_compile,device=cpu,floatX=float32 \ 78 | python2.7 -m spinn.models.fat_classifier --data_type snli \ 79 | --training_data_path snli_1.0/snli_1.0_dev.jsonl \ 80 | --eval_data_path snli_1.0/snli_1.0_dev.jsonl \ 81 | --embedding_data_path spinn/python/spinn/tests/test_embedding_matrix.5d.txt \ 82 | --word_embedding_dim 5 --model_dim 10 83 | 84 | For full runs, you'll also need a copy of the 840B word 300D [GloVe word vectors](http://nlp.stanford.edu/projects/glove/). 85 | 86 | ## C++ code 87 | 88 | The C++ code lives in the `cpp` folder. This code implements a basic SPINN feedforward. (This implementation corresponds to the bare SPINN-PI-NT, "parsed input / no tracking" model, described in the paper.) It has been verified to produce the exact same output as a recursive neural network with the same weights and inputs. (We used a simplified version of Ozan Irsoy's [`deep-recursive` project][5] as a comparison.) 89 | 90 | The main binary, `stacktest`, simply generates random input data and runs a feedforward. It outputs the total feedforward time elapsed and the numerical result of the feedforward. 91 | 92 | ### Dependencies 93 | 94 | The only external dependency of the C++ code is CUDA >=7.0. The tests depend on the [`googletest` library][4], included in this repository as a Git submodule. 95 | 96 | ### Installation 97 | 98 | First install CUDA >=7.0 and ensure that `nvcc` is on your `PATH`. Then: 99 | 100 | # From project root 101 | cd cpp 102 | 103 | # Pull down Git submodules (libraries) 104 | git submodule update --init 105 | 106 | # Compile C++ code 107 | make stacktest 108 | make rnntest 109 | 110 | This should generate a binary in `cpp/bin/stacktest`. 111 | 112 | ### Running 113 | 114 | The binary `cpp/bin/stacktest` runs on random input data. You can time the feedforward yourself by running the following commands: 115 | 116 | # From project root 117 | cd cpp 118 | 119 | BATCH_SIZE=512 ./bin/stacktest 120 | 121 | You can of course set `BATCH_SIZE` to whatever integer you desire. The other model architecture parameters are fixed in the code, but you can easily change them as well [on this line][6] if you desire. 122 | 123 | #### Baseline RNN 124 | 125 | The binary `cpp/bin/rnntest` runs a vanilla RNN (ReLU activations) with random input data. You can run this performance test script as follows: 126 | 127 | # From project root 128 | cd cpp 129 | 130 | BATCH_SIZE=512 ./bin/rnntest 131 | 132 | ## License 133 | 134 | Copyright 2018, Stanford University 135 | 136 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 137 | 138 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 139 | 140 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 141 | 142 | 143 | 144 | [1]: http://arxiv.org/abs/1603.06021 145 | [2]: https://github.com/stanfordnlp/spinn/blob/master/requirements.txt 146 | [3]: https://github.com/hans/theano-hacked/tree/8964f10e44bcd7f21ae74ea7cdc3682cc7d3258e 147 | [4]: https://github.com/google/googletest 148 | [5]: https://github.com/oir/deep-recursive 149 | [6]: https://github.com/stanfordnlp/spinn/blob/5d4257f4cd15cf7213d2ff87f6f3d7f6716e2ea1/cpp/bin/stacktest.cc#L33 150 | [7]: https://github.com/stanfordnlp/spinn/releases/tag/ACL2016 151 | [8]: http://nlp.stanford.edu/blog/hybrid-tree-sequence-neural-networks-with-spinn/ 152 | -------------------------------------------------------------------------------- /checkpoints/_init.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Header script to prepare to perform a SPINN SNLI run. 4 | 5 | function show_help { 6 | echo "$(basename "$0") [-h] [-e] [-t] [-s] -- run a train or test run of a SPINN model 7 | 8 | where: 9 | -h show this help text 10 | -e run in eval-only mode (evaluates on dev set by default) 11 | -t evaluate on test set 12 | -s skip the checkpoint loading; run with a randomly initialized model" 13 | exit 14 | } 15 | 16 | # Parse arguments 17 | eval_only=false 18 | test_set=false 19 | skip_ckpt=false 20 | while [[ $# > 0 ]]; do 21 | case "$1" in 22 | -e) 23 | eval_only=true 24 | ;; 25 | -t) 26 | test_set=true 27 | ;; 28 | -s) 29 | skip_ckpt=true 30 | ;; 31 | -h|--help) 32 | show_help 33 | ;; 34 | *) 35 | ;; 36 | esac 37 | shift 38 | done 39 | 40 | # The directory where logs should be stored. 41 | export LOG_DIR=../logs 42 | mkdir -p $LOG_DIR 43 | 44 | # The path to pretrained embeddings. 45 | export EMBEDDING_PATH=../glove/glove.840B.300d.txt 46 | if [ ! -e "$EMBEDDING_PATH" ]; then 47 | echo "Could not find GloVe embeddings at $EMBEDDING_PATH." 48 | read -p "Should we download them? (~2G download) (y/n) " yn 49 | if echo "$yn" | grep -iq "[^yY]"; then 50 | exit 1 51 | fi 52 | 53 | mkdir -p `dirname "$EMBEDDING_PATH"` 54 | wget http://nlp.stanford.edu/data/glove.840B.300d.zip -O glove.zip \ 55 | || (echo "Failed to download GloVe embeddings." >&2 && exit 1) 56 | 57 | tar -xvC `dirname "$EMBEDDING_PATH"` -f glove.zip \ 58 | || (echo "Failed to extract GloVe embeddings." >&2 && exit 1) 59 | rm glove.zip 60 | fi 61 | 62 | # Prepare SNLI data. 63 | export SNLI_DIR=../snli_1.0 64 | export SNLI_TRAIN_JSONL=$SNLI_DIR/snli_1.0_train.jsonl 65 | export SNLI_DEV_JSONL=$SNLI_DIR/snli_1.0_dev.jsonl 66 | export SNLI_TEST_JSONL=$SNLI_DIR/snli_1.0_test.jsonl 67 | if [ ! -d "$SNLI_DIR" ]; then 68 | echo "Downloading SNLI data." >&2 69 | wget http://nlp.stanford.edu/projects/snli/snli_1.0.zip -O snli_1.0.zip \ 70 | || (echo "Failed to download SNLI data." >&2 && exit 1) 71 | unzip -d .. snli_1.0.zip && rm snli_1.0.zip 72 | fi 73 | 74 | export PYTHONPATH=../python 75 | export THEANO_FLAGS="allow_gc=False,cuda.root=/usr/bin/cuda,warn_float64=warn,device=gpu,floatX=float32,$THEANO_FLAGS" 76 | echo "THEANO_FLAGS: $THEANO_FLAGS" 77 | 78 | flags="--data_type snli --embedding_data_path $EMBEDDING_PATH --log_path $LOG_DIR --training_data_path $SNLI_TRAIN_JSONL --experiment_name $MODEL" 79 | if [ "$eval_only" = true ]; then 80 | flags="$flags --expanded_eval_only" 81 | fi 82 | if [ "$test_set" = true ]; then 83 | flags="$flags --eval_data_path $SNLI_TEST_JSONL" 84 | else 85 | flags="$flags --eval_data_path $SNLI_DEV_JSONL" 86 | fi 87 | if [ ! "$skip_ckpt" = "true" ]; then 88 | flags="$flags --ckpt_path ${MODEL}.ckpt_best" 89 | fi 90 | export BASE_FLAGS="$flags" 91 | -------------------------------------------------------------------------------- /checkpoints/cbow.ckpt_best: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordnlp/spinn/f2f95f585328ec5cd7da071753f5e0582e7600a0/checkpoints/cbow.ckpt_best -------------------------------------------------------------------------------- /checkpoints/cbow.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | if [ -e README.md ]; then 4 | cd checkpoints 5 | fi 6 | 7 | export MODEL="cbow" 8 | source _init.sh 9 | 10 | # RNN 11 | export REMBED_FLAGS=" --batch_size 32 --eval_seq_length 25 --init_range 0.005 --l2_lambda 1.24280631663e-07 --learning_rate 0.00829688998827 --model_dim 300 --model_type CBOW --num_sentence_pair_combination_layers 2 --semantic_classifier_keep_rate 0.88010692672 --seq_length 25 --tracking_lstm_hidden_dim 33 --word_embedding_dim 300" 12 | 13 | echo "Running command:" 14 | echo "python -m spinn.models.fat_classifier $BASE_FLAGS $REMBED_FLAGS" 15 | python -m spinn.models.fat_classifier $BASE_FLAGS $REMBED_FLAGS 16 | 17 | -------------------------------------------------------------------------------- /checkpoints/rnn.ckpt_best: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordnlp/spinn/f2f95f585328ec5cd7da071753f5e0582e7600a0/checkpoints/rnn.ckpt_best -------------------------------------------------------------------------------- /checkpoints/rnn.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | if [ -e README.md ]; then 4 | cd checkpoints 5 | fi 6 | 7 | export MODEL="rnn" 8 | source _init.sh 9 | 10 | # RNN 11 | export REMBED_FLAGS=" --batch_size 32 --embedding_keep_rate 0.852564448733 --eval_seq_length 25 --init_range 0.005 --l2_lambda 4.42556134893e-06 --learning_rate 0.00464868093302 --model_dim 600 --model_type RNN --num_sentence_pair_combination_layers 2 --semantic_classifier_keep_rate 0.883392584372 --seq_length 25 --tracking_lstm_hidden_dim 33 --word_embedding_dim 300" 12 | 13 | echo "Running command:" 14 | echo "python -m spinn.models.fat_classifier $BASE_FLAGS $REMBED_FLAGS" 15 | python -m spinn.models.fat_classifier $BASE_FLAGS $REMBED_FLAGS 16 | 17 | -------------------------------------------------------------------------------- /checkpoints/spinn.ckpt_best: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordnlp/spinn/f2f95f585328ec5cd7da071753f5e0582e7600a0/checkpoints/spinn.ckpt_best -------------------------------------------------------------------------------- /checkpoints/spinn.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | if [ -e README.md ]; then 4 | cd checkpoints 5 | fi 6 | 7 | export MODEL="spinn" 8 | source _init.sh 9 | 10 | # SPINN 11 | export REMBED_FLAGS=" --batch_size 32 --connect_tracking_comp --embedding_keep_rate 0.938514416034 --eval_seq_length 50 --init_range 0.005 --l2_lambda 2.76018187539e-05 --learning_rate 0.00103428201391 --model_dim 600 --model_type Model1 --num_sentence_pair_combination_layers 1 --semantic_classifier_keep_rate 0.949455648614 --seq_length 50 --tracking_lstm_hidden_dim 44 --transition_cost_scale 0.605159568546 --use_tracking_lstm --word_embedding_dim 300 --predict_use_cell" 12 | 13 | echo "Running command:" 14 | echo "python -m spinn.models.fat_classifier $BASE_FLAGS $REMBED_FLAGS" 15 | python -m spinn.models.fat_classifier $BASE_FLAGS $REMBED_FLAGS 16 | 17 | -------------------------------------------------------------------------------- /checkpoints/spinn_gru.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | if [ -e README.md ]; then 4 | cd checkpoints 5 | fi 6 | 7 | export MODEL="spinn_gru" 8 | source _init.sh 9 | 10 | # SPINN 11 | export REMBED_FLAGS=" --use_gru --batch_size 32 --connect_tracking_comp --embedding_keep_rate 0.938514416034 --eval_seq_length 50 --init_range 0.005 --l2_lambda 2.76018187539e-05 --learning_rate 0.00103428201391 --model_dim 600 --model_type Model1 --num_sentence_pair_combination_layers 1 --semantic_classifier_keep_rate 0.949455648614 --seq_length 50 --tracking_lstm_hidden_dim 0 --transition_cost_scale 0.605159568546 --nouse_tracking_lstm --word_embedding_dim 300 --nopredict_use_cell --noconnect_tracking_comp" 12 | 13 | echo "Running command:" 14 | echo "python -m spinn.models.fat_classifier $BASE_FLAGS $REMBED_FLAGS" 15 | python -m spinn.models.fat_classifier $BASE_FLAGS $REMBED_FLAGS 16 | 17 | -------------------------------------------------------------------------------- /checkpoints/spinn_pi.ckpt_best: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordnlp/spinn/f2f95f585328ec5cd7da071753f5e0582e7600a0/checkpoints/spinn_pi.ckpt_best -------------------------------------------------------------------------------- /checkpoints/spinn_pi.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | if [ -e README.md ]; then 4 | cd checkpoints 5 | fi 6 | 7 | export MODEL="spinn_pi" 8 | source _init.sh 9 | 10 | # The invocation below will load the pretrained models and continue training 11 | # by default. Use the flag --expanded_eval_only mode to do eval-only runs, and delete the flag --ckpt_path ... to train from scratch. 12 | 13 | # SPINN-PI 14 | export REMBED_FLAGS=" --batch_size 32 --connect_tracking_comp --embedding_keep_rate 0.917969380132 --eval_seq_length 50 --init_range 0.005 --l2_lambda 2.00098223698e-05 --learning_rate 0.00701855401337 --model_dim 600 --model_type Model0 --num_sentence_pair_combination_layers 2 --semantic_classifier_keep_rate 0.934741728838 --seq_length 50 --tracking_lstm_hidden_dim 61 --use_tracking_lstm --word_embedding_dim 300" 15 | 16 | echo "Running command:" 17 | echo "python -m spinn.models.fat_classifier $BASE_FLAGS $REMBED_FLAGS" 18 | python -m spinn.models.fat_classifier $BASE_FLAGS $REMBED_FLAGS 19 | 20 | 21 | -------------------------------------------------------------------------------- /checkpoints/spinn_pi_nt.ckpt_best: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordnlp/spinn/f2f95f585328ec5cd7da071753f5e0582e7600a0/checkpoints/spinn_pi_nt.ckpt_best -------------------------------------------------------------------------------- /checkpoints/spinn_pi_nt.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | if [ -e README.md ]; then 4 | cd checkpoints 5 | fi 6 | 7 | export MODEL=spinn_pi_nt 8 | source _init.sh 9 | 10 | # SPINN-PI-NT 11 | export REMBED_FLAGS=" --batch_size 32 --embedding_keep_rate 0.828528124124 --eval_seq_length 50 --init_range 0.005 --l2_lambda 3.45058959758e-06 --learning_rate 0.000297682444894 --model_dim 600 --model_type Model0 --noconnect_tracking_comp --num_sentence_pair_combination_layers 2 --semantic_classifier_keep_rate 0.9437038157 --seq_length 50 --tracking_lstm_hidden_dim 57 --use_tracking_lstm --word_embedding_dim 300 " 12 | 13 | echo "Running command:" 14 | echo "python -m spinn.models.fat_classifier $BASE_FLAGS $REMBED_FLAGS" 15 | python -m spinn.models.fat_classifier $BASE_FLAGS $REMBED_FLAGS 16 | 17 | -------------------------------------------------------------------------------- /cpp/Makefile: -------------------------------------------------------------------------------- 1 | GTEST_DIR = lib/googletest/googletest 2 | GMOCK_DIR = lib/googletest/googlemock 3 | CPPFLAGS += -isystem $(GTEST_DIR)/include -isystem $(GMOCK_DIR)/include 4 | CXXFLAGS += -g -Wall -Wextra -pthread 5 | 6 | ###### 7 | # gtest directives 8 | GTEST_SRCS_ = $(GTEST_DIR)/src/*.cc $(GTEST_DIR)/src/*.h $(GTEST_HEADERS) 9 | gtest-all.o : $(GTEST_SRCS_) 10 | $(CXX) $(CPPFLAGS) -I$(GTEST_DIR) $(CXXFLAGS) -c \ 11 | $(GTEST_DIR)/src/gtest-all.cc 12 | 13 | gtest_main.o : $(GTEST_SRCS_) 14 | $(CXX) $(CPPFLAGS) -I$(GTEST_DIR) $(CXXFLAGS) -c \ 15 | $(GTEST_DIR)/src/gtest_main.cc 16 | 17 | gtest.a : gtest-all.o 18 | $(AR) $(ARFLAGS) $@ $^ 19 | 20 | gtest_main.a : gtest-all.o gtest_main.o 21 | $(AR) $(ARFLAGS) $@ $^ 22 | 23 | GMOCK_SRCS_ = $(GMOCK_DIR)/src/*.cc $(GMOCK_HEADERS) 24 | gmock-all.o : $(GMOCK_SRCS_) 25 | $(CXX) $(CPPFLAGS) -I$(GTEST_DIR) -I$(GMOCK_DIR) $(CXXFLAGS) \ 26 | -c $(GMOCK_DIR)/src/gmock-all.cc 27 | 28 | gmock_main.o : $(GMOCK_SRCS_) 29 | $(CXX) $(CPPFLAGS) -I$(GTEST_DIR) -I$(GMOCK_DIR) $(CXXFLAGS) \ 30 | -c $(GMOCK_DIR)/src/gmock_main.cc 31 | 32 | gmock.a : gmock-all.o gtest-all.o 33 | $(AR) $(ARFLAGS) $@ $^ 34 | 35 | gmock_main.a : gmock-all.o gtest-all.o gmock_main.o 36 | $(AR) $(ARFLAGS) $@ $^ 37 | ###### 38 | 39 | SOURCES = thin-stack.cc rnn.cc \ 40 | kernels.cu util.cc 41 | 42 | test: gtest_main.a gmock_main.a test.cc $(SOURCES) 43 | rm -f bin/unittest 44 | nvcc -g -O0 -lcublas -lcurand \ 45 | -isystem $(GTEST_DIR)/include -isystem $(GMOCK_DIR)/include \ 46 | -std=c++11 -o bin/unittest \ 47 | $^ 48 | 49 | stacktest: bin/stacktest.cc $(SOURCES) 50 | nvcc -g -O0 -lcublas -lcurand \ 51 | -std=c++11 -I. -o bin/stacktest \ 52 | $^ 53 | 54 | rnntest: bin/rnntest.cc $(SOURCES) 55 | nvcc -g -O0 -lcublas -lcurand \ 56 | -std=c++11 -I. -o bin/rnntest \ 57 | $^ 58 | 59 | clean: 60 | rm -f bin/unittest bin/stacktest bin/rnntest 61 | rm -f *.a *.o *.cu_o 62 | -------------------------------------------------------------------------------- /cpp/bin/rnntest.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include "cublas_v2.h" 6 | 7 | #include "rnn.h" 8 | #include "util.h" 9 | 10 | 11 | RNNParameters load_params(ModelSpec spec) { 12 | float *U = make_rand_matrix(spec.model_dim, spec.model_dim); 13 | float *W = make_rand_matrix(spec.model_dim, spec.word_embedding_dim); 14 | float *b = make_rand_matrix(spec.model_dim, 1); 15 | 16 | RNNParameters ret = {U, W, b}; 17 | return ret; 18 | } 19 | 20 | void destroy_params(RNNParameters params) { 21 | cudaFree(params.U); 22 | cudaFree(params.W); 23 | cudaFree(params.b); 24 | } 25 | 26 | int main() { 27 | ModelSpec spec = {300, 300, (size_t) atoi(getenv("BATCH_SIZE")), 10, 30, 300}; 28 | RNNParameters params = load_params(spec); 29 | 30 | cublasHandle_t handle; 31 | cublasStatus_t stat = cublasCreate(&handle); 32 | if (stat != CUBLAS_STATUS_SUCCESS) { 33 | cout << "CUBLAS initialization failed" << endl; 34 | return 1; 35 | } 36 | 37 | RNN rnn(spec, params, handle); 38 | 39 | // Set model inputs. 40 | cout << "X:" << endl; 41 | fill_rand_matrix(rnn.X, spec.model_dim, spec.batch_size * spec.seq_length); 42 | 43 | auto time_elapsed = chrono::microseconds::zero(); 44 | int n_batches = 50; 45 | for (int t = 0; t < n_batches; t++) { 46 | auto start = chrono::high_resolution_clock::now(); 47 | rnn.forward(); 48 | auto end = chrono::high_resolution_clock::now(); 49 | time_elapsed += chrono::duration_cast(end - start); 50 | } 51 | 52 | // Print the final representation. 53 | cout << "Output:" << endl; 54 | print_device_matrix(rnn.output, spec.model_dim, spec.batch_size); 55 | 56 | cout << "Total time elapsed: " << time_elapsed.count() << endl; 57 | 58 | destroy_params(params); 59 | } 60 | -------------------------------------------------------------------------------- /cpp/bin/stacktest.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * This binary runs a basic speed test of the C++ thin-stack implementation 3 | * of SPINN in this package. 4 | * 5 | * The SPINN model and the thin-stack algorithm are described in our paper: 6 | * 7 | * A Fast Unified Model for Sentence Parsing and Understanding. 8 | * Samuel R. Bowman, Jon Gauthier, Abhinav Rastogi, Raghav Gupta, 9 | * Christopher D. Manning, and Christopher Potts. arXiv March 2016. 10 | * http://arxiv.org/abs/1603.06021 11 | * 12 | * This script loads fixed parameters from the local `params` directory, 13 | * generates random input data, and performs a SPINN feedforward as 14 | * described in the paper. It outputs timing information and the final 15 | * values at the top of the stack after the feedforward. 16 | * 17 | * For runtime instructions see the README in the root of this project 18 | * directory. 19 | */ 20 | 21 | #include 22 | #include 23 | 24 | #include 25 | #include "cublas_v2.h" 26 | 27 | #include "thin-stack.h" 28 | #include "util.h" 29 | 30 | 31 | ThinStackParameters load_params(ModelSpec spec) { 32 | float *compose_W_l = load_weights_cuda("params/compose_W_l.txt", 33 | spec.model_dim * spec.model_dim); 34 | float *compose_W_r = load_weights_cuda("params/compose_W_r.txt", 35 | spec.model_dim * spec.model_dim); 36 | float *compose_b = load_weights_cuda("params/compose_b.txt", spec.model_dim); 37 | 38 | ThinStackParameters ret = { 39 | NULL, NULL, NULL, // tracking 40 | compose_W_l, compose_W_r, NULL, compose_b, // composition 41 | }; 42 | 43 | return ret; 44 | } 45 | 46 | void destroy_params(ThinStackParameters params) { 47 | cudaFree(params.compose_W_l); 48 | cudaFree(params.compose_W_r); 49 | cudaFree(params.compose_b); 50 | } 51 | 52 | int main() { 53 | ModelSpec spec = { 54 | 300, // Dimension of stack values / constituent node values 55 | 300, // Word embedding dimension / tree leaf embedding dimension 56 | (size_t) atoi(getenv("BATCH_SIZE")), // Batch size 57 | 10, // Vocabulary size 58 | 59, // Transition sequence length (== 2 * (sentence length) - 1) 59 | 300 // Ignore -- unused. 60 | }; 61 | 62 | ThinStackParameters params = load_params(spec); 63 | 64 | cublasHandle_t handle; 65 | cublasStatus_t stat = cublasCreate(&handle); 66 | if (stat != CUBLAS_STATUS_SUCCESS) { 67 | cout << "CUBLAS initialization failed" << endl; 68 | return 1; 69 | } 70 | 71 | ThinStack ts(spec, params, handle); 72 | 73 | // Set model inputs. 74 | cout << "X:" << endl; 75 | int num_tokens = (spec.seq_length + 1) / 2; 76 | fill_rand_matrix(ts.X, spec.model_dim, spec.batch_size * num_tokens); 77 | 78 | cout << "transitions:" << endl; 79 | float *h_transitions = (float *) malloc(spec.seq_length * spec.batch_size * sizeof(float)); 80 | // Build a batch of random transition sequences which somewhat resemble 81 | // realistic transition sequences. (These are perhaps a bit more difficult 82 | // than realistic transition sequence batches in that there will be less 83 | // overlap w.r.t. merge locations within the sequence.) 84 | for (int i = 0; i < spec.seq_length * spec.batch_size; i++) { 85 | float val; 86 | if (i < spec.batch_size * 2) { 87 | val = 0.0f; 88 | } else if (i >= spec.batch_size * 2 && i < spec.batch_size * 3) { 89 | val = 1.0f; 90 | } else { 91 | val = rand() % 2 == 0 ? 1.0f : 0.0f; 92 | } 93 | h_transitions[i] = val; 94 | } 95 | cudaMemcpy(ts.transitions, h_transitions, 96 | spec.seq_length * spec.batch_size * sizeof(float), 97 | cudaMemcpyHostToDevice); 98 | free(h_transitions); 99 | #if DEBUG 100 | print_device_matrix(ts.transitions, 1, spec.batch_size * spec.seq_length); 101 | #endif 102 | 103 | auto time_elapsed = chrono::microseconds::zero(); 104 | int n_batches = 50; 105 | for (int t = 0; t < n_batches; t++) { 106 | auto start = chrono::high_resolution_clock::now(); 107 | ts.forward(); 108 | auto end = chrono::high_resolution_clock::now(); 109 | time_elapsed += chrono::duration_cast(end - start); 110 | } 111 | 112 | // Print the top of the stack. 113 | cout << "Stack top:" << endl; 114 | print_device_matrix( 115 | &ts.stack[(spec.seq_length - 1) * spec.model_dim * spec.batch_size], 116 | spec.model_dim, spec.batch_size); 117 | 118 | cout << "Total time elapsed: " << time_elapsed.count() << endl; 119 | 120 | destroy_params(params); 121 | } 122 | -------------------------------------------------------------------------------- /cpp/kernels.cu: -------------------------------------------------------------------------------- 1 | #include "kernels.cuh" 2 | #include 3 | 4 | namespace kernels { 5 | 6 | void muli_vs(float *v, float s, int N) { 7 | int num_threads = min(N, MAX_THREADS_PER_BLOCK); 8 | int num_blocks = (N + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK; 9 | k_muli_vs<<>>(v, s, N); 10 | } 11 | 12 | __global__ void k_muli_vs(float *v, float s, int N) { 13 | int idx = blockIdx.x * blockDim.x + threadIdx.x; 14 | if (idx >= N) return; 15 | 16 | v[idx] *= s; 17 | } 18 | 19 | 20 | void addi_vv(cublasHandle_t handle, float *v1, const float *v2, 21 | float v1_coeff, float v2_coeff, int N) { 22 | if (v1_coeff != 1.0) 23 | muli_vs(v1, v1_coeff, N); 24 | 25 | cublasSaxpy(handle, N, &v2_coeff, v2, 1, v1, 1); 26 | } 27 | 28 | 29 | void addi_mv(float *m, const float *v, float coeff, int M, int N) { 30 | // NB: a bias here: that we're working with small M, large N 31 | int num_threads = min(M, MAX_THREADS_PER_BLOCK); 32 | int num_blocks = min(N, MAX_BLOCKS); 33 | k_addi_mv<<>>(m, v, coeff, M, N); 34 | } 35 | 36 | __global__ void k_addi_mv(float *m, const float *v, float coeff, int M, int N) { 37 | for (int i0 = blockIdx.x; i0 < N; i0 += gridDim.x) { 38 | for (int i1 = threadIdx.x; i1 < M; i1 += blockDim.x) { 39 | m[i0 * M + i1] += coeff * v[i1]; 40 | } 41 | } 42 | } 43 | 44 | 45 | void relu(float *m, int M, int N) { 46 | // NB: a bias here: that we're working with small M, large N 47 | int num_threads = min(M, MAX_THREADS_PER_BLOCK); 48 | int num_blocks = min(N, MAX_BLOCKS); 49 | k_relu<<>>(m, M, N); 50 | } 51 | 52 | __global__ void k_relu(float *m, int M, int N) { 53 | for (int i0 = blockIdx.x; i0 < N; i0 += gridDim.x) { 54 | for (int i1 = threadIdx.x; i1 < M; i1 += blockDim.x) { 55 | m[i0 * M + i1] = max(0.0f, m[i0 * M + i1]); 56 | } 57 | } 58 | } 59 | 60 | 61 | 62 | void subtensor1(float *dst, const float *src, const float *idxs, int src_N, 63 | int N, int D, float idx_scal_shift, float idx_scal_mul, 64 | float idx_vec_shift_coeff, float *idx_vec_shift) { 65 | int num_threads = min(D, MAX_THREADS_PER_BLOCK); 66 | int num_blocks = min(N, MAX_BLOCKS); 67 | k_subtensor1<<>>(dst, src, idxs, src_N, N, D, 68 | idx_scal_shift, idx_scal_mul, idx_vec_shift_coeff, idx_vec_shift); 69 | } 70 | 71 | __global__ void k_subtensor1(float *dst, const float *src, const float *idxs, 72 | int src_N, int N, int D, float idx_scal_shift, float idx_scal_mul, 73 | float idx_vec_shift_coeff, float *idx_vec_shift) { 74 | for (int i0 = blockIdx.x; i0 < N; i0 += gridDim.x) { 75 | float fsrc_idx = idxs[i0] * idx_scal_mul + idx_scal_shift; 76 | float shift = idx_vec_shift == NULL 77 | ? 0.0f : idx_vec_shift_coeff * idx_vec_shift[i0]; 78 | 79 | int src_idx = (int) (fsrc_idx + shift); 80 | if (src_idx < 0) { 81 | // Negative index. Read from other end of the source matrix. 82 | src_idx += src_N; 83 | } 84 | 85 | //printf("%d %5f %5f %5f %5f %5f %d\n", i0, idxs[i0], shift, idx_scal_mul, idx_scal_shift, fsrc_idx, src_idx); 86 | 87 | int src_offset = src_idx * D; 88 | int dst_offset = i0 * D; 89 | for (int i1 = threadIdx.x; i1 < D; i1 += blockDim.x) 90 | dst[dst_offset + i1] = src[src_offset + i1]; 91 | } 92 | } 93 | 94 | 95 | void set_subtensor1i_s(float *dst, float src, const float *idxs, int N, 96 | float idx_scal_shift, float idx_vec_shift_coeff, float *idx_vec_shift) { 97 | int num_threads = min(N, MAX_THREADS_PER_BLOCK); 98 | int num_blocks = (N + MAX_THREADS_PER_BLOCK - 1) / MAX_THREADS_PER_BLOCK; 99 | k_set_subtensor1i_s<<>>( 100 | dst, src, idxs, N, idx_scal_shift, idx_vec_shift_coeff, idx_vec_shift); 101 | } 102 | 103 | __global__ void k_set_subtensor1i_s(float *dst, float src, const float *idxs, int N, 104 | float idx_scal_shift, float idx_vec_shift_coeff, float *idx_vec_shift) { 105 | int k_idx = blockIdx.x * blockDim.x + threadIdx.x; 106 | if (k_idx >= N) return; 107 | 108 | float fidx = idxs[k_idx] + idx_scal_shift; 109 | fidx += idx_vec_shift_coeff * idx_vec_shift[k_idx]; 110 | int idx = (int) fidx; 111 | 112 | dst[idx] = src; 113 | } 114 | 115 | 116 | void switch_m(float *dst, const float *mask, const float *ift, const float *iff, 117 | int N, int D) { 118 | int num_threads = min(D, MAX_THREADS_PER_BLOCK); 119 | int num_blocks = min(N, MAX_BLOCKS); 120 | k_switch_m<<>>(dst, mask, ift, iff, N, D); 121 | } 122 | 123 | __global__ void k_switch_m(float *dst, const float *mask, const float *ift, 124 | const float *iff, int N, int D) { 125 | for (int i0 = blockIdx.x; i0 < N; i0 += gridDim.x) { 126 | const float *src = (int) mask[i0] ? ift : iff; 127 | int offset = i0 * D; 128 | for (int i1 = threadIdx.x; i1 < D; i1 += blockDim.x) 129 | dst[offset + i1] = src[offset + i1]; 130 | } 131 | } 132 | 133 | } 134 | -------------------------------------------------------------------------------- /cpp/kernels.cuh: -------------------------------------------------------------------------------- 1 | #ifndef _kernels_ 2 | #define _kernels_ 3 | 4 | #include 5 | 6 | #include 7 | #include "cublas_v2.h" 8 | 9 | 10 | // Grid dimension constraints for jagupard machines. 11 | #define MAX_BLOCKS 65535 12 | #define MAX_THREADS_PER_BLOCK 1024 13 | 14 | 15 | namespace kernels { 16 | 17 | // v *= s 18 | void muli_vs(float *v, float s, int N); 19 | __global__ void k_muli_vs(float *v, float s, int N); 20 | 21 | /** 22 | * Add two vectors inplace (writing to the first). 23 | * 24 | * v1 = v1 * v1_coeff + v2 * v2_coeff 25 | */ 26 | void addi_vv(cublasHandle_t handle, float *v1, const float *v2, 27 | float v1_coeff, float v2_coeff, int N); 28 | 29 | /** 30 | * Broadcast-add an `N`-dim column vector onto an `M * N` matrix. 31 | * 32 | * m += coeff * v 33 | */ 34 | void addi_mv(float *m, const float *v, float coeff, int M, int N); 35 | __global__ void k_addi_mv(float *m, const float *v, float coeff, int M, 36 | int N); 37 | 38 | void relu(float *m, int M, int N); 39 | __global__ void k_relu(float *m, int M, int N); 40 | 41 | /** 42 | * Retrieve a subset of `N` rows from the contiguous `src_N * D` matrix `src` 43 | * and write them to `dst` (`M >= N`). `dst` should be large enough to hold 44 | * the `N * D` float result. `idxs` should be a length-`N` int array. 45 | * 46 | * In broadcasting Python code, this function is equivalent to the following: 47 | * 48 | * idxs_ = idxs + idx_scal_shift 49 | * idxs_ += idx_vec_shift_coeff * idx_vec_shift 50 | * dst = src[idxs_] 51 | */ 52 | void subtensor1(float *dst, const float *src, const float *idxs, int src_N, 53 | int N, int D, float idx_scal_shift, float idx_scal_mul, 54 | float idx_vec_shift_coeff, float *idx_vec_shift); 55 | __global__ void k_subtensor1(float *dst, const float *src, const float *idxs, 56 | int src_N, int N, int D, float idx_scal_shift, float idx_scal_mul, 57 | float idx_vec_shift_coeff, float *idx_vec_shift); 58 | 59 | /** 60 | * Write an int scalar into a subtensor range. 61 | * 62 | * dst[idxs + idx_scal_shift + idx_vec_shift_coeff * idx_vec_shift] = src 63 | */ 64 | void set_subtensor1i_s(float *dst, float src, const float *idxs, int N, 65 | float idx_scal_shift, float idx_vec_shift_coeff, float *idx_vec_shift); 66 | __global__ void k_set_subtensor1i_s(float *dst, float src, const float *idxs, 67 | int N, float idx_scal_shift, float idx_vec_shift_coeff, 68 | float *idx_vec_shift); 69 | 70 | /** 71 | * Switch over the rows of two matrices using a mask. 72 | * 73 | * dst = T.switch(mask, ift, iff) 74 | * 75 | * where `ift`, `iff` are `N * D` matrices, and `mask` is an `N`-dimensional 76 | * vector. 77 | */ 78 | void switch_m(float *dst, const float *mask, const float *ift, 79 | const float *iff, int N, int D); 80 | __global__ void k_switch_m(float *dst, const float *mask, const float *ift, 81 | const float *iff, int N, int D); 82 | 83 | } 84 | 85 | #endif 86 | -------------------------------------------------------------------------------- /cpp/params/compose_b.txt: -------------------------------------------------------------------------------- 1 | -0.00444719 0.118844 -0.0975603 0.0544677 0.0529373 0.0401551 -0.0489362 -0.00217848 -0.00650212 -0.0289487 -0.0479762 -0.0142745 -0.0133776 0.118265 -0.0106076 0.137174 -0.0188899 -0.0860789 -0.0566507 -0.0893182 -0.138098 -0.0215581 -0.124176 0.117908 -0.0638663 -0.136597 -0.107616 -0.0368244 0.0561491 0.0623542 -0.103 -0.0897195 0.0397765 -0.0591391 0.10617 -0.0487076 0.122437 -0.0841879 0.0905353 -0.0254862 0.0282847 -0.0988623 0.101661 -0.126514 -0.122019 -0.0503683 -0.130761 0.000512687 0.00497416 -0.0459905 2 | -------------------------------------------------------------------------------- /cpp/rnn.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * Implements a basic RNN feedforward (ReLU activation) for speed comparison. 3 | * 4 | * You can run this code with the script in `bin/rnntest.cc`. See the README in 5 | * the root directory of this project for full usage instructions. 6 | */ 7 | 8 | 9 | #include "rnn.h" 10 | using namespace std; 11 | namespace k = kernels; 12 | 13 | 14 | RNN::RNN(ModelSpec spec, RNNParameters params, cublasHandle_t handle) 15 | : SequenceModel(spec), params(params), handle(handle) { 16 | 17 | // Pre-allocate inputs. 18 | cudaMalloc(&X_indices, spec.batch_size * spec.seq_length * sizeof(float)); 19 | cudaMalloc(&X, spec.batch_size * spec.seq_length * spec.model_dim * sizeof(float)); 20 | 21 | // Pre-allocate temporary containers. 22 | cudaMalloc(&odd_output, spec.batch_size * spec.model_dim * sizeof(float)); 23 | cudaMalloc(&even_output, spec.batch_size * spec.model_dim * sizeof(float)); 24 | 25 | output = spec.seq_length % 2 == 0 ? even_output : odd_output; 26 | 27 | } 28 | 29 | 30 | RNN::~RNN() { 31 | 32 | cout << "!!!!!!!!!!" << endl; 33 | cout << "RNN dying!" << endl; 34 | cout << "!!!!!!!!!!" << endl; 35 | 36 | cudaFree(X_indices); 37 | cudaFree(X); 38 | 39 | } 40 | 41 | 42 | void RNN::forward() { 43 | 44 | // First timestep will read from odd output slot. Make sure it sees an empty 45 | // (zero) state. 46 | cudaMemset(odd_output, 0, spec.model_dim * spec.batch_size * sizeof(float)); 47 | 48 | for (int t = 0; t < spec.seq_length; t++) { 49 | step(t); 50 | #if DEBUG 51 | cout << endl << "======================" << endl << endl; 52 | #endif 53 | } 54 | 55 | // TODO: Don't need to sync here. Could have the client establish a lock on 56 | // results and simultaneously begin the next batch + copy out results 57 | cudaDeviceSynchronize(); 58 | 59 | } 60 | 61 | 62 | void RNN::step(int t) { 63 | 64 | const float *X_t = &X[t * spec.word_embedding_dim * spec.batch_size]; 65 | #if DEBUG 66 | cout << "X_t " << t << endl; 67 | print_device_matrix(X_t, spec.model_dim, spec.batch_size); 68 | #endif 69 | 70 | const float *state; 71 | float *output; 72 | if (t % 2 == 0) { 73 | // t is even -- read from output at previous odd timestep, and write into 74 | // even slot 75 | state = odd_output; 76 | output = even_output; 77 | } else { 78 | // t is odd -- read from output at previous even timestep, and write into 79 | // odd slot 80 | state = even_output; 81 | output = odd_output; 82 | } 83 | 84 | recurrence(state, X_t, output); 85 | 86 | } 87 | 88 | 89 | void RNN::recurrence(const float *state, const float *input, float *output) { 90 | 91 | // out = U(state) 92 | float alpha = 1.0f; 93 | float beta = 0.0f; 94 | cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, spec.model_dim, spec.batch_size, 95 | spec.model_dim, &alpha, params.U, spec.model_dim, state, spec.model_dim, 96 | &beta, output, spec.model_dim); 97 | // out += W(input) 98 | float beta2 = 1.0f; 99 | cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, spec.model_dim, spec.batch_size, 100 | spec.model_dim, &alpha, params.W, spec.model_dim, input, spec.model_dim, 101 | &beta2, output, spec.model_dim); 102 | 103 | // out += b 104 | k::addi_mv(output, params.b, 1.0, spec.model_dim, spec.batch_size); 105 | 106 | k::relu(output, spec.model_dim, spec.batch_size); 107 | 108 | #if DEBUG 109 | cout << "state" << endl; 110 | print_device_matrix(output, spec.model_dim, spec.batch_size); 111 | #endif 112 | 113 | } 114 | -------------------------------------------------------------------------------- /cpp/rnn.h: -------------------------------------------------------------------------------- 1 | #ifndef _thin_stack_ 2 | #define _thin_stack_ 3 | 4 | #include 5 | #include "cublas_v2.h" 6 | 7 | #include "sequence-model.h" 8 | #include "util.h" 9 | 10 | #include "kernels.cuh" 11 | 12 | 13 | typedef struct RNNParameters { 14 | float *U; // hidden-to-hidden 15 | float *W; // input-to-hidden 16 | float *b; 17 | } RNNParameters; 18 | 19 | class RNN : public SequenceModel { 20 | public: 21 | RNN(ModelSpec spec, RNNParameters params, cublasHandle_t handle); 22 | ~RNN(); 23 | 24 | RNNParameters params; 25 | cublasHandle_t handle; 26 | 27 | void forward(); 28 | 29 | float *output; 30 | 31 | private: 32 | 33 | void step(int t); 34 | 35 | void recurrence(const float *state, const float *input, float *output); 36 | 37 | // Containers for temporary (per-step) data 38 | // RNN feedforward need only maintain two state cells. Just read from one 39 | // and write to the other! 40 | float *odd_output, *even_output; 41 | 42 | }; 43 | 44 | #endif 45 | -------------------------------------------------------------------------------- /cpp/sequence-model.h: -------------------------------------------------------------------------------- 1 | #ifndef _sequence_model_ 2 | #define _sequence_model_ 3 | 4 | #include 5 | 6 | #include "util.h" 7 | 8 | #include "kernels.cuh" 9 | 10 | class SequenceModel { 11 | 12 | public: 13 | 14 | ModelSpec spec; 15 | SequenceModel(ModelSpec spec) : spec(spec) {}; 16 | 17 | // Embedding index inputs, of dimension `batch_size * seq_length` -- i.e., 18 | // we have `seq_length`-many concatenated vectors of embedding integers 19 | float *X_indices; 20 | // Embedding inputs, of dimension `model_dim * (batch_size * seq_length)` -- 21 | // i.e., along 2nd axis we have `seq_length`-many `model_dim * batch_size` 22 | // matrices. 23 | float *X; 24 | 25 | void lookup_embeddings(float *embedding_source) { 26 | kernels::subtensor1(X, embedding_source, X_indices, spec.vocab_size, 27 | spec.seq_length * spec.batch_size, spec.model_dim, 0.0f, 1.0f, 0.0f, 28 | NULL); 29 | } 30 | 31 | virtual void forward() = 0; 32 | 33 | }; 34 | 35 | 36 | #endif 37 | -------------------------------------------------------------------------------- /cpp/test.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include "cublas_v2.h" 3 | 4 | #include "gmock/gmock.h" 5 | #include "gtest/gtest.h" 6 | 7 | #include "kernels.cuh" 8 | #include "thin-stack.h" 9 | #include "util.h" 10 | 11 | using namespace testing; 12 | 13 | static ThinStack make_stack(ModelSpec spec) { 14 | // Make up random parameters. 15 | float *compose_W_l, *compose_W_r, *compose_b; 16 | cudaMalloc(&compose_W_l, spec.model_dim * spec.model_dim * sizeof(float)); 17 | cudaMalloc(&compose_W_r, spec.model_dim * spec.model_dim * sizeof(float)); 18 | cudaMalloc(&compose_b, spec.model_dim * sizeof(float)); 19 | fill_rand_matrix(compose_W_l, spec.model_dim, spec.model_dim); 20 | fill_rand_matrix(compose_W_r, spec.model_dim, spec.model_dim); 21 | fill_rand_matrix(compose_b, spec.model_dim, 1); 22 | 23 | cout << "compose_W_l" << endl; 24 | print_device_matrix(compose_W_l, spec.model_dim, spec.model_dim); 25 | 26 | cout << "compose_W_r" << endl; 27 | print_device_matrix(compose_W_r, spec.model_dim, spec.model_dim); 28 | 29 | cout << "compose_b" << endl; 30 | print_device_matrix(compose_b, 1, spec.model_dim); 31 | 32 | ThinStackParameters params = { 33 | NULL, NULL, NULL, // tracking 34 | compose_W_l, compose_W_r, NULL, compose_b // composition 35 | }; 36 | cublasHandle_t handle = getCublasHandle(); 37 | ThinStack ts(spec, params, handle); 38 | return ts; 39 | } 40 | 41 | static void free_stack(ThinStack s) { 42 | cudaFree(s.params.compose_W_l); 43 | cudaFree(s.params.compose_W_r); 44 | cublasDestroy(s.handle); 45 | } 46 | 47 | static inline void assert_matrices_equal(const float *m1, const float *m2, 48 | int M, int N) { 49 | cudaDeviceSynchronize(); 50 | float *h_m1 = (float *) malloc(2 * M * N * sizeof(float)); 51 | float *h_m2 = &h_m1[M * N]; 52 | 53 | cudaMemcpy(h_m1, m1, M * N * sizeof(float), cudaMemcpyDeviceToHost); 54 | cudaMemcpy(h_m2, m2, M * N * sizeof(float), cudaMemcpyDeviceToHost); 55 | cudaDeviceSynchronize(); 56 | 57 | for (int i = 0; i < M; i++) { 58 | for (int j = 0; j < N; j++) { 59 | ASSERT_THAT(h_m1[j * M + i], FloatEq(h_m2[j * M + i])); 60 | } 61 | } 62 | 63 | free(h_m1); 64 | } 65 | 66 | static float *compose(float *dst, ThinStack& ts, const float *l, 67 | const float *r) { 68 | // W_l l 69 | float alpha = 1.0f, beta = 0.0f; 70 | cublasSgemm(ts.handle, CUBLAS_OP_N, CUBLAS_OP_N, ts.spec.model_dim, 71 | ts.spec.batch_size, ts.spec.model_dim, &alpha, ts.params.compose_W_l, 72 | ts.spec.model_dim, l, ts.spec.model_dim, &beta, dst, ts.spec.model_dim); 73 | 74 | // += W_r r 75 | float beta2 = 1.0f; 76 | cublasSgemm(ts.handle, CUBLAS_OP_N, CUBLAS_OP_N, ts.spec.model_dim, 77 | ts.spec.batch_size, ts.spec.model_dim, &alpha, ts.params.compose_W_r, 78 | ts.spec.model_dim, r, ts.spec.model_dim, &beta2, dst, ts.spec.model_dim); 79 | 80 | // += b 81 | kernels::addi_mv(dst, ts.params.compose_b, 1.0f, ts.spec.model_dim, 82 | ts.spec.batch_size); 83 | 84 | kernels::relu(dst, ts.spec.model_dim, ts.spec.batch_size); 85 | 86 | return dst; 87 | } 88 | 89 | 90 | class ThinStackTest : public ::testing::Test { 91 | 92 | public: 93 | 94 | ModelSpec spec; 95 | ThinStack ts; 96 | 97 | ThinStackTest() : 98 | spec({300, 300, 2, 10, 5, 300}), 99 | ts(make_stack(spec)) { 100 | 101 | fill_rand_matrix(ts.X, spec.model_dim, spec.seq_length * spec.batch_size); 102 | 103 | } 104 | 105 | virtual void TearDown() { 106 | free_stack(ts); 107 | } 108 | 109 | }; 110 | 111 | 112 | // Test simple shift-shift-merge feedforward with live random weights. 113 | TEST_F(ThinStackTest, ShiftShiftMerge) { 114 | 115 | float h_transitions[] = { 116 | 0.0f, 0.0f, 117 | 0.0f, 0.0f, 118 | 1.0f, 1.0f, 119 | 0.0f, 0.0f, // DUMMY 120 | 0.0f, 0.0f, // DUMMY 121 | }; 122 | cublasSetVector(spec.seq_length * spec.batch_size, sizeof(float), 123 | h_transitions, 1, ts.transitions, 1); 124 | 125 | // Do the feedforward! 126 | ts.forward(); 127 | 128 | // Now simulate the feedforward -- this should just be the composition of the 129 | // first two buffer elements. 130 | float *expected; 131 | cudaMalloc(&expected, spec.model_dim * spec.batch_size * sizeof(float)); 132 | 133 | float *left_child = &ts.X[0]; 134 | float *right_child = &ts.X[spec.model_dim * spec.batch_size]; 135 | compose(expected, ts, left_child, right_child); 136 | 137 | float *output = &ts.stack[2 * spec.model_dim * spec.batch_size]; 138 | assert_matrices_equal(output, expected, spec.model_dim, spec.batch_size); 139 | 140 | } 141 | 142 | 143 | TEST_F(ThinStackTest, ShiftShiftMergeShiftMerge) { 144 | 145 | float h_transitions[] = { 146 | 0.0f, 0.0f, 147 | 0.0f, 0.0f, 148 | 1.0f, 1.0f, 149 | 0.0f, 0.0f, 150 | 1.0f, 1.0f, 151 | }; 152 | cublasSetVector(spec.seq_length * spec.batch_size, sizeof(float), 153 | h_transitions, 1, ts.transitions, 1); 154 | 155 | // Do the feedforward! 156 | ts.forward(); 157 | 158 | // Now simulate the feedforward. 159 | float *c1, *c2; 160 | cudaMalloc(&c1, spec.model_dim * spec.batch_size * sizeof(float)); 161 | cudaMalloc(&c2, spec.model_dim * spec.batch_size * sizeof(float)); 162 | 163 | // c1 164 | float *left_child = &ts.X[0]; 165 | float *right_child = &ts.X[spec.model_dim * spec.batch_size]; 166 | compose(c1, ts, left_child, right_child); 167 | 168 | // c2 169 | left_child = c1; 170 | right_child = &ts.X[2 * spec.model_dim * spec.batch_size]; 171 | compose(c2, ts, left_child, right_child); 172 | 173 | float *output = &ts.stack[4 * spec.model_dim * spec.batch_size]; 174 | assert_matrices_equal(output, c2, spec.model_dim, spec.batch_size); 175 | 176 | } 177 | 178 | 179 | TEST(Kernels, AddI_MV) { 180 | int M = 5, N = 10; 181 | float *m, *v; 182 | cudaMalloc(&m, M * N * sizeof(float)); 183 | cudaMalloc(&v, M * sizeof(float)); 184 | 185 | fill_rand_matrix(m, M, N); 186 | fill_rand_matrix(v, M, 1); 187 | 188 | cout << "matrix" << endl; 189 | print_device_matrix(m, M, N); 190 | 191 | cout << "vector" << endl; 192 | print_device_matrix(v, M, 1); 193 | 194 | kernels::addi_mv(m, v, 1.0, M, N); 195 | cudaDeviceSynchronize(); 196 | 197 | cout << "result" << endl; 198 | print_device_matrix(m, M, N); 199 | } 200 | -------------------------------------------------------------------------------- /cpp/thin-stack.cc: -------------------------------------------------------------------------------- 1 | /** 2 | * This file contains a C++/CUDA implementation of the thin-stack algorithm. 3 | * Thin-stack computes the same function as a vanilla recursive neural network 4 | * and as the SPINN-PI-NT model. 5 | * 6 | * The SPINN model and the thin-stack algorithm are described in our paper: 7 | * 8 | * A Fast Unified Model for Sentence Parsing and Understanding. 9 | * Samuel R. Bowman, Jon Gauthier, Abhinav Rastogi, Raghav Gupta, 10 | * Christopher D. Manning, and Christopher Potts. arXiv March 2016. 11 | * http://arxiv.org/abs/1603.06021 12 | * 13 | * The exact model implemented here is a recursive neural network (equivalent 14 | * to SPINN-PI-NT) with ReLU activations. It has been verified to compute the 15 | * exact same function of its inputs as a recursive neural network. 16 | * 17 | * You can execute this code using the script `bin/stacktest`. See the README 18 | * in the root directory of this project for more usage instructions. 19 | */ 20 | 21 | 22 | #include "thin-stack.h" 23 | using namespace std; 24 | namespace k = kernels; 25 | 26 | 27 | ThinStack::ThinStack(ModelSpec spec, ThinStackParameters params, 28 | cublasHandle_t handle) 29 | : SequenceModel(spec), params(params), stack_size(spec.seq_length), 30 | handle(handle) { 31 | 32 | stack_total_size = (stack_size * spec.batch_size) * spec.model_dim; 33 | buffer_total_size = spec.batch_size * spec.seq_length * spec.model_dim; 34 | queue_total_size = spec.batch_size * spec.seq_length; 35 | cursors_total_size = spec.batch_size; 36 | 37 | // Pre-allocate inputs. 38 | cudaMalloc(&X_indices, spec.batch_size * spec.seq_length * sizeof(float)); 39 | cudaMalloc(&X, spec.batch_size * spec.seq_length * spec.model_dim * sizeof(float)); 40 | cudaMalloc(&transitions, spec.batch_size * spec.seq_length * sizeof(float)); 41 | 42 | // Pre-allocate auxiliary data structures. 43 | cudaMalloc(&stack, stack_total_size * sizeof(float)); 44 | cudaMalloc(&queue, queue_total_size * sizeof(float)); 45 | cudaMalloc(&cursors, cursors_total_size * sizeof(float)); 46 | cudaMalloc(&buffer, buffer_total_size * sizeof(float)); 47 | 48 | // Pre-allocate temporary containers. 49 | cudaMalloc(&buffer_top_idxs_t, spec.batch_size * sizeof(float)); 50 | cudaMalloc(&buffer_top_t, spec.batch_size * spec.model_dim * sizeof(float)); 51 | cudaMalloc(&stack_1_ptrs, spec.batch_size * sizeof(float)); 52 | cudaMalloc(&stack_1_t, spec.model_dim * spec.batch_size * sizeof(float)); 53 | cudaMalloc(&stack_2_ptrs, spec.batch_size * sizeof(float)); 54 | cudaMalloc(&stack_2_t, spec.model_dim * spec.batch_size * sizeof(float)); 55 | cudaMalloc(&push_output, spec.batch_size * spec.model_dim * sizeof(float)); 56 | cudaMalloc(&merge_output, spec.batch_size * spec.model_dim * sizeof(float)); 57 | 58 | // Pre-allocate accumulators. 59 | cudaMalloc(&buffer_cur_t, spec.batch_size * sizeof(float)); 60 | 61 | init_helpers(); 62 | 63 | } 64 | 65 | 66 | void ThinStack::init_helpers() { 67 | cudaMalloc(&batch_range, spec.batch_size * sizeof(float)); 68 | cudaMalloc(&batch_ones, spec.batch_size * sizeof(float)); 69 | 70 | float h_batch_ones[spec.batch_size]; 71 | float h_batch_range[spec.batch_size]; 72 | for (int i = 0; i < spec.batch_size; i++) { 73 | h_batch_ones[i] = 1.0f; 74 | h_batch_range[i] = (float) i; 75 | } 76 | 77 | cudaMemcpy(batch_range, h_batch_range, spec.batch_size * sizeof(float), 78 | cudaMemcpyHostToDevice); 79 | cudaMemcpy(batch_ones, h_batch_ones, spec.batch_size * sizeof(float), 80 | cudaMemcpyHostToDevice); 81 | cudaDeviceSynchronize(); 82 | } 83 | 84 | 85 | void ThinStack::free_helpers() { 86 | cudaFree(batch_ones); 87 | cudaFree(batch_range); 88 | } 89 | 90 | 91 | ThinStack::~ThinStack() { 92 | 93 | cout << "!!!!!!!!!!!!!!!!" << endl; 94 | cout << "ThinStack dying!" << endl; 95 | cout << "!!!!!!!!!!!!!!!!" << endl; 96 | free_helpers(); 97 | 98 | cudaFree(X_indices); 99 | cudaFree(X); 100 | cudaFree(transitions); 101 | 102 | cudaFree(stack); 103 | cudaFree(queue); 104 | cudaFree(cursors); 105 | cudaFree(buffer); 106 | 107 | cudaFree(buffer_top_idxs_t); 108 | cudaFree(buffer_top_t); 109 | cudaFree(stack_1_ptrs); 110 | cudaFree(stack_2_ptrs); 111 | cudaFree(push_output); 112 | cudaFree(merge_output); 113 | 114 | cudaFree(buffer_cur_t); 115 | 116 | } 117 | 118 | 119 | void ThinStack::forward() { 120 | 121 | // TODO embedding projection 122 | buffer = X; 123 | reset(); 124 | cudaDeviceSynchronize(); 125 | 126 | for (int t = 0; t < spec.seq_length; t++) { 127 | step(t); 128 | #if DEBUG 129 | cout << endl << "======================" << endl << endl; 130 | #endif 131 | } 132 | 133 | // TODO: Don't need to sync here. Could have the client establish a lock on 134 | // results and simultaneously begin the next batch + copy out results 135 | cudaDeviceSynchronize(); 136 | 137 | #if DEBUG 138 | cout << "final" << endl; 139 | print_device_matrix(stack, spec.model_dim, spec.batch_size * spec.seq_length); 140 | #endif 141 | 142 | } 143 | 144 | 145 | void ThinStack::step(int t) { 146 | 147 | float *transitions_t = &transitions[t * spec.batch_size]; 148 | #if DEBUG 149 | cout << "transitions " << t << endl; 150 | print_device_matrix(transitions_t, 1, spec.batch_size); 151 | #endif 152 | 153 | // buffer_top = buffer[buffer_cur_t * batch_size + batch_range] 154 | k::subtensor1(buffer_top_t, buffer, buffer_cur_t, 155 | spec.batch_size * spec.model_dim, spec.batch_size, 156 | spec.model_dim, 0.0f, spec.batch_size, 1.0f, batch_range); 157 | #if DEBUG 158 | cout << "buffer_top after:" << endl; 159 | print_device_matrix(buffer_top_t, spec.model_dim, spec.batch_size); 160 | #endif 161 | 162 | // stack_2_ptrs = (cursors - 1) + batch_range * seq_length 163 | k::subtensor1(stack_2_ptrs, queue, cursors, spec.batch_size, 164 | spec.batch_size, 1, -1.0f, 1.0f, spec.seq_length, 165 | batch_range); 166 | #if DEBUG 167 | cout << "stack_2_ptrs #1" << endl; 168 | print_device_matrix(stack_2_ptrs, 1, spec.batch_size); 169 | #endif 170 | 171 | // stack_2_ptrs = stack_2_ptrs * batch_size + batch_range * 1 172 | k::addi_vv(handle, stack_2_ptrs, batch_range, spec.batch_size, 1, 173 | spec.batch_size); 174 | #if DEBUG 175 | cout << "stack_2_ptrs" << endl; 176 | print_device_matrix(stack_2_ptrs, 1, spec.batch_size); 177 | #endif 178 | 179 | // stack_1, stack_2 180 | // stack_1_t = stack[batch_range + (t - 1) * spec.batch_size] 181 | k::subtensor1(stack_1_t, stack, batch_range, 182 | spec.batch_size * spec.seq_length, spec.batch_size, 183 | spec.model_dim, (float) (t - 1) * spec.batch_size, 1.0f, 0.0f, 184 | NULL); 185 | 186 | k::subtensor1(stack_2_t, stack, stack_2_ptrs, 187 | spec.batch_size * spec.seq_length, spec.batch_size, 188 | spec.model_dim, 0.0f, 1.0f, 0.0f, NULL); 189 | 190 | // Run recurrence, which writes into `push_output`, `merge_output`. 191 | recurrence(stack_1_t, stack_2_t, buffer_top_t); 192 | 193 | // Write in the next stack top. 194 | mask_and_update_stack(buffer_top_t, merge_output, transitions_t, t); 195 | 196 | mask_and_update_cursors(cursors, transitions_t, t); 197 | #if DEBUG 198 | cout << "cursors after" << endl; 199 | print_device_matrix(cursors, 1, spec.batch_size); 200 | #endif 201 | 202 | 203 | // queue[cursors + 0 + batch_range * spec.seq_length] = t 204 | k::set_subtensor1i_s(queue, t, cursors, spec.batch_size, 0, spec.seq_length, 205 | batch_range); 206 | #if DEBUG 207 | cout << "queue after" << endl; 208 | print_device_matrix(queue, 1, spec.seq_length * spec.batch_size); 209 | #endif 210 | 211 | // buffer_cur += (1 - transitions) 212 | update_buffer_cur(buffer_cur_t, transitions_t, t); 213 | #if DEBUG 214 | cout << "buffer cur after" << endl; 215 | print_device_matrix(buffer_cur_t, 1, spec.batch_size); 216 | #endif 217 | 218 | } 219 | 220 | 221 | void ThinStack::recurrence(const float *stack_1_t, const float *stack_2_t, 222 | const float *buffer_top_t) { 223 | 224 | #if DEBUG 225 | cout << "left child:" << endl; 226 | print_device_matrix(stack_2_t, spec.model_dim, spec.batch_size); 227 | 228 | cout << "right child:" << endl; 229 | print_device_matrix(stack_1_t, spec.model_dim, spec.batch_size); 230 | #endif 231 | 232 | // merge_out = W_l l 233 | float alpha = 1.0f; 234 | float beta = 0.0f; 235 | cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, spec.model_dim, spec.batch_size, 236 | spec.model_dim, &alpha, params.compose_W_l, spec.model_dim, stack_2_t, 237 | spec.model_dim, &beta, merge_output, spec.model_dim); 238 | // merge_out += W_r r 239 | float beta2 = 1.0f; 240 | cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, spec.model_dim, spec.batch_size, 241 | spec.model_dim, &alpha, params.compose_W_r, spec.model_dim, stack_1_t, 242 | spec.model_dim, &beta2, merge_output, spec.model_dim); 243 | 244 | // merge_out += b 245 | k::addi_mv(merge_output, params.compose_b, 1.0, spec.model_dim, 246 | spec.batch_size); 247 | 248 | k::relu(merge_output, spec.model_dim, spec.batch_size); 249 | 250 | } 251 | 252 | 253 | void ThinStack::mask_and_update_stack(const float *push_value, 254 | const float *merge_value, const float *transitions, int t) { 255 | 256 | // Find start position of write destination (next-top corresponding to 257 | // timestep `t`). 258 | int stack_offset = t * spec.batch_size * spec.model_dim; 259 | 260 | #if DEBUG 261 | cout << "merge value:" << endl; 262 | print_device_matrix(merge_value, spec.model_dim, spec.batch_size); 263 | cout << "push value:" << endl; 264 | print_device_matrix(push_value, spec.model_dim, spec.batch_size); 265 | #endif 266 | 267 | k::switch_m(&stack[stack_offset], transitions, merge_value, push_value, 268 | spec.batch_size, spec.model_dim); 269 | 270 | #if DEBUG 271 | cout << "stack top t (offset " << stack_offset << "):" << endl; 272 | print_device_matrix(&stack[stack_offset], spec.model_dim, spec.batch_size); 273 | #endif 274 | 275 | } 276 | 277 | 278 | void ThinStack::mask_and_update_cursors(float *cursors, const float *transitions, 279 | int t) { 280 | 281 | // cursors += 1 282 | float alpha1 = 1.0f; 283 | cublasSaxpy(handle, spec.batch_size, &alpha1, batch_ones, 1, cursors, 1); 284 | 285 | // cursors -= 2*transitions 286 | float alpha2 = -2.0f; 287 | cublasSaxpy(handle, spec.batch_size, &alpha2, transitions, 1, cursors, 1); 288 | 289 | } 290 | 291 | 292 | void ThinStack::update_buffer_cur(float *buffer_cur_t, float *transitions, int t) { 293 | 294 | // buffer_cur += 1 295 | float alpha1 = 1.0; 296 | cublasSaxpy(handle, spec.batch_size, &alpha1, batch_ones, 1, buffer_cur_t, 1); 297 | 298 | // buffer_cur -= transitions 299 | float alpha2 = -1.0; 300 | cublasSaxpy(handle, spec.batch_size, &alpha2, transitions, 1, buffer_cur_t, 1); 301 | 302 | } 303 | 304 | 305 | void ThinStack::reset() { 306 | // TODO: Technically these don't need to be explicitly zeroed out before 307 | // every feedforward. They just get overwritten and their bad values are 308 | // never used, provided that the feedforward uses a valid transition 309 | // sequence. 310 | cudaMemset(stack, 0, stack_total_size * sizeof(float)); 311 | cudaMemset(queue, 0, queue_total_size * sizeof(float)); 312 | 313 | float alpha = -1.0f; 314 | cudaMemset(cursors, 0, cursors_total_size * sizeof(float)); 315 | cublasSaxpy(handle, spec.batch_size, &alpha, batch_ones, 1, cursors, 1); 316 | 317 | cudaMemset(buffer_cur_t, 0, spec.batch_size * sizeof(float)); 318 | } 319 | -------------------------------------------------------------------------------- /cpp/thin-stack.h: -------------------------------------------------------------------------------- 1 | #ifndef _thin_stack_ 2 | #define _thin_stack_ 3 | 4 | #include 5 | #include "cublas_v2.h" 6 | 7 | #include "sequence-model.h" 8 | #include "util.h" 9 | 10 | #include "kernels.cuh" 11 | 12 | 13 | typedef struct ThinStackParameters { 14 | float *tracking_W_inp; 15 | float *tracking_W_hid; 16 | float *tracking_b; 17 | float *compose_W_l; 18 | float *compose_W_r; 19 | float *compose_W_ext; 20 | float *compose_b; 21 | } ThinStackParameters; 22 | 23 | class ThinStack : public SequenceModel { 24 | public: 25 | /** 26 | * Constructs a new `ThinStack`. 27 | */ 28 | ThinStack(ModelSpec spec, ThinStackParameters params, 29 | cublasHandle_t handle); 30 | 31 | ~ThinStack(); 32 | 33 | ThinStackParameters params; 34 | cublasHandle_t handle; 35 | 36 | void forward(); 37 | 38 | float *transitions; 39 | 40 | float *stack; 41 | 42 | private: 43 | 44 | void step(int t); 45 | 46 | // Reset internal storage. Must be run before beginning a sequence 47 | // feedforward. 48 | void reset(); 49 | 50 | void recurrence(const float *stack_1_t, const float *stack_2_t, 51 | const float *buffer_top_t); 52 | void mask_and_update_stack(const float *push_value, 53 | const float *merge_value, const float *transitions, int t); 54 | void mask_and_update_cursors(float *cursors, const float *transitions, 55 | int t); 56 | void update_buffer_cur(float *buffer_cur_t, float *transitions, int t); 57 | 58 | void init_helpers(); 59 | void free_helpers(); 60 | 61 | size_t stack_size; 62 | 63 | size_t stack_total_size; 64 | size_t buffer_total_size; 65 | size_t queue_total_size; 66 | size_t cursors_total_size; 67 | 68 | // Containers for temporary (per-step) data 69 | float *buffer_top_idxs_t; 70 | float *buffer_top_t; 71 | float *stack_1_ptrs; 72 | float *stack_1_t; 73 | float *stack_2_ptrs; 74 | float *stack_2_t; 75 | float *push_output; 76 | float *merge_output; 77 | 78 | // Per-step accumulators 79 | float *buffer_cur_t; 80 | 81 | // Dumb helpers 82 | float *batch_ones; 83 | float *batch_range; 84 | 85 | // `model_dim * (batch_size * seq_length)` 86 | // `seq_length`-many `model_dim * batch_size` matrices, flattened into one. 87 | float *buffer; 88 | float *queue; 89 | float *cursors; 90 | 91 | }; 92 | 93 | #endif 94 | -------------------------------------------------------------------------------- /cpp/util.cc: -------------------------------------------------------------------------------- 1 | #include "util.h" 2 | 3 | float *load_weights(string filename, int N) { 4 | float *ret = (float *) malloc(N * sizeof(float)); 5 | cout << filename << endl; 6 | ifstream file(filename); 7 | 8 | float x; 9 | for (int i = 0; i < N; i++) { 10 | file >> x; 11 | ret[i] = x; 12 | } 13 | 14 | return ret; 15 | } 16 | 17 | float *load_weights_cuda(string filename, int N, float *target) { 18 | float *h_weights = load_weights(filename, N); 19 | 20 | if (!target) { 21 | float *d_weights; 22 | cudaMalloc(&d_weights, N * sizeof(float)); 23 | target = d_weights; 24 | } 25 | cudaMemcpy(target, h_weights, N * sizeof(float), 26 | cudaMemcpyHostToDevice); 27 | free(h_weights); 28 | return target; 29 | } 30 | 31 | 32 | cublasHandle_t getCublasHandle() { 33 | cublasHandle_t handle; 34 | cublasStatus_t stat = cublasCreate(&handle); 35 | if (stat != CUBLAS_STATUS_SUCCESS) { 36 | cout << "CUBLAS initialization failed (" << stat << ")" << endl; 37 | return NULL; 38 | } 39 | return handle; 40 | } 41 | 42 | 43 | // Print a column-major matrix stored on device. 44 | void print_device_matrix(const float *m, int M, int N) { 45 | cudaDeviceSynchronize(); 46 | float *h_m = (float *) malloc(M * N * sizeof(float)); 47 | cudaMemcpy(h_m, m, M * N * sizeof(float), cudaMemcpyDeviceToHost); 48 | cudaDeviceSynchronize(); 49 | 50 | cout << "[[ "; 51 | for (int i = 0; i < M; i++) { 52 | if (i > 0) 53 | cout << " [ "; 54 | 55 | for (int j = 0; j < N; j++) { 56 | float val = h_m[j * M + i]; 57 | printf(" %+.05f, ", val); 58 | } 59 | 60 | cout << " ],"; 61 | if (i < M - 1) 62 | cout << endl; 63 | } 64 | cout << "]" << endl << endl; 65 | 66 | free(h_m); 67 | } 68 | 69 | 70 | float *make_rand_matrix(int M, int N) { 71 | float *m; 72 | cudaMalloc(&m, M * N * sizeof(float)); 73 | fill_rand_matrix(m, M, N); 74 | return m; 75 | } 76 | 77 | 78 | void fill_rand_matrix(float *m, int M, int N) { 79 | static curandGenerator_t prng; 80 | if (!prng) { 81 | curandCreateGenerator(&prng, CURAND_RNG_PSEUDO_DEFAULT); 82 | curandSetPseudoRandomGeneratorSeed(prng, (unsigned long long) clock()); 83 | } 84 | 85 | curandGenerateNormal(prng, m, M * N, 0.0f, 0.5f); 86 | } 87 | -------------------------------------------------------------------------------- /cpp/util.h: -------------------------------------------------------------------------------- 1 | #ifndef _util_ 2 | #define _util_ 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | #include "cublas_v2.h" 11 | 12 | using namespace std; 13 | 14 | 15 | #define DEBUG 0 16 | 17 | typedef struct ModelSpec { 18 | size_t model_dim; 19 | size_t word_embedding_dim; 20 | size_t batch_size; 21 | size_t vocab_size; 22 | size_t seq_length; 23 | size_t model_visible_dim; 24 | } ModelSpec; 25 | 26 | 27 | float *load_weights(string filename, int N); 28 | float *load_weights_cuda(string filename, int N, float *target=NULL); 29 | 30 | cublasHandle_t getCublasHandle(); 31 | 32 | void print_device_matrix(const float *m, int M, int N); 33 | float *make_rand_matrix(int M, int N); 34 | void fill_rand_matrix(float *m, int M, int N); 35 | 36 | #endif 37 | -------------------------------------------------------------------------------- /evalb_rembed.prm: -------------------------------------------------------------------------------- 1 | ##------------------------------------------## 2 | ## Debug mode ## 3 | ## print out data for individual sentence ## 4 | ##------------------------------------------## 5 | DEBUG 0 6 | 7 | ##------------------------------------------## 8 | ## MAX error ## 9 | ## Number of error to stop the process. ## 10 | ## This is useful if there could be ## 11 | ## tokanization error. ## 12 | ## The process will stop when this number## 13 | ## of errors are accumulated. ## 14 | ##------------------------------------------## 15 | # (arastogi) Interested in all parses, set to 16 | # high value 17 | MAX_ERROR 1000000 18 | 19 | ##------------------------------------------## 20 | ## Cut-off length for statistics ## 21 | ## At the end of evaluation, the ## 22 | ## statistics for the senetnces of length## 23 | ## less than or equal to this number will## 24 | ## be shown, on top of the statistics ## 25 | ## for all the sentences ## 26 | ##------------------------------------------## 27 | # (arastogi) Set cutoff to num_transitions 28 | CUTOFF_LEN 50 29 | 30 | ##------------------------------------------## 31 | ## unlabeled or labeled bracketing ## 32 | ## 0: unlabeled bracketing ## 33 | ## 1: labeled bracketing ## 34 | ##------------------------------------------## 35 | LABELED 1 36 | 37 | ##------------------------------------------## 38 | ## Delete labels ## 39 | ## list of labels to be ignored. ## 40 | ## If it is a pre-terminal label, delete ## 41 | ## the word along with the brackets. ## 42 | ## If it is a non-terminal label, just ## 43 | ## delete the brackets (don't delete ## 44 | ## deildrens). ## 45 | ##------------------------------------------## 46 | # DELETE_LABEL TOP 47 | 48 | 49 | ##------------------------------------------## 50 | ## Delete labels for length calculation ## 51 | ## list of labels to be ignored for ## 52 | ## length calculation purpose ## 53 | ##------------------------------------------## 54 | # DELETE_LABEL_FOR_LENGTH -NONE- 55 | 56 | 57 | ##------------------------------------------## 58 | ## Equivalent labels, words ## 59 | ## the pairs are considered equivalent ## 60 | ## This is non-directional. ## 61 | ##------------------------------------------## 62 | # EQ_LABEL T TT 63 | 64 | # EQ_WORD This this 65 | -------------------------------------------------------------------------------- /python/requirements.txt: -------------------------------------------------------------------------------- 1 | nose==1.3.7 2 | numpy==1.10.4 3 | six==1.10.0 4 | wheel==0.24.0 5 | python-gflags==2.0 6 | git+git://github.com/hans/theano-hacked.git 7 | -------------------------------------------------------------------------------- /python/spinn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordnlp/spinn/f2f95f585328ec5cd7da071753f5e0582e7600a0/python/spinn/__init__.py -------------------------------------------------------------------------------- /python/spinn/afs_safe_logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import sys 3 | import json 4 | 5 | class Logger(object): 6 | # A logging alternative that doesn't leave logs open between writes, 7 | # so as to allow AFS synchronization. 8 | 9 | # Level constants 10 | DEBUG = 0 11 | INFO = 1 12 | WARNING = 2 13 | ERROR = 3 14 | 15 | def __init__(self, log_path=None, json_log_path=None, min_print_level=0, min_file_level=0): 16 | # log_path: The full path for the log file to write. The file will be appended 17 | # to if it exists. 18 | # min_print_level: Only messages with level above this level will be printed to stderr. 19 | # min_file_level: Only messages with level above this level will be 20 | # written to disk. 21 | self.log_path = log_path 22 | self.json_log_path = json_log_path 23 | self.min_print_level = min_print_level 24 | self.min_file_level = min_file_level 25 | 26 | def Log(self, message, level=INFO): 27 | if level >= self.min_print_level: 28 | # Write to STDERR 29 | sys.stderr.write("[%i] %s\n" % (level, message)) 30 | if self.log_path and level >= self.min_file_level: 31 | # Write to the log file then close it 32 | with open(self.log_path, 'a') as f: 33 | datetime_string = datetime.datetime.now().strftime( 34 | "%y-%m-%d %H:%M:%S") 35 | f.write("%s [%i] %s\n" % (datetime_string, level, message)) 36 | 37 | def LogJSON(self, message_obj, level=INFO): 38 | if self.json_log_path and level >= self.min_file_level: 39 | with open(self.json_log_path, 'w') as f: 40 | print >>f, json.dumps(message_obj) 41 | else: 42 | sys.stderr.write('WARNING: No JSON log filename.') 43 | 44 | -------------------------------------------------------------------------------- /python/spinn/cbow.py: -------------------------------------------------------------------------------- 1 | """Theano-based sum-of-words implementations.""" 2 | 3 | import numpy as np 4 | import theano 5 | 6 | from theano import tensor as T 7 | from spinn import util 8 | 9 | 10 | class CBOW(object): 11 | """Plain sum of words encoder implementation. 12 | """ 13 | 14 | def __init__(self, model_dim, word_embedding_dim, vocab_size, _0, _1, 15 | _2, _3, _4, vs, 16 | X=None, 17 | initial_embeddings=None, 18 | make_test_fn=False, 19 | use_attention=False, 20 | **kwargs): 21 | """Construct an RNN. 22 | 23 | Args: 24 | model_dim: Dimensionality of hidden state. Must equal word_embedding_dim. 25 | vocab_size: Number of unique tokens in vocabulary. 26 | compose_network: Blocks-like function which accepts arguments 27 | `prev_hidden_state, inp, inp_dim, hidden_dim, vs, name` (see e.g. `util.LSTMLayer`). 28 | training_mode: A Theano scalar indicating whether to act as a training model 29 | with dropout (1.0) or to act as an eval model with rescaling (0.0). 30 | vs: VariableStore instance for parameter storage 31 | X: Theano batch describing input matrix, or `None` (in which case 32 | this instance will make its own batch variable). 33 | make_test_fn: If set, create a function to run a scan for testing. 34 | kwargs, _0, _1, _2, _3, _4: Ignored. meant to make the signature match the signature of HardStack(). 35 | """ 36 | 37 | assert model_dim == word_embedding_dim 38 | assert not use_attention or use_attention == "None" 39 | 40 | self.model_dim = model_dim 41 | self.word_embedding_dim = word_embedding_dim 42 | 43 | self.vocab_size = vocab_size 44 | 45 | self._vs = vs 46 | 47 | self.initial_embeddings = initial_embeddings 48 | 49 | self.X = X 50 | 51 | self._make_params() 52 | self._make_inputs() 53 | self._make_sum() 54 | 55 | if make_test_fn: 56 | assert False, "Not implemented." 57 | 58 | def _make_params(self): 59 | # Per-token embeddings. 60 | if self.initial_embeddings is not None: 61 | def EmbeddingInitializer(shape): 62 | return self.initial_embeddings 63 | self.embeddings = self._vs.add_param( 64 | "embeddings", (self.vocab_size, self.word_embedding_dim), 65 | initializer=EmbeddingInitializer, 66 | trainable=False, 67 | savable=False) 68 | else: 69 | self.embeddings = self._vs.add_param( 70 | "embeddings", (self.vocab_size, self.word_embedding_dim)) 71 | 72 | def _make_inputs(self): 73 | self.X = self.X or T.imatrix("X") 74 | 75 | def _make_sum(self): 76 | """Build the sequential composition / scan graph.""" 77 | 78 | batch_size, seq_length = self.X.shape 79 | 80 | # Look up all of the embeddings that will be used. 81 | raw_embeddings = self.embeddings[self.X] # batch_size * seq_length * emb_dim 82 | 83 | self.final_representations = T.sum(raw_embeddings, axis=1, keepdims=True, dtype="float32", acc_dtype="float32") 84 | self.transitions_pred = T.zeros((batch_size, 0)) 85 | self.predict_transitions = False 86 | self.tracking_state_final = None 87 | 88 | -------------------------------------------------------------------------------- /python/spinn/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordnlp/spinn/f2f95f585328ec5cd7da071753f5e0582e7600a0/python/spinn/data/__init__.py -------------------------------------------------------------------------------- /python/spinn/data/boolean/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordnlp/spinn/f2f95f585328ec5cd7da071753f5e0582e7600a0/python/spinn/data/boolean/__init__.py -------------------------------------------------------------------------------- /python/spinn/data/boolean/generate_bl_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Recursively enumerates sentences of Boolean logic with their truth values and 4 | # writes them to disjoint training, dev, and test set files, formatted for 5 | # import_binary_bracketed_data.py. 6 | # 7 | # The maximum length of the examples and the number of generated examples are both 8 | # governed by the recursion depth. 9 | 10 | import copy 11 | import random 12 | 13 | RECURSION_DEPTH = 3 14 | 15 | TRAIN_PORTION = 0.98 16 | DEV_PORTION = 0.01 17 | 18 | 19 | def get_value_for_tree(tree): 20 | if isinstance(tree, tuple): 21 | if tree[1] == 'not': 22 | child = get_value_for_tree(tree[0]) 23 | return not child 24 | else: 25 | left = get_value_for_tree(tree[0]) 26 | right = get_value_for_tree(tree[1]) 27 | if tree[2] == "and": 28 | return left and right 29 | elif tree[2] == "or": 30 | return left or right 31 | else: 32 | print 'syntax error', tree 33 | else: 34 | return tree 35 | 36 | 37 | def expand(statements): 38 | result = copy.copy(statements) 39 | for statement in statements: 40 | result.append((statement, 'not')) 41 | for inner_statement in statements: 42 | result.append((statement, inner_statement, 'and')) 43 | result.append((statement, inner_statement, 'or')) 44 | return result 45 | 46 | 47 | def uniq(seq, idfun=None): 48 | # order preserving 49 | if idfun is None: 50 | def idfun(x): 51 | return x 52 | seen = {} 53 | result = [] 54 | for item in seq: 55 | marker = idfun(item) 56 | # in old Python versions: 57 | # if seen.has_key(marker) 58 | # but in new ones: 59 | if marker in seen: 60 | continue 61 | seen[marker] = 1 62 | result.append(item) 63 | return result 64 | 65 | 66 | def to_string(expr): 67 | if isinstance(expr, int): 68 | return value_names[expr] 69 | if isinstance(expr, str): 70 | return expr 71 | elif len(expr) == 3: 72 | return "( " + to_string(expr[0]) + " ( " + to_string(expr[1]) + " " + to_string(expr[2]) + " ) )" 73 | else: 74 | return "( " + to_string(expr[0]) + " " + to_string(expr[1]) + " )" 75 | 76 | 77 | if __name__ == "__main__": 78 | values = [0, 1] 79 | value_names = ['F', 'T'] 80 | 81 | total = 0 82 | statements = [0, 1] 83 | for i in range(RECURSION_DEPTH): 84 | statements = expand(statements) 85 | statements = uniq(statements) 86 | 87 | outputs = [] 88 | 89 | for i, statement in enumerate(statements): 90 | tv = get_value_for_tree(statement) 91 | tv_string = value_names[tv] 92 | 93 | total += 1 94 | outputs.append(tv_string + "\t" + to_string(statement)) 95 | 96 | outputs = uniq(outputs) 97 | random.shuffle(outputs) 98 | 99 | filename = 'pbl_train.tsv' 100 | f = open(filename, 'w') 101 | for i in range(int(TRAIN_PORTION * len(outputs))): 102 | output = outputs[i] 103 | f.write(output + "\n") 104 | f.close() 105 | 106 | filename = 'pbl_dev.tsv' 107 | f = open(filename, 'w') 108 | for i in range(int(TRAIN_PORTION * len(outputs)), int((TRAIN_PORTION + DEV_PORTION) * len(outputs))): 109 | output = outputs[i] 110 | f.write(output + "\n") 111 | f.close() 112 | 113 | filename = 'pbl_test.tsv' 114 | f = open(filename, 'w') 115 | for i in range(int((TRAIN_PORTION + DEV_PORTION) * len(outputs)), len(outputs)): 116 | output = outputs[i] 117 | f.write(output + "\n") 118 | f.close() 119 | -------------------------------------------------------------------------------- /python/spinn/data/boolean/load_boolean_data.py: -------------------------------------------------------------------------------- 1 | 2 | #!/usr/bin/env python 3 | 4 | # Loads a file where each line contains a label, followed by a tab, followed 5 | # by a sequence of words with a binary parse indicated by space-separated parentheses. 6 | # 7 | # Example: 8 | # sentence_label ( ( word word ) ( ( word word ) word ) ) 9 | import numpy as np 10 | 11 | from spinn import util 12 | 13 | NUM_CLASSES = 2 14 | 15 | SENTENCE_PAIR_DATA = False 16 | 17 | FIXED_VOCABULARY = { 18 | util.PADDING_TOKEN: 0, 19 | "T": 1, 20 | "F": 2, 21 | "not": 3, 22 | "and": 4, 23 | "or": 5 24 | } 25 | 26 | LABEL_MAP = { 27 | "T": 0, 28 | "F": 1 29 | } 30 | 31 | 32 | def convert_binary_bracketed_data(filename): 33 | examples = [] 34 | with open(filename, 'r') as f: 35 | for line in f: 36 | example = {} 37 | line = line.strip() 38 | tab_split = line.split('\t') 39 | example["label"] = tab_split[0] 40 | example["sentence"] = tab_split[1] 41 | example["tokens"] = [] 42 | example["transitions"] = [] 43 | 44 | for word in example["sentence"].split(' '): 45 | if word != "(": 46 | if word != ")": 47 | example["tokens"].append(word) 48 | example["transitions"].append(1 if word == ")" else 0) 49 | 50 | examples.append(example) 51 | return examples 52 | 53 | 54 | def load_data(path): 55 | dataset = convert_binary_bracketed_data(path) 56 | return dataset, FIXED_VOCABULARY 57 | 58 | if __name__ == "__main__": 59 | # Demo: 60 | examples = import_binary_bracketed_data('bl-data/bl_dev.tsv') 61 | print examples[0] 62 | -------------------------------------------------------------------------------- /python/spinn/data/snli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordnlp/spinn/f2f95f585328ec5cd7da071753f5e0582e7600a0/python/spinn/data/snli/__init__.py -------------------------------------------------------------------------------- /python/spinn/data/snli/load_snli_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import json 4 | 5 | SENTENCE_PAIR_DATA = True 6 | 7 | LABEL_MAP = { 8 | "entailment": 0, 9 | "neutral": 1, 10 | "contradiction": 2 11 | } 12 | 13 | def convert_binary_bracketing(parse): 14 | transitions = [] 15 | tokens = [] 16 | for word in parse.split(' '): 17 | if word[0] != "(": 18 | if word == ")": 19 | transitions.append(1) 20 | else: 21 | # Downcase all words to match GloVe. 22 | tokens.append(word.lower()) 23 | transitions.append(0) 24 | return tokens, transitions 25 | 26 | def load_data(path): 27 | print "Loading", path 28 | examples = [] 29 | with open(path, 'r') as f: 30 | for line in f: 31 | loaded_example = json.loads(line) 32 | if loaded_example["gold_label"] not in LABEL_MAP: 33 | continue 34 | 35 | example = {} 36 | example["label"] = loaded_example["gold_label"] 37 | example["premise"] = loaded_example["sentence1"] 38 | example["hypothesis"] = loaded_example["sentence2"] 39 | (example["premise_tokens"], example["premise_transitions"]) = convert_binary_bracketing(loaded_example["sentence1_binary_parse"]) 40 | (example["hypothesis_tokens"], example["hypothesis_transitions"]) = convert_binary_bracketing(loaded_example["sentence2_binary_parse"]) 41 | examples.append(example) 42 | return examples, None 43 | 44 | 45 | if __name__ == "__main__": 46 | # Demo: 47 | examples = load_data('snli-data/snli_1.0_dev.jsonl') 48 | print examples[0] 49 | -------------------------------------------------------------------------------- /python/spinn/data/sst/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordnlp/spinn/f2f95f585328ec5cd7da071753f5e0582e7600a0/python/spinn/data/sst/__init__.py -------------------------------------------------------------------------------- /python/spinn/data/sst/load_sst_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Loads a file where each line contains a label, followed by a tab, followed 4 | # by a sequence of words with a binary parse indicated by space-separated parentheses. 5 | # 6 | # Example: 7 | # sentence_label ( ( word word ) ( ( word word ) word ) ) 8 | 9 | import collections 10 | import numpy as np 11 | 12 | from spinn import util 13 | 14 | SENTENCE_PAIR_DATA = False 15 | 16 | LABEL_MAP = { 17 | "0": 0, 18 | "1": 1, 19 | "2": 2, 20 | "3": 3, 21 | "4": 4 22 | } 23 | 24 | 25 | def convert_unary_binary_bracketed_data(filename): 26 | # Build a binary tree out of a binary parse in which every 27 | # leaf node is wrapped as a unary constituent, as here: 28 | # (4 (2 (2 The ) (2 actors ) ) (3 (4 (2 are ) (3 fantastic ) ) (2 . ) ) ) 29 | examples = [] 30 | with open(filename, 'r') as f: 31 | for line in f: 32 | example = {} 33 | line = line.strip() 34 | if len(line) == 0: 35 | continue 36 | example["label"] = line[1] 37 | example["sentence"] = line 38 | example["tokens"] = [] 39 | example["transitions"] = [] 40 | 41 | words = example["sentence"].split(' ') 42 | for index, word in enumerate(words): 43 | if word[0] != "(": 44 | if word == ")": 45 | # Ignore unary merges 46 | if words[index - 1] == ")": 47 | example["transitions"].append(1) 48 | else: 49 | # Downcase all words to match GloVe. 50 | example["tokens"].append(word.lower()) 51 | example["transitions"].append(0) 52 | examples.append(example) 53 | return examples 54 | 55 | 56 | def load_data(path, vocabulary=None, seq_length=None, batch_size=32, eval_mode=False, logger=None): 57 | dataset = convert_unary_binary_bracketed_data(path) 58 | return dataset, None 59 | 60 | 61 | if __name__ == "__main__": 62 | # Demo: 63 | examples = import_binary_bracketed_data('sst-data/dev.txt') 64 | print examples[0] 65 | -------------------------------------------------------------------------------- /python/spinn/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordnlp/spinn/f2f95f585328ec5cd7da071753f5e0582e7600a0/python/spinn/models/__init__.py -------------------------------------------------------------------------------- /python/spinn/plain_rnn.py: -------------------------------------------------------------------------------- 1 | """Theano-based RNN implementations.""" 2 | 3 | import numpy as np 4 | import theano 5 | 6 | from theano import tensor as T 7 | from spinn import util 8 | 9 | 10 | class RNN(object): 11 | """Plain RNN encoder implementation. Can use any activation function. 12 | """ 13 | 14 | def __init__(self, model_dim, word_embedding_dim, vocab_size, _0, compose_network, 15 | _1, training_mode, _2, vs, 16 | train_with_predicted_transitions=False, 17 | X=None, 18 | initial_embeddings=None, 19 | make_test_fn=False, 20 | **kwargs): 21 | """Construct an RNN. 22 | 23 | Args: 24 | model_dim: Dimensionality of hidden state. 25 | vocab_size: Number of unique tokens in vocabulary. 26 | compose_network: Blocks-like function which accepts arguments 27 | `prev_hidden_state, inp, inp_dim, hidden_dim, vs, name` (see e.g. `util.LSTMLayer`). 28 | training_mode: A Theano scalar indicating whether to act as a training model 29 | with dropout (1.0) or to act as an eval model with rescaling (0.0). 30 | vs: VariableStore instance for parameter storage 31 | X: Theano batch describing input matrix, or `None` (in which case 32 | this instance will make its own batch variable). 33 | make_test_fn: If set, create a function to run a scan for testing. 34 | kwargs, _0, _1, _2: Ignored. Meant to make the signature match the signature of HardStack(). 35 | """ 36 | 37 | self.model_dim = model_dim 38 | self.word_embedding_dim = word_embedding_dim 39 | self.vocab_size = vocab_size 40 | 41 | self._compose_network = compose_network 42 | 43 | self._vs = vs 44 | 45 | self.initial_embeddings = initial_embeddings 46 | 47 | self.training_mode = training_mode 48 | 49 | self.X = X 50 | 51 | self._make_params() 52 | self._make_inputs() 53 | self._make_scan() 54 | 55 | if make_test_fn: 56 | self.scan_fn = theano.function([self.X, self.training_mode], 57 | self.final_representations, 58 | on_unused_input='warn') 59 | 60 | def _make_params(self): 61 | # Per-token embeddings. 62 | if self.initial_embeddings is not None: 63 | def EmbeddingInitializer(shape): 64 | return self.initial_embeddings 65 | self.embeddings = self._vs.add_param( 66 | "embeddings", (self.vocab_size, self.word_embedding_dim), 67 | initializer=EmbeddingInitializer, 68 | trainable=False, 69 | savable=False) 70 | else: 71 | self.embeddings = self._vs.add_param( 72 | "embeddings", (self.vocab_size, self.word_embedding_dim)) 73 | 74 | def _make_inputs(self): 75 | self.X = self.X or T.imatrix("X") 76 | 77 | def _step(self, inputs_cur_t, hidden_prev_t): 78 | hidden_state_cur_t = self._compose_network(hidden_prev_t, inputs_cur_t, 79 | self.word_embedding_dim, self.model_dim, self._vs, name="rnn") 80 | 81 | return hidden_state_cur_t 82 | 83 | def _make_scan(self): 84 | """Build the sequential composition / scan graph.""" 85 | 86 | batch_size, seq_length = self.X.shape 87 | 88 | # Look up all of the embeddings that will be used. 89 | raw_embeddings = self.embeddings[self.X] # batch_size * seq_length * emb_dim 90 | raw_embeddings = raw_embeddings.dimshuffle(1, 0, 2) 91 | 92 | # Initialize the hidden state. 93 | hidden_init = T.zeros((batch_size, self.model_dim)) 94 | 95 | self.states = theano.scan( 96 | self._step, 97 | sequences=[raw_embeddings], 98 | outputs_info=[hidden_init])[0] 99 | 100 | self.final_representations = self.states[-1] 101 | self.transitions_pred = T.zeros((batch_size, 0)) 102 | self.predict_transitions = False 103 | self.tracking_state_final = None 104 | 105 | -------------------------------------------------------------------------------- /python/spinn/recurrences.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements the core recurrences for various stack models. 3 | 4 | The recurrences described here are unrolled into bona-fide stack models 5 | by `spinn.stack`. 6 | """ 7 | 8 | from functools import partial 9 | 10 | from theano import tensor as T 11 | 12 | from spinn import util 13 | 14 | 15 | class Recurrence(object): 16 | 17 | def __init__(self, spec, vs): 18 | self._spec = spec 19 | self._vs = vs 20 | 21 | # TODO(SB): Standardize terminology in comments -- 22 | # Merge/push v. push/pop v. shift/reduce... 23 | 24 | # A recurrence is expected to output 1 value in a merge op and zero 25 | # values in a push op. A recurrence may also output `N` "extra" 26 | # outputs, expected at both merge and push ops. 27 | # 28 | # This list should contain `N` shape tuples. The tuple at position `i` 29 | # specifies that extra output #`i` will have shape `extra_outputs[i]` 30 | # for a single example (i.e., not including batch axis). For example, 31 | # if a recurrence yields a 50-dimensional vector for each example at 32 | # each timestep, we would include `(50,)` here. 33 | self.extra_outputs = [] 34 | 35 | self.predicts_transitions = False 36 | self.uses_predictions = False 37 | 38 | def __call__(self, inputs, **constants): 39 | """ 40 | Computes push and merge results for a single timestep. 41 | 42 | Args: 43 | inputs: A tuple of inputs to the recurrence. At minimum this 44 | contains three elements: `(stack_1, stack_2, buffer_top)`. 45 | A recurrence which has `N` extra outputs will also receive 46 | all `N` outputs from the previous timestep concatenated to 47 | this tuple, e.g. 48 | 49 | (stack_1, stack_2, buffer_top, 50 | prev_output_1, prev_output_2, ...) 51 | 52 | Each input is a batch of values, with batch_size as the leading 53 | axis. 54 | constants: 55 | TBD 56 | 57 | Returns: 58 | push_outputs: A list of batch outputs for the case in which the 59 | current op is a push. This list should be `self.extra_outputs` 60 | long. 61 | merge_outputs: A list of batch outputs for the case in which the 62 | current op is a merge. This list should be 63 | `1 + self.extra_outputs` long. (The first element should be 64 | the result of merging the stack top values, a batch tensor of 65 | shape `batch_size * self._spec.model_dim`.) 66 | actions: (Only necessary if `self.predicts_transitions` is `True`.) 67 | A batch of logits over stack actions at this timestep (of shape 68 | `batch_size * num_actions`). 69 | """ 70 | raise NotImplementedError("abstract method") 71 | 72 | 73 | class SharedRecurrenceMixin(object): 74 | """Mixin providing various shared components.""" 75 | 76 | def __init__(self): 77 | raise RuntimeError("Don't instantiate me; I'm a mixin!") 78 | 79 | def _context_sensitive_shift(self, inputs): 80 | """ 81 | Compute a buffer top representation by mixing buffer top and hidden state. 82 | 83 | NB: This hasn't been an especially effective tool so far. 84 | """ 85 | assert self.use_tracking_lstm 86 | buffer_top, tracking_hidden = inputs[2:4] 87 | 88 | # Exclude the cell value from the computation. 89 | tracking_hidden = tracking_hidden[:, :hidden_dim] 90 | 91 | inp = T.concatenate([tracking_hidden, buffer_top], axis=1) 92 | inp_dim = self._spec.word_embedding_dim + self.tracking_lstm_hidden_dim 93 | layer = util.ReLULayer if self.context_sensitive_use_relu else util.Linear 94 | return layer(inp, inp_dim, self._spec.model_dim, self._vs, 95 | name="context_comb_unit", use_bias=True, 96 | initializer=util.HeKaimingInitializer()) 97 | 98 | def _tracking_lstm_predict(self, inputs, network): 99 | # TODO(SB): Offer more buffer content than just the top as input. 100 | c1, c2, buffer_top, tracking_hidden = inputs[:4] 101 | 102 | h_dim = self._spec.model_dim 103 | if self._spec.model_visible_dim != self._spec.model_dim: 104 | h_dim = self._spec.model_visible_dim 105 | c1 = c1[:, :h_dim] 106 | c2 = c2[:, :h_dim] 107 | buffer_top = buffer_top[:, :h_dim] 108 | 109 | inp = (c1, c2, buffer_top) 110 | return network(tracking_hidden, inp, (h_dim,) * 3, 111 | self.tracking_lstm_hidden_dim, self._vs, 112 | name="prediction_and_tracking") 113 | 114 | def _predict(self, inputs, network): 115 | # TODO(SB): Offer more buffer content than just the top as input. 116 | c1, c2, buffer_top = tuple(inputs[:3]) 117 | 118 | h_dim = self._spec.model_dim 119 | if self._spec.model_visible_dim != self._spec.model_dim: 120 | h_dim = self._spec.model_visible_dim 121 | c1 = c1[:, :h_dim] 122 | c2 = c2[:, :h_dim] 123 | buffer_top = buffer_top[:, :h_dim] 124 | 125 | inp = (c1, c2, buffer_top) 126 | return network(inp, (h_dim,) * 3, util.NUM_TRANSITION_TYPES, self._vs, 127 | name="prediction_and_tracking") 128 | 129 | def _merge(self, inputs, network): 130 | merge_items = tuple(inputs[:2]) 131 | if self.use_tracking_lstm: 132 | # NB: Unlike in the previous implementation, context-sensitive 133 | # composition (aka the tracking--composition connection) is not 134 | # optional here. It helps performance, so this shouldn't be a 135 | # big problem. 136 | tracking_h_t = inputs[3][:, :self.tracking_lstm_hidden_dim] 137 | return network(merge_items, tracking_h_t, self._spec.model_dim, 138 | self._vs, name="compose", 139 | external_state_dim=self.tracking_lstm_hidden_dim) 140 | else: 141 | return network(merge_items, (self._spec.model_dim,) * 2, 142 | self._spec.model_dim, self._vs, name="compose") 143 | 144 | 145 | class Model0(Recurrence, SharedRecurrenceMixin): 146 | 147 | def __init__(self, spec, vs, compose_network, 148 | use_tracking_lstm=False, 149 | tracking_lstm_hidden_dim=8, 150 | use_context_sensitive_shift=False, 151 | context_sensitive_use_relu=False): 152 | super(Model0, self).__init__(spec, vs) 153 | self.extra_outputs = [] 154 | if use_tracking_lstm: 155 | self.extra_outputs.append((tracking_lstm_hidden_dim * 2,)) 156 | self.predicts_transitions = False 157 | 158 | self._compose_network = compose_network 159 | self.use_tracking_lstm = use_tracking_lstm 160 | self.tracking_lstm_hidden_dim = tracking_lstm_hidden_dim 161 | self.use_context_sensitive_shift = use_context_sensitive_shift 162 | self.context_sensitive_use_relu = context_sensitive_use_relu 163 | 164 | if use_tracking_lstm: 165 | self._prediction_and_tracking_network = partial(util.TrackingUnit, 166 | make_logits=False) 167 | 168 | def __call__(self, inputs, **constants): 169 | c1, c2, buffer_top = inputs[:3] 170 | if self.use_tracking_lstm: 171 | tracking_hidden = inputs[3] 172 | 173 | # Unlike in the previous implementation, we update the tracking LSTM 174 | # before using its output to update the inputs. 175 | if self.use_tracking_lstm: 176 | tracking_hidden, _ = self._tracking_lstm_predict( 177 | inputs, self._prediction_and_tracking_network) 178 | inputs = [c1, c2, buffer_top, tracking_hidden] 179 | 180 | if self.use_context_sensitive_shift: 181 | buffer_top = self._context_sensitive_shift(inputs) 182 | 183 | merge_value = self._merge(inputs, self._compose_network) 184 | 185 | if self.use_tracking_lstm: 186 | return [tracking_hidden], [merge_value, tracking_hidden] 187 | else: 188 | return [], [merge_value] 189 | 190 | 191 | class Model1(Recurrence, SharedRecurrenceMixin): 192 | 193 | def __init__(self, spec, vs, compose_network, 194 | use_tracking_lstm=False, 195 | tracking_lstm_hidden_dim=8, 196 | use_context_sensitive_shift=False, 197 | context_sensitive_use_relu=False): 198 | super(Model1, self).__init__(spec, vs) 199 | if use_tracking_lstm: 200 | self.extra_outputs.append((tracking_lstm_hidden_dim * 2,)) 201 | self.predicts_transitions = True 202 | self.uses_predictions = False 203 | 204 | self._compose_network = compose_network 205 | self.use_tracking_lstm = use_tracking_lstm 206 | self.tracking_lstm_hidden_dim = tracking_lstm_hidden_dim 207 | self.use_context_sensitive_shift = use_context_sensitive_shift 208 | self.context_sensitive_use_relu = context_sensitive_use_relu 209 | 210 | if use_tracking_lstm: 211 | self._prediction_and_tracking_network = partial(util.TrackingUnit, 212 | make_logits=False) 213 | else: 214 | self._prediction_and_tracking_network = util.Linear 215 | 216 | def __call__(self, inputs, **constants): 217 | c1, c2, buffer_top = inputs[:3] 218 | if self.use_tracking_lstm: 219 | tracking_hidden = inputs[3] 220 | 221 | # Predict transitions. 222 | 223 | # Unlike in the previous implementation, we update the tracking LSTM 224 | # before using its output to update the inputs. 225 | if self.use_tracking_lstm: 226 | tracking_hidden, actions_t = self._tracking_lstm_predict( 227 | inputs, self._prediction_and_tracking_network) 228 | inputs = [c1, c2, buffer_top, tracking_hidden] 229 | else: 230 | actions_t = self._predict( 231 | inputs, self._prediction_and_tracking_network) 232 | 233 | if self.use_context_sensitive_shift: 234 | buffer_top = self._context_sensitive_shift(inputs) 235 | 236 | merge_value = self._merge(inputs, self._compose_network) 237 | 238 | if self.use_tracking_lstm: 239 | return [tracking_hidden], [merge_value, tracking_hidden], actions_t 240 | else: 241 | return [], [merge_value], actions_t 242 | 243 | 244 | class Model2(Model1, SharedRecurrenceMixin): 245 | """Core implementation of Model 2. Supports scheduled sampling.""" 246 | 247 | def __init__(self, spec, vs, compose_network, **kwargs): 248 | super(Model2, self).__init__(spec, vs, compose_network, **kwargs) 249 | self.uses_predictions = True 250 | -------------------------------------------------------------------------------- /python/spinn/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordnlp/spinn/f2f95f585328ec5cd7da071753f5e0582e7600a0/python/spinn/tests/__init__.py -------------------------------------------------------------------------------- /python/spinn/tests/test_cuda_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import theano 3 | import theano.tensor as T 4 | 5 | # Import cuda util in order to register the optimization 6 | from spinn.util import cuda 7 | 8 | 9 | def _test_gpu_rowwise_switch_inner(f, A, B, mask, expected): 10 | ret = f(A, B, mask) 11 | print A 12 | print B 13 | print mask 14 | print ret 15 | print expected 16 | np.testing.assert_array_almost_equal(ret, expected) 17 | 18 | 19 | def test_gpu_rowwise_switch(): 20 | assert theano.config.device.startswith("gpu"), "Need to test on GPU!" 21 | 22 | data = [ 23 | # 4 x 2 24 | (np.array([[ 0.22323515, 0.36703175], 25 | [ 0.82260513, 0.3461504 ], 26 | [ 0.82362652, 0.81626087], 27 | [ 0.95270008, 0.2226797 ]]), 28 | np.array([[ 0.36341551, 0.20102882], 29 | [ 0.24144639, 0.45237923], 30 | [ 0.39951822, 0.7348066 ], 31 | [ 0.16649647, 0.60306537]]), 32 | np.array([1, 0, 1, 1]), 33 | np.array([[ 0.22323515, 0.36703175], 34 | [ 0.24144639, 0.45237923], 35 | [ 0.82362652, 0.81626087], 36 | [ 0.95270008, 0.2226797 ]])), 37 | 38 | # 2 x 3 x 4 39 | (np.array([[[ 0.48769062, 0.82649632, 0.2047115 , 0.41437615], 40 | [ 0.25290664, 0.87164914, 0.80968588, 0.49295084], 41 | [ 0.71438099, 0.97913502, 0.37598001, 0.76958707]], 42 | 43 | [[ 0.37605973, 0.538358 , 0.74304674, 0.84346291], 44 | [ 0.95310617, 0.61540292, 0.49881143, 0.1028554 ], 45 | [ 0.83481996, 0.90969569, 0.40410424, 0.34419989]]]), 46 | np.array([[[ 0.7289117 , 0.97323253, 0.19070121, 0.64164653], 47 | [ 0.26816493, 0.76093069, 0.95284825, 0.77350426], 48 | [ 0.55415519, 0.39431256, 0.86588665, 0.50031027]], 49 | 50 | [[ 0.1980869 , 0.7753601 , 0.26810868, 0.3628802 ], 51 | [ 0.2488143 , 0.21278388, 0.09724567, 0.58457886], 52 | [ 0.12295105, 0.75321368, 0.37258797, 0.27756972]]]), 53 | np.array([1, 0]), 54 | np.array([[[ 0.48769062, 0.82649632, 0.2047115 , 0.41437615], 55 | [ 0.25290664, 0.87164914, 0.80968588, 0.49295084], 56 | [ 0.71438099, 0.97913502, 0.37598001, 0.76958707]], 57 | 58 | [[ 0.1980869 , 0.7753601 , 0.26810868, 0.3628802 ], 59 | [ 0.2488143 , 0.21278388, 0.09724567, 0.58457886], 60 | [ 0.12295105, 0.75321368, 0.37258797, 0.27756972]]])) 61 | 62 | ] 63 | 64 | A2, B2 = T.matrices("AB") 65 | A3, B3 = T.tensor3("A"), T.tensor3("B") 66 | mask = T.ivector("mask") 67 | 68 | switch2 = T.switch(mask.dimshuffle(0, "x"), A2, B2) 69 | switch3 = T.switch(mask.dimshuffle(0, "x", "x"), A3, B3) 70 | 71 | f2 = theano.function([A2, B2, mask], switch2) 72 | f3 = theano.function([A3, B3, mask], switch3) 73 | 74 | print "Graph of 2dim switch:" 75 | theano.printing.debugprint(f2.maker.fgraph.outputs[0]) 76 | print "Graph of 3dim switch:" 77 | theano.printing.debugprint(f3.maker.fgraph.outputs[0]) 78 | 79 | for instance in data: 80 | # Retrieve appropriate function 81 | func = f2 if instance[0].ndim == 2 else f3 82 | 83 | # Cast to float-friendly types 84 | instance = [x.astype(np.float32) if x.dtype.kind == 'f' 85 | else x.astype(np.int32) for x in instance] 86 | 87 | yield tuple([_test_gpu_rowwise_switch_inner, func] + instance) 88 | 89 | 90 | def _test_masked_careduce_inner(f, X, Y, mask, expected): 91 | print X 92 | print Y 93 | print mask 94 | 95 | ret = f(X, Y, mask) 96 | print ret 97 | print expected 98 | np.testing.assert_array_almost_equal(ret, expected) 99 | 100 | 101 | def test_masked_careduce(): 102 | assert theano.config.device.startswith("gpu"), "Need to test on GPU!" 103 | 104 | data = [ 105 | (np.array([[[ 3., 1.], 106 | [ 4., 8.]], 107 | 108 | [[ 9., 4.], 109 | [ 3., 6.]], 110 | 111 | [[ 5., 2.], 112 | [ 6., 2.]]]), 113 | 114 | np.array([[[ 10., 7.], 115 | [ 7, 5.]], 116 | 117 | [[ 2., 1.], 118 | [ 5., 9.]], 119 | 120 | [[ 0., 6.], 121 | [ 3., 4.]]]), 122 | 123 | np.array([0, 1, 0]), 124 | 125 | np.array([[ 19., 17.], 126 | [ 13., 15.]])) 127 | ] 128 | 129 | X, Y = T.tensor3("X"), T.tensor3("Y") 130 | mask = T.fvector("mask") 131 | mask_ = mask.dimshuffle(0, "x", "x") 132 | 133 | switch = T.switch(mask_, X, Y) 134 | out = switch.sum(axis=0) 135 | 136 | f = theano.function([X, Y, mask], out) 137 | 138 | print "Graph of switch+sum:" 139 | theano.printing.debugprint(f.maker.fgraph.outputs[0]) 140 | 141 | for instance in data: 142 | # Cast to float-friendly types 143 | instance = [x.astype(np.float32) for x in instance] 144 | yield tuple([_test_masked_careduce_inner, f] + instance) 145 | 146 | -------------------------------------------------------------------------------- /python/spinn/tests/test_embedding_matrix.5d.txt: -------------------------------------------------------------------------------- 1 | the 0.418 0.24968 -0.41242 0.1217 0.34527 2 | , 0.013441 0.23682 -0.16899 0.40951 0.63812 3 | . 0.15164 0.30177 -0.16763 0.17684 0.31719 4 | of 0.70853 0.57088 -0.4716 0.18048 0.54449 5 | to 0.68047 -0.039263 0.30186 -0.17792 0.42962 6 | and 0.26818 0.14346 -0.27877 0.016257 0.11384 7 | in 0.33042 0.24995 -0.60874 0.10923 0.036372 8 | a 0.21705 0.46515 -0.46757 0.10082 1.0135 9 | " 0.25769 0.45629 -0.76974 -0.37679 0.59272 10 | _ 0.23727 0.40478 -0.20547 0.58805 0.65533 11 | -------------------------------------------------------------------------------- /python/spinn/tests/test_plain_rnn.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import theano 5 | from theano import tensor as T 6 | 7 | from spinn.plain_rnn import RNN 8 | from spinn.util import VariableStore, CropAndPad, IdentityLayer 9 | 10 | 11 | class RNNTestCase(unittest.TestCase): 12 | 13 | """Basic functional tests for RNN with dummy data.""" 14 | 15 | def _make_rnn(self, seq_length=4): 16 | self.embedding_dim = embedding_dim = 3 17 | self.vocab_size = vocab_size = 10 18 | self.seq_length = seq_length 19 | 20 | def compose_network(h_prev, inp, embedding_dim, model_dim, vs, name="compose"): 21 | # Just add the two embeddings! 22 | W = T.concatenate([T.eye(model_dim), T.eye(model_dim)], axis=0) 23 | i = T.concatenate([h_prev, inp], axis=1) 24 | return i.dot(W) 25 | 26 | X = T.imatrix("X") 27 | training_mode = T.scalar("training_mode") 28 | vs = VariableStore() 29 | embeddings = np.arange(vocab_size).reshape( 30 | (vocab_size, 1)).repeat(embedding_dim, axis=1) 31 | self.model = RNN( 32 | embedding_dim, embedding_dim, vocab_size, seq_length, compose_network, 33 | IdentityLayer, training_mode, None, vs, 34 | X=X, make_test_fn=True, initial_embeddings=embeddings) 35 | 36 | def test_basic_ff(self): 37 | self._make_rnn(4) 38 | 39 | X = np.array([ 40 | [3, 1, 2, 0], 41 | [3, 2, 4, 5] 42 | ], dtype=np.int32) 43 | 44 | expected = np.array([[6, 6, 6], 45 | [14, 14, 14]]) 46 | 47 | ret = self.model.scan_fn(X, 1.0) 48 | np.testing.assert_almost_equal(ret, expected) 49 | 50 | 51 | if __name__ == '__main__': 52 | unittest.main() 53 | -------------------------------------------------------------------------------- /python/spinn/tests/test_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from nose.tools import assert_equal 4 | 5 | from spinn import util 6 | 7 | 8 | TEST_EMBEDDING_MATRIX = "spinn/tests/test_embedding_matrix.5d.txt" 9 | 10 | 11 | def test_build_vocabulary_for_ascii_embedding_file(): 12 | types_in_data = ["and", "the", "_", "strange_and_exotic_word"] 13 | core_vocabulary = {"*PADDING*":0} 14 | vocabulary = util.BuildVocabularyForASCIIEmbeddingFile(TEST_EMBEDDING_MATRIX, types_in_data, core_vocabulary) 15 | 16 | expected = { 17 | "*PADDING*" : 0, 18 | "the" : 1, 19 | "and" : 2, 20 | "_" : 3, 21 | } 22 | 23 | assert_equal(vocabulary, expected) 24 | 25 | 26 | def test_load_embeddings_from_ascii(): 27 | vocabulary = {"strange_and_exotic_word" : 0, "the" : 1, "." : 2} 28 | loaded_matrix = util.LoadEmbeddingsFromASCII(vocabulary, 5, TEST_EMBEDDING_MATRIX) 29 | expected = np.asarray( 30 | [[0, 0, 0, 0, 0], 31 | [0.418, 0.24968, -0.41242, 0.1217, 0.34527], 32 | [0.15164, 0.30177, -0.16763, 0.17684, 0.31719]], dtype=np.float32) 33 | 34 | np.testing.assert_array_equal(loaded_matrix, expected) 35 | 36 | def test_crop_and_pad_example(): 37 | def _run_asserts(seq, tgt_length, expected): 38 | example = {"seq": seq} 39 | left_padding = tgt_length - len(seq) 40 | util.CropAndPadExample(example, left_padding, tgt_length, "seq") 41 | assert_equal(example["seq"], expected) 42 | 43 | seqs = [ 44 | ([1, 1, 1], 4, [0, 1, 1, 1]), 45 | ([1, 2, 3], 2, [2, 3]) 46 | ] 47 | 48 | for seq, tgt_length, expected in seqs: 49 | yield _run_asserts, seq, tgt_length, expected 50 | 51 | 52 | def test_crop_and_pad(): 53 | dataset = [ 54 | { 55 | # Transitions too long -- will need to crop both 56 | "tokens": [1, 2, 4, 3, 6, 2], 57 | "transitions": [0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1] 58 | }, 59 | { 60 | # Transitions too short -- will need to pad transition seq and 61 | # adjust tokens with dummy elements accordingly 62 | "tokens": [6, 1], 63 | "transitions": [0, 0, 1] 64 | }, 65 | { 66 | # Transitions too long; lots of pushes 67 | "tokens": [6, 1, 2, 3, 5, 1], 68 | "transitions": [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1] 69 | } 70 | ] 71 | 72 | length = 5 73 | 74 | # Expectations: 75 | # - When the transition sequence is too short, it will be left padded with 76 | # zeros, and the corresponding token sequence will be left padded with the 77 | # same number of zeros. 78 | # - When the transition sequence is too long, it will be cropped from the left, 79 | # since this ensures that there will be a unique stack top element at the 80 | # final step. If this is not the case, than most of the model's effort 81 | # will go into building embeddings that stay higher in the stack, and is thus 82 | # wasted. The corresponding token sequence will be cropped by removing 83 | # as many elements on the left side as there were zeros removed from the 84 | # transition sequence. 85 | # - num_transitions reports the number of transitions in the original sequence. 86 | expected = [ 87 | { 88 | "tokens": [6, 2, 0, 0, 0], 89 | "transitions": [1, 0, 1, 0, 1], 90 | "num_transitions": 11 91 | }, 92 | { 93 | "tokens": [0, 0, 6, 1, 0], 94 | "transitions": [0, 0, 0, 0, 1], 95 | "num_transitions": 3 96 | }, 97 | { 98 | "tokens": [0, 0, 0, 0, 0], 99 | "transitions": [1, 1, 1, 1, 1], 100 | "num_transitions": 11 101 | } 102 | ] 103 | 104 | dataset = util.CropAndPad(dataset, length) 105 | assert_equal(dataset, expected) 106 | 107 | if __name__ == '__main__': 108 | test_build_vocabulary_for_ascii_embedding_file() 109 | test_load_embeddings_from_ascii() 110 | test_crop_and_pad() 111 | test_crop_and_pad_example() 112 | -------------------------------------------------------------------------------- /python/spinn/util/__init__.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | from theano.sandbox.cuda import cuda_available 4 | 5 | # Only import custom CUDA ops if we on a CUDA-enabled host. 6 | if cuda_available: 7 | from spinn.util.cuda import * 8 | 9 | from spinn.util.theano_internal import * 10 | from spinn.util.data import * 11 | from spinn.util.blocks import * 12 | from spinn.util.variable_store import VariableStore 13 | 14 | 15 | ModelSpec_ = namedtuple("ModelSpec", ["model_dim", "word_embedding_dim", 16 | "batch_size", "vocab_size", "seq_length", 17 | "model_visible_dim"]) 18 | def ModelSpec(*args, **kwargs): 19 | args = dict(zip(ModelSpec_._fields, args)) 20 | args.update(kwargs) 21 | 22 | # Defaults 23 | if "model_visible_dim" not in args: 24 | args["model_visible_dim"] = args["model_dim"] 25 | 26 | return ModelSpec_(**args) 27 | 28 | 29 | -------------------------------------------------------------------------------- /python/spinn/util/data.py: -------------------------------------------------------------------------------- 1 | """Dataset handling and related yuck.""" 2 | 3 | import random 4 | import itertools 5 | 6 | import numpy as np 7 | import theano 8 | 9 | 10 | # With loaded embedding matrix, the padding vector will be initialized to zero 11 | # and will not be trained. Hopefully this isn't a problem. It seems better than 12 | # random initialization... 13 | PADDING_TOKEN = "*PADDING*" 14 | 15 | # Temporary hack: Map UNK to "_" when loading pretrained embedding matrices: 16 | # it's a common token that is pretrained, but shouldn't look like any content words. 17 | UNK_TOKEN = "_" 18 | 19 | CORE_VOCABULARY = {PADDING_TOKEN: 0, 20 | UNK_TOKEN: 1} 21 | 22 | # Allowed number of transition types : currently PUSH : 0 and MERGE : 1 23 | NUM_TRANSITION_TYPES = 2 24 | 25 | 26 | def TrimDataset(dataset, seq_length, eval_mode=False, sentence_pair_data=False): 27 | """Avoid using excessively long training examples.""" 28 | if eval_mode: 29 | return dataset 30 | else: 31 | if sentence_pair_data: 32 | new_dataset = [example for example in dataset if 33 | len(example["premise_transitions"]) <= seq_length and 34 | len(example["hypothesis_transitions"]) <= seq_length] 35 | else: 36 | new_dataset = [example for example in dataset if len( 37 | example["transitions"]) <= seq_length] 38 | return new_dataset 39 | 40 | 41 | def TokensToIDs(vocabulary, dataset, sentence_pair_data=False): 42 | """Replace strings in original boolean dataset with token IDs.""" 43 | if sentence_pair_data: 44 | keys = ["premise_tokens", "hypothesis_tokens"] 45 | else: 46 | keys = ["tokens"] 47 | 48 | for key in keys: 49 | if UNK_TOKEN in vocabulary: 50 | unk_id = vocabulary[UNK_TOKEN] 51 | for example in dataset: 52 | example[key] = [vocabulary.get(token, unk_id) 53 | for token in example[key]] 54 | else: 55 | for example in dataset: 56 | example[key] = [vocabulary[token] 57 | for token in example[key]] 58 | return dataset 59 | 60 | 61 | def CropAndPadExample(example, left_padding, target_length, key, logger=None): 62 | """ 63 | Crop/pad a sequence value of the given dict `example`. 64 | """ 65 | if left_padding < 0: 66 | # Crop, then pad normally. 67 | # TODO: Track how many sentences are cropped, but don't log a message 68 | # for every single one. 69 | example[key] = example[key][-left_padding:] 70 | left_padding = 0 71 | right_padding = target_length - (left_padding + len(example[key])) 72 | example[key] = ([0] * left_padding) + \ 73 | example[key] + ([0] * right_padding) 74 | 75 | 76 | def CropAndPad(dataset, length, logger=None, sentence_pair_data=False): 77 | # NOTE: This can probably be done faster in NumPy if it winds up making a 78 | # difference. 79 | # Always make sure that the transitions are aligned at the left edge, so 80 | # the final stack top is the root of the tree. If cropping is used, it should 81 | # just introduce empty nodes into the tree. 82 | if sentence_pair_data: 83 | keys = [("premise_transitions", "num_premise_transitions", "premise_tokens"), 84 | ("hypothesis_transitions", "num_hypothesis_transitions", "hypothesis_tokens")] 85 | else: 86 | keys = [("transitions", "num_transitions", "tokens")] 87 | 88 | for example in dataset: 89 | for (transitions_key, num_transitions_key, tokens_key) in keys: 90 | example[num_transitions_key] = len(example[transitions_key]) 91 | transitions_left_padding = length - example[num_transitions_key] 92 | shifts_before_crop_and_pad = example[transitions_key].count(0) 93 | CropAndPadExample( 94 | example, transitions_left_padding, length, transitions_key, logger=logger) 95 | shifts_after_crop_and_pad = example[transitions_key].count(0) 96 | tokens_left_padding = shifts_after_crop_and_pad - \ 97 | shifts_before_crop_and_pad 98 | CropAndPadExample( 99 | example, tokens_left_padding, length, tokens_key, logger=logger) 100 | return dataset 101 | 102 | def CropAndPadForRNN(dataset, length, logger=None, sentence_pair_data=False): 103 | # NOTE: This can probably be done faster in NumPy if it winds up making a 104 | # difference. 105 | if sentence_pair_data: 106 | keys = ["premise_tokens", 107 | "hypothesis_tokens"] 108 | else: 109 | keys = ["tokens"] 110 | 111 | for example in dataset: 112 | for tokens_key in keys: 113 | num_tokens = len(example[tokens_key]) 114 | tokens_left_padding = length - num_tokens 115 | CropAndPadExample( 116 | example, tokens_left_padding, length, tokens_key, logger=logger) 117 | return dataset 118 | 119 | 120 | def MakeTrainingIterator(sources, batch_size): 121 | # Make an iterator that exposes a dataset as random minibatches. 122 | 123 | def data_iter(): 124 | dataset_size = len(sources[0]) 125 | start = -1 * batch_size 126 | order = range(dataset_size) 127 | random.shuffle(order) 128 | 129 | while True: 130 | start += batch_size 131 | if start > dataset_size - batch_size: 132 | # Start another epoch. 133 | start = 0 134 | random.shuffle(order) 135 | batch_indices = order[start:start + batch_size] 136 | yield tuple(source[batch_indices] for source in sources) 137 | return data_iter() 138 | 139 | 140 | def MakeEvalIterator(sources, batch_size): 141 | # Make a list of minibatches from a dataset to use as an iterator. 142 | # TODO(SB): Pad out the last few examples in the eval set if they don't 143 | # form a batch. 144 | 145 | print "WARNING: May be discarding eval examples." 146 | 147 | dataset_size = len(sources[0]) 148 | data_iter = [] 149 | start = -batch_size 150 | while True: 151 | start += batch_size 152 | 153 | if start >= dataset_size: 154 | break 155 | 156 | candidate_batch = tuple(source[start:start + batch_size] 157 | for source in sources) 158 | 159 | if len(candidate_batch[0]) == batch_size: 160 | data_iter.append(candidate_batch) 161 | else: 162 | print "Skipping " + str(len(candidate_batch[0])) + " examples." 163 | return data_iter 164 | 165 | 166 | def PreprocessDataset(dataset, vocabulary, seq_length, data_manager, eval_mode=False, logger=None, 167 | sentence_pair_data=False, for_rnn=False): 168 | # TODO(SB): Simpler version for plain RNN. 169 | dataset = TrimDataset(dataset, seq_length, eval_mode=eval_mode, sentence_pair_data=sentence_pair_data) 170 | dataset = TokensToIDs(vocabulary, dataset, sentence_pair_data=sentence_pair_data) 171 | if for_rnn: 172 | dataset = CropAndPadForRNN(dataset, seq_length, logger=logger, sentence_pair_data=sentence_pair_data) 173 | else: 174 | dataset = CropAndPad(dataset, seq_length, logger=logger, sentence_pair_data=sentence_pair_data) 175 | 176 | if sentence_pair_data: 177 | X = np.transpose(np.array([[example["premise_tokens"] for example in dataset], 178 | [example["hypothesis_tokens"] for example in dataset]], 179 | dtype=np.int32), (1, 2, 0)) 180 | if for_rnn: 181 | # TODO(SB): Extend this clause to the non-pair case. 182 | transitions = np.zeros((len(dataset), 2, 0)) 183 | num_transitions = np.zeros((len(dataset), 2)) 184 | else: 185 | transitions = np.transpose(np.array([[example["premise_transitions"] for example in dataset], 186 | [example["hypothesis_transitions"] for example in dataset]], 187 | dtype=np.int32), (1, 2, 0)) 188 | num_transitions = np.transpose(np.array( 189 | [[example["num_premise_transitions"] for example in dataset], 190 | [example["num_hypothesis_transitions"] for example in dataset]], 191 | dtype=np.int32), (1, 0)) 192 | else: 193 | X = np.array([example["tokens"] for example in dataset], 194 | dtype=np.int32) 195 | transitions = np.array([example["transitions"] for example in dataset], 196 | dtype=np.int32) 197 | num_transitions = np.array( 198 | [example["num_transitions"] for example in dataset], 199 | dtype=np.int32) 200 | y = np.array( 201 | [data_manager.LABEL_MAP[example["label"]] for example in dataset], 202 | dtype=np.int32) 203 | 204 | return X, transitions, y, num_transitions 205 | 206 | 207 | def BuildVocabulary(raw_training_data, raw_eval_sets, embedding_path, logger=None, sentence_pair_data=False): 208 | # Find the set of words that occur in the data. 209 | logger.Log("Constructing vocabulary...") 210 | types_in_data = set() 211 | for dataset in [raw_training_data] + [eval_dataset[1] for eval_dataset in raw_eval_sets]: 212 | if sentence_pair_data: 213 | types_in_data.update(itertools.chain.from_iterable([example["premise_tokens"] 214 | for example in dataset])) 215 | types_in_data.update(itertools.chain.from_iterable([example["hypothesis_tokens"] 216 | for example in dataset])) 217 | else: 218 | types_in_data.update(itertools.chain.from_iterable([example["tokens"] 219 | for example in dataset])) 220 | logger.Log("Found " + str(len(types_in_data)) + " word types.") 221 | 222 | if embedding_path == None: 223 | logger.Log( 224 | "Warning: Open-vocabulary models require pretrained vectors. Running with empty vocabulary.") 225 | vocabulary = CORE_VOCABULARY 226 | else: 227 | # Build a vocabulary of words in the data for which we have an 228 | # embedding. 229 | vocabulary = BuildVocabularyForASCIIEmbeddingFile( 230 | embedding_path, types_in_data, CORE_VOCABULARY) 231 | 232 | return vocabulary 233 | 234 | 235 | def BuildVocabularyForASCIIEmbeddingFile(path, types_in_data, core_vocabulary): 236 | """Quickly iterates through a GloVe-formatted ASCII vector file to 237 | extract a working vocabulary of words that occur both in the data and 238 | in the vector file.""" 239 | 240 | # TODO(SB): Report on *which* words are skipped. See if any are common. 241 | 242 | vocabulary = {} 243 | vocabulary.update(core_vocabulary) 244 | next_index = len(vocabulary) 245 | with open(path, 'r') as f: 246 | for line in f: 247 | spl = line.split(" ", 1) 248 | word = spl[0] 249 | if word in types_in_data: 250 | vocabulary[word] = next_index 251 | next_index += 1 252 | return vocabulary 253 | 254 | 255 | def LoadEmbeddingsFromASCII(vocabulary, embedding_dim, path): 256 | """Prepopulates a numpy embedding matrix indexed by vocabulary with 257 | values from a GloVe - format ASCII vector file. 258 | 259 | For now, values not found in the file will be set to zero.""" 260 | emb = np.zeros( 261 | (len(vocabulary), embedding_dim), dtype=theano.config.floatX) 262 | with open(path, 'r') as f: 263 | for line in f: 264 | spl = line.split(" ") 265 | word = spl[0] 266 | if word in vocabulary: 267 | emb[vocabulary[word], :] = [float(e) for e in spl[1:]] 268 | return emb 269 | 270 | 271 | def TransitionsToParse(transitions, words): 272 | if transitions is not None: 273 | stack = ["(P *ZEROS*)"] * (len(transitions) + 1) 274 | buffer_ptr = 0 275 | for transition in transitions: 276 | if transition == 0: 277 | stack.append("(P " + words[buffer_ptr] +")") 278 | buffer_ptr += 1 279 | elif transition == 1: 280 | r = stack.pop() 281 | l = stack.pop() 282 | stack.append("(M " + l + " " + r + ")") 283 | return stack.pop() 284 | else: 285 | return " ".join(words) 286 | -------------------------------------------------------------------------------- /python/spinn/util/theano_internal.py: -------------------------------------------------------------------------------- 1 | """Low-level Theano utilities.""" 2 | 3 | from collections import OrderedDict 4 | from functools import wraps 5 | 6 | import theano 7 | from theano import ifelse 8 | from theano import tensor as T 9 | from theano.compile.sharedvalue import SharedVariable 10 | from theano.sandbox.cuda import cuda_available 11 | 12 | if cuda_available: 13 | from theano.sandbox.cuda import HostFromGpu 14 | from spinn.util import cuda 15 | 16 | 17 | def tensorx(name, ndim, dtype=theano.config.floatX): 18 | return T.TensorType(dtype, (False,) * ndim)(name) 19 | 20 | 21 | def zeros_nobroadcast(shape, dtype=theano.config.floatX): 22 | zeros = T.zeros(shape, dtype=dtype) 23 | zeros = T.unbroadcast(zeros, *range(len(shape))) 24 | return zeros 25 | 26 | 27 | def merge_update_lists(xs, ys): 28 | """ 29 | Merge two update lists: 30 | 31 | - adding where `xs[i] is not None and ys[i] is not None` 32 | - copying `xs[i]` if `xs[i] is not None` 33 | - copying `ys[i]` otherwise 34 | """ 35 | 36 | assert len(xs) == len(ys), "%i %i" % (len(xs), len(ys)) 37 | ret = [] 38 | 39 | for x, y in zip(xs, ys): 40 | if y is None: 41 | ret.append(x) 42 | elif x is None: 43 | ret.append(y) 44 | else: 45 | # Merge. 46 | ret.append(x + y) 47 | 48 | return ret 49 | 50 | 51 | def merge_updates(*updates_dicts): 52 | all_updates = OrderedDict() 53 | for updates_dict in updates_dicts: 54 | for k, v in updates_dict.iteritems(): 55 | if k in all_updates: 56 | all_updates[k] += v 57 | else: 58 | all_updates[k] = v 59 | 60 | return all_updates 61 | 62 | 63 | def batch_subgraph_gradients(g_in, wrt, f_g_out, 64 | wrt_jacobian=True, 65 | name="batch_subgraph_grad"): 66 | """ 67 | Build gradients w.r.t. some cost on a subgraph of a larger graph. 68 | 69 | Let G be a feedforward subgraph for which we want to compute gradients. 70 | G has N_I inputs and N_O outputs. 71 | 72 | This function will compute batch gradients on the subgraph G. 73 | 74 | It optionally supports computing Jacobians (batch-element-wise cost 75 | gradients) as well, though this is experimental and relies on some naughty 76 | Theano hacks. 77 | 78 | Args: 79 | g_in: List of N_I inputs to G. Each element may be either a 80 | symbolic Theano input variable or an integer (signifying the number 81 | of dimensions of the input). 82 | wrt: Any variables inside G for which we should also collect gradients. 83 | f_g_out: Function which accepts N_I Theano vars and returns N_O Theano 84 | vars. 85 | 86 | Returns: 87 | A function which accepts two arguments, `b_in` and `b_grad`. 88 | 89 | `b_in` must be a list of N_I Theano batch variables representing inputs 90 | to the subgraph G. (Each element of `b_in` has a leading batch axis and 91 | is thus one dimension larger than its corresponding element of `g_in`). 92 | 93 | `b_grad` must be a list of N_O Theano batch variables representing 94 | cost gradients w.r.t. each of the graph outputs. Again, each element of 95 | the list has a leading batch axis and is thus one dimension larger than 96 | its corresponding output from `f_g_out`. 97 | 98 | The function returns `(d_in, d_wrt)`, where 99 | 100 | - `d_in` is a list of batch cost gradients with respect to each of the 101 | corresponding elements of `g_in`. Each element of `d_in` has a 102 | leading batch axis, and is thus one dimension larger than its 103 | corresponding `g_in` element. 104 | - `d_wrt` is a list of batch cost gradients with respect to each of the 105 | corresponding elements of `wrt`. Each element of `d_wrt` has a 106 | leading batch axis, and is thus one dimension larger than its 107 | corresponding `wrt` element. 108 | """ 109 | 110 | wrt = tuple(wrt) 111 | 112 | def deltas(b_inps, b_grads): 113 | b_inps = tuple(b_inps) 114 | assert len(g_in) == len(b_inps), "%i %i" % (len(g_in), len(b_inps)) 115 | 116 | # Build feedforward graph. 117 | b_out = f_g_out(*b_inps) 118 | # Make sure it's a list of outputs. 119 | b_out = [b_out] if not isinstance(b_out, (list, tuple)) else b_out 120 | 121 | def dot_grad_override(op, inp, grads): 122 | x, y = inp 123 | xdim, ydim = x.type.ndim, y.type.ndim 124 | 125 | # HACK: Get super grads 126 | gz, = grads 127 | xgrad, ygrad = op.grad(inp, grads) 128 | 129 | if xdim == ydim == 2: 130 | # HACK: Compute the Jacobian of this `dot` op. We will yield a 131 | # rank-3 tensor rather than a gradient matrix. 132 | ygrad = T.batched_dot(x.dimshuffle(0, 1, "x"), 133 | gz.dimshuffle(0, "x", 1)) 134 | 135 | # TODO patternbroadcast? 136 | 137 | return xgrad, ygrad 138 | 139 | # Overrides which turn our "grad" call into a "jacobian" call! 140 | overrides = None 141 | if wrt_jacobian: 142 | overrides = {T.Dot: dot_grad_override} 143 | 144 | # Compute gradients of subgraph beginning at `g_in` and ending at `g_out`, 145 | # where the cost gradient w.r.t. each `g_out` is given by the corresponding 146 | # entry in `grads_above`. 147 | known_grads = dict(zip(b_out, b_grads)) 148 | d_all = T.grad(cost=None, wrt=b_inps + wrt, 149 | known_grads=known_grads, 150 | consider_constant=b_inps, 151 | disconnected_inputs="ignore", 152 | return_disconnected="None", 153 | use_overrides=set(wrt), 154 | grad_overrides=overrides) 155 | d_in, d_wrt = d_all[:len(b_inps)], d_all[len(b_inps):] 156 | 157 | # Strip any GPU<->host transfers that might have crept into this 158 | # automatically constructed graph. 159 | d_wrt = map(cuda.strip_transfer, d_wrt) 160 | d_in = map(cuda.strip_transfer, d_in) 161 | if d_wrt: 162 | for i in range(len(d_wrt)): 163 | if d_wrt[i] is None: 164 | continue 165 | # HACK: Strip off DimShuffle(Elemwise(DimShuffle(Sum))). This is what 166 | # comes out for bias gradients.. don't ask me why. 167 | if isinstance(d_wrt[i].owner.op, T.DimShuffle): 168 | base = d_wrt[i].owner 169 | if isinstance(base.inputs[0].owner.op, T.Elemwise): 170 | base = base.inputs[0].owner 171 | if isinstance(base.inputs[0].owner.op, T.DimShuffle): 172 | base = base.inputs[0].owner 173 | if isinstance(base.inputs[0].owner.op, T.Sum): 174 | base = base.inputs[0].owner 175 | d_wrt[i] = base.inputs[0] 176 | 177 | return d_in, d_wrt 178 | 179 | return deltas 180 | 181 | 182 | def ensure_2d_arguments(f, squeeze_ret=True): 183 | """Decorator which ensures all of its function's arguments are 2D.""" 184 | @wraps(f) 185 | def wrapped(*args, **kwargs): 186 | new_args = [] 187 | for arg in args: 188 | if isinstance(arg, T.TensorVariable): 189 | if arg.ndim == 1: 190 | arg = arg.dimshuffle("x", 0) 191 | elif arg.ndim > 2: 192 | raise RuntimeError("ensure_2d_arguments wrapped a function" 193 | " which received an %i-d argument. " 194 | "Don't know what to do.") 195 | new_args.append(arg) 196 | 197 | ret = f(*new_args, **kwargs) 198 | if squeeze_ret: 199 | if isinstance(ret, (list, tuple)): 200 | ret = [ret_i.squeeze() for ret_i in ret] 201 | elif isinstance(ret, T.TensorVariable): 202 | ret = ret.squeeze() 203 | return ret 204 | return wrapped 205 | 206 | 207 | def prepare_updates_dict(updates): 208 | """ 209 | Prepare a Theano `updates` dictionary. 210 | 211 | Ensure that both keys and values are valid entries. 212 | NB, this function is heavily coupled with its clients, and not intended for 213 | general use.. 214 | """ 215 | 216 | def prepare_key(key, val): 217 | if not isinstance(key, SharedVariable): 218 | if isinstance(key.owner.inputs[0], SharedVariable): 219 | # Extract shared from Update(shared) 220 | return key.owner.inputs[0] 221 | elif key.owner.inputs[0].owner.op.__class__ is HostFromGpu: 222 | if isinstance(key.owner.inputs[0].owner.inputs[0], SharedVariable): 223 | # Extract shared from Update(HostFromGpu(shared)) 224 | return key.owner.inputs[0].owner.inputs[0] 225 | elif key.owner.op.__class__ is ifelse.IfElse: 226 | # Assume that 'true' condition of ifelse involves the intended 227 | # shared variable. 228 | return prepare_key(key.owner.inputs[1], val) 229 | 230 | raise ValueError("Invalid updates dict key/value: %s / %s" 231 | % (key, val)) 232 | return key 233 | 234 | return {prepare_key(key, val): val for key, val in updates.iteritems()} 235 | -------------------------------------------------------------------------------- /python/spinn/util/variable_store.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import cPickle 4 | import theano 5 | 6 | from spinn.util.blocks import HeKaimingInitializer 7 | 8 | 9 | class VariableStore(object): 10 | 11 | def __init__(self, prefix="vs", default_initializer=HeKaimingInitializer(), logger=None): 12 | self.prefix = prefix 13 | self.default_initializer = default_initializer 14 | self.vars = OrderedDict() # Order is used in saving and loading 15 | self.savable_vars = OrderedDict() 16 | self.trainable_vars = OrderedDict() 17 | self.logger = logger 18 | self.nongradient_updates = OrderedDict() 19 | 20 | def add_param(self, name, shape, initializer=None, savable=True, trainable=True): 21 | if not initializer: 22 | initializer = self.default_initializer 23 | 24 | if name not in self.vars: 25 | full_name = "%s/%s" % (self.prefix, name) 26 | if self.logger: 27 | self.logger.Log( 28 | "Created variable " + full_name + " shape: " + str(shape), level=self.logger.DEBUG) 29 | init_value = initializer(shape).astype(theano.config.floatX) 30 | self.vars[name] = theano.shared(init_value, 31 | name=full_name) 32 | if savable: 33 | self.savable_vars[name] = self.vars[name] 34 | if trainable: 35 | self.trainable_vars[name] = self.vars[name] 36 | 37 | return self.vars[name] 38 | 39 | def save_checkpoint(self, filename="vs_ckpt", keys=None, extra_vars=[]): 40 | if not keys: 41 | keys = self.savable_vars 42 | save_file = open(filename, 'w') # this will overwrite current contents 43 | for key in keys: 44 | cPickle.dump(self.vars[key].get_value(borrow=True), save_file, -1) # the -1 is for HIGHEST_PROTOCOL 45 | for var in extra_vars: 46 | cPickle.dump(var, save_file, -1) 47 | save_file.close() 48 | 49 | def load_checkpoint(self, filename="vs_ckpt", keys=None, num_extra_vars=0, skip_saved_unsavables=False): 50 | if skip_saved_unsavables: 51 | keys = self.vars 52 | else: 53 | if not keys: 54 | keys = self.savable_vars 55 | save_file = open(filename) 56 | for key in keys: 57 | if skip_saved_unsavables and key not in self.savable_vars: 58 | if self.logger: 59 | full_name = "%s/%s" % (self.prefix, key) 60 | self.logger.Log( 61 | "Not restoring variable " + full_name, level=self.logger.DEBUG) 62 | _ = cPickle.load(save_file) # Discard 63 | else: 64 | if self.logger: 65 | full_name = "%s/%s" % (self.prefix, key) 66 | self.logger.Log( 67 | "Restoring variable " + full_name, level=self.logger.DEBUG) 68 | self.vars[key].set_value(cPickle.load(save_file), borrow=True) 69 | 70 | extra_vars = [] 71 | for _ in range(num_extra_vars): 72 | extra_vars.append(cPickle.load(save_file)) 73 | return extra_vars 74 | 75 | def add_nongradient_update(self, variable, new_value): 76 | # Track an update that should be applied during training but that aren't gradients. 77 | # self.nongradient_updates should be fed as an update to theano.function(). 78 | self.nongradient_updates[variable] = new_value 79 | 80 | -------------------------------------------------------------------------------- /scripts/12_5_sst_2s.sh: -------------------------------------------------------------------------------- 1 | # NAME: 12_5_sweep_sst_Model2S 2 | # NUM RUNS: 6 3 | # SWEEP PARAMETERS: {'semantic_classifier_keep_rate': ('LIN', 0.4, 0.75), 'embedding_keep_rate': ('LIN', 0.4, 1.0), 'learning_rate': ('EXP', 0.0001, 0.0003), 'double_identity_init_range': ('EXP', 0.0005, 0.005), 'init_range': ('EXP', 0.001, 0.004), 'scheduled_sampling_exponent_base': ('SS_BASE', 2e-06, 0.0002), 'l2_lambda': ('EXP', 2e-07, 2e-05)} 4 | # FIXED_PARAMETERS: {'eval_seq_length': '150', 'clipping_max_value': '5.0', 'data_type': 'sst', 'training_data_path': 'sst-data/train_expanded.txt', 'batch_size': '32', 'embedding_data_path': '/scr/nlp/data/glove_vecs/glove.840B.300d.txt', 'model_dim': '50', 'seq_length': '100', 'ckpt_root': '/afs/cs.stanford.edu/u/sbowman/scr/', 'word_embedding_dim': '300', 'model_type': 'Model2S', 'lstm_composition': '', 'eval_data_path': 'sst-data/dev.txt:sst-data/train_sample.txt'} 5 | 6 | export SPINN_FLAGS=" --semantic_classifier_keep_rate 0.607345803836 --eval_seq_length 150 --clipping_max_value 5.0 --data_type sst --model_dim 50 --learning_rate 0.000229287344209 --training_data_path sst-data/train_expanded.txt --word_embedding_dim 300 --batch_size 32 --double_identity_init_range 0.00479811207292 --ckpt_root /afs/cs.stanford.edu/u/sbowman/scr/ --init_range 0.00173528232303 --embedding_keep_rate 0.626722888293 --seq_length 100 --embedding_data_path /scr/nlp/data/glove_vecs/glove.840B.300d.txt --model_type Model2S --scheduled_sampling_exponent_base 0.999840115881 --l2_lambda 3.27967310602e-06 --lstm_composition --eval_data_path sst-data/dev.txt:sst-data/train_sample.txt --experiment_name 12_5_sweep_sst_Model2S_0-semantic_classifier_keep_rate0.61-learning_rate0.00023-double_identity_init_range0.0048-init_range0.0017-embedding_keep_rate0.63-scheduled_sampling_exponent_base1-l2_lambda3.3e-06"; export DEVICE=gpu5; qsub -v SPINN_FLAGS,DEVICE train_spinn_classifier.sh -q jag -l host=jagupardX 7 | 8 | export SPINN_FLAGS=" --semantic_classifier_keep_rate 0.655673561552 --eval_seq_length 150 --clipping_max_value 5.0 --data_type sst --model_dim 50 --learning_rate 0.000149040525731 --training_data_path sst-data/train_expanded.txt --word_embedding_dim 300 --batch_size 32 --double_identity_init_range 0.00184318062176 --ckpt_root /afs/cs.stanford.edu/u/sbowman/scr/ --init_range 0.00274993243904 --embedding_keep_rate 0.780228802804 --seq_length 100 --embedding_data_path /scr/nlp/data/glove_vecs/glove.840B.300d.txt --model_type Model2S --scheduled_sampling_exponent_base 0.999859308571 --l2_lambda 2.58814037254e-06 --lstm_composition --eval_data_path sst-data/dev.txt:sst-data/train_sample.txt --experiment_name 12_5_sweep_sst_Model2S_1-semantic_classifier_keep_rate0.66-learning_rate0.00015-double_identity_init_range0.0018-init_range0.0027-embedding_keep_rate0.78-scheduled_sampling_exponent_base1-l2_lambda2.6e-06"; export DEVICE=gpu5; qsub -v SPINN_FLAGS,DEVICE train_spinn_classifier.sh -q jag -l host=jagupardX 9 | 10 | export SPINN_FLAGS=" --semantic_classifier_keep_rate 0.603547877227 --eval_seq_length 150 --clipping_max_value 5.0 --data_type sst --model_dim 50 --learning_rate 0.000156624519335 --training_data_path sst-data/train_expanded.txt --word_embedding_dim 300 --batch_size 32 --double_identity_init_range 0.00369239449626 --ckpt_root /afs/cs.stanford.edu/u/sbowman/scr/ --init_range 0.00270690567322 --embedding_keep_rate 0.409169167178 --seq_length 100 --embedding_data_path /scr/nlp/data/glove_vecs/glove.840B.300d.txt --model_type Model2S --scheduled_sampling_exponent_base 0.999988121904 --l2_lambda 5.07823011003e-06 --lstm_composition --eval_data_path sst-data/dev.txt:sst-data/train_sample.txt --experiment_name 12_5_sweep_sst_Model2S_2-semantic_classifier_keep_rate0.6-learning_rate0.00016-double_identity_init_range0.0037-init_range0.0027-embedding_keep_rate0.41-scheduled_sampling_exponent_base1-l2_lambda5.1e-06"; export DEVICE=gpu7; qsub -v SPINN_FLAGS,DEVICE train_spinn_classifier.sh -q jag -l host=jagupardX 11 | 12 | export SPINN_FLAGS=" --semantic_classifier_keep_rate 0.427596967339 --eval_seq_length 150 --clipping_max_value 5.0 --data_type sst --model_dim 50 --learning_rate 0.000114421988307 --training_data_path sst-data/train_expanded.txt --word_embedding_dim 300 --batch_size 32 --double_identity_init_range 0.00120151554549 --ckpt_root /afs/cs.stanford.edu/u/sbowman/scr/ --init_range 0.00134755000609 --embedding_keep_rate 0.925482590687 --seq_length 100 --embedding_data_path /scr/nlp/data/glove_vecs/glove.840B.300d.txt --model_type Model2S --scheduled_sampling_exponent_base 0.999926919173 --l2_lambda 4.12749323597e-06 --lstm_composition --eval_data_path sst-data/dev.txt:sst-data/train_sample.txt --experiment_name 12_5_sweep_sst_Model2S_3-semantic_classifier_keep_rate0.43-learning_rate0.00011-double_identity_init_range0.0012-init_range0.0013-embedding_keep_rate0.93-scheduled_sampling_exponent_base1-l2_lambda4.1e-06"; export DEVICE=gpu7; qsub -v SPINN_FLAGS,DEVICE train_spinn_classifier.sh -q jag -l host=jagupardX 13 | 14 | 15 | -------------------------------------------------------------------------------- /scripts/analyze_log.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility for plotting the various costs and accuracies vs training iteration no. Reads these values from 3 | a log file. Can also be used to compare multiple logs by supplying multiple paths. 4 | """ 5 | 6 | import gflags 7 | import matplotlib.pyplot as plt 8 | import sys 9 | 10 | FLAGS = gflags.FLAGS 11 | 12 | class TrainLine(object): 13 | def __init__(self, line): 14 | tokens = line.split() 15 | self.step = int(tokens[4]) 16 | self.pred_acc = float(tokens[6]) 17 | self.parse_acc = float(tokens[7]) 18 | self.total_cost = float(tokens[9]) 19 | self.xent_cost = float(tokens[10]) 20 | self.action_cost = float(tokens[11]) 21 | self.l2_cost = float(tokens[12]) 22 | 23 | class EvalLine(object): 24 | def __init__(self, line): 25 | tokens = line.split() 26 | self.step = int(tokens[4]) 27 | self.pred_acc = float(tokens[7]) 28 | self.parse_acc = float(tokens[8]) 29 | 30 | class Log(object): 31 | def __init__(self, path): 32 | self.evals = [] 33 | with open(path) as f: 34 | lines = filter(lambda l : 'Step' in l, f.readlines()) 35 | # count number of eval sets 36 | num_logs = 1 37 | for i in xrange(len(lines)): 38 | if 'Acc' not in lines[i]: 39 | num_logs += 1 40 | self.evals.append(lines[i].strip().split()[-1]) 41 | else: 42 | break 43 | self.corpus = [[] for _ in xrange(num_logs)] 44 | # read the costs and accuracies 45 | for i, line in enumerate(lines): 46 | if 'Acc' in line: 47 | ind = 0 48 | else: 49 | ind = self.evals.index(lines[i].strip().split()[-1]) + 1 50 | # ind = i % num_logs 51 | if ind == 0: 52 | self.corpus[0].append(TrainLine(line)) 53 | elif 'Eval' in line: 54 | self.corpus[ind].append(EvalLine(line)) 55 | 56 | def ShowPlots(subplot=False): 57 | for log_ind, path in enumerate(FLAGS.path.split(":")): 58 | log = Log(path) 59 | if subplot: 60 | plt.subplot(len(FLAGS.path.split(":")), 1, log_ind + 1) 61 | for index in FLAGS.index.split(","): 62 | index = int(index) 63 | for attr in ["pred_acc", "parse_acc", "total_cost", "xent_cost", "l2_cost", "action_cost"]: 64 | if getattr(FLAGS, attr): 65 | if "cost" in attr: 66 | assert index == 0, "costs only associated with training log" 67 | steps, val = zip(*[(l.step, getattr(l, attr)) for l in log.corpus[index] if l.step < FLAGS.iters]) 68 | dct = {} 69 | for k, v in zip(steps, val): 70 | dct[k] = max(v, dct[k]) if k in dct else v 71 | steps, val = zip(*sorted(dct.iteritems())) 72 | plt.plot(steps, val, label="Log%d:%s-%d" % (log_ind, attr, index)) 73 | 74 | plt.xlabel("No. of training iteration") 75 | plt.ylabel(FLAGS.ylabel) 76 | if FLAGS.legend: 77 | plt.legend() 78 | plt.show() 79 | 80 | 81 | if __name__ == '__main__': 82 | 83 | gflags.DEFINE_string("path", None, "Path to log file") 84 | gflags.DEFINE_string("index", "1", "csv list of corpus indices. 0 for train, 1 for eval set 1 etc.") 85 | gflags.DEFINE_boolean("pred_acc", True, "Prediction accuracy") 86 | gflags.DEFINE_boolean("parse_acc", False, "Parsing accuracy") 87 | gflags.DEFINE_boolean("total_cost", False, "Total cost, valid only if index == 0") 88 | gflags.DEFINE_boolean("xent_cost", False, "Cross entropy cost, valid only if index == 0") 89 | gflags.DEFINE_boolean("l2_cost", False, "L2 regularization cost, valid only if index == 0") 90 | gflags.DEFINE_boolean("action_cost", False, "Action cost, valid only if index == 0") 91 | gflags.DEFINE_boolean("legend", False, "Show legend in plot") 92 | gflags.DEFINE_boolean("subplot", False, "Separate plots for each log") 93 | gflags.DEFINE_string("ylabel", "", "Label for y-axis of plot") 94 | gflags.DEFINE_integer("iters", 10000, "Iters to limit plot to") 95 | 96 | FLAGS(sys.argv) 97 | assert FLAGS.path is not None, "Must provide a log path" 98 | ShowPlots(FLAGS.subplot) 99 | 100 | -------------------------------------------------------------------------------- /scripts/make_snli_sweep.py: -------------------------------------------------------------------------------- 1 | # Create a script to run a random hyperparameter search. 2 | 3 | import copy 4 | import getpass 5 | import os 6 | import random 7 | import numpy as np 8 | 9 | LIN = "LIN" 10 | EXP = "EXP" 11 | SS_BASE = "SS_BASE" 12 | 13 | # Instructions: Configure the variables in this block, then run 14 | # the following on a machine with qsub access: 15 | # python make_sweep.py > my_sweep.sh 16 | # bash my_sweep.sh 17 | 18 | # - # 19 | 20 | # Non-tunable flags that must be passed in. 21 | 22 | FIXED_PARAMETERS = { 23 | "data_type": "snli", 24 | "model_type": "Model0", 25 | "training_data_path": "/scr/nlp/data/snli_1.0/snli_1.0_train.jsonl", 26 | "eval_data_path": "/scr/nlp/data/snli_1.0/snli_1.0_dev.jsonl", 27 | "embedding_data_path": "/scr/nlp/data/glove_vecs/glove.840B.300d.txt", 28 | "word_embedding_dim": "300", 29 | "model_dim": "600", 30 | "seq_length": "50", 31 | "eval_seq_length": "50", 32 | "batch_size": "32", 33 | "ckpt_path": os.path.join("/scr/", getpass.getuser(), "/"), # Launching user's home scr dir 34 | "log_path": os.path.join("/scr/", getpass.getuser(), "/") # Launching user's home scr dir 35 | } 36 | 37 | # Tunable parameters. 38 | SWEEP_PARAMETERS = { 39 | "learning_rate": (EXP, 0.0002, 0.01), # RNN likes higher end of range, but below 009. 40 | "l2_lambda": (EXP, 8e-7, 2e-5), 41 | "semantic_classifier_keep_rate": (LIN, 0.80, 0.95), # NB: Keep rates may depend considerably on dims. 42 | "embedding_keep_rate": (LIN, 0.8, 0.95), 43 | "scheduled_sampling_exponent_base": (SS_BASE, 1e-5, 8e-5), 44 | "transition_cost_scale": (LIN, 0.5, 4.0), 45 | "tracking_lstm_hidden_dim": (EXP, 24, 128), 46 | "num_sentence_pair_combination_layers": (LIN, 1, 3) 47 | } 48 | 49 | sweep_name = "sweep_" + \ 50 | FIXED_PARAMETERS["data_type"] + "_" + FIXED_PARAMETERS["model_type"] 51 | sweep_runs = 4 52 | queue = "jag" 53 | 54 | # - # 55 | print "# NAME: " + sweep_name 56 | print "# NUM RUNS: " + str(sweep_runs) 57 | print "# SWEEP PARAMETERS: " + str(SWEEP_PARAMETERS) 58 | print "# FIXED_PARAMETERS: " + str(FIXED_PARAMETERS) 59 | print 60 | 61 | for run_id in range(sweep_runs): 62 | params = {} 63 | params.update(FIXED_PARAMETERS) 64 | for param in SWEEP_PARAMETERS: 65 | config = SWEEP_PARAMETERS[param] 66 | t = config[0] 67 | mn = config[1] 68 | mx = config[2] 69 | 70 | r = random.uniform(0, 1) 71 | if t == EXP: 72 | lmn = np.log(mn) 73 | lmx = np.log(mx) 74 | sample = np.exp(lmn + (lmx - lmn) * r) 75 | elif t==SS_BASE: 76 | lmn = np.log(mn) 77 | lmx = np.log(mx) 78 | sample = 1 - np.exp(lmn + (lmx - lmn) * r) 79 | else: 80 | sample = mn + (mx - mn) * r 81 | 82 | if isinstance(mn, int): 83 | sample = int(round(sample, 0)) 84 | 85 | params[param] = sample 86 | 87 | name = sweep_name + "_" + str(run_id) 88 | flags = "" 89 | for param in params: 90 | value = params[param] 91 | val_str = "" 92 | flags += " --" + param + " " + str(value) 93 | if param not in FIXED_PARAMETERS: 94 | if isinstance(value, int): 95 | val_disp = str(value) 96 | else: 97 | val_disp = "%.2g" % value 98 | name += "-" + param + val_disp 99 | flags += " --experiment_name " + name 100 | print "export SPINN_FLAGS=\"" + flags + "\"; export DEVICE=gpuX; qsub -v SPINN_FLAGS,DEVICE ../scripts/train_spinn_classifier.sh -q " + queue + " -l host=jagupardX" 101 | print 102 | -------------------------------------------------------------------------------- /scripts/make_sst_sweep.py: -------------------------------------------------------------------------------- 1 | # Create a script to run a random hyperparameter search. 2 | 3 | import copy 4 | import getpass 5 | import os 6 | import random 7 | import numpy as np 8 | 9 | LIN = "LIN" 10 | EXP = "EXP" 11 | SS_BASE = "SS_BASE" 12 | 13 | # Instructions: Configure the variables in this block, then run 14 | # the following on a machine with qsub access: 15 | # python make_sweep.py > my_sweep.sh 16 | # bash my_sweep.sh 17 | 18 | # - # 19 | 20 | # Non-tunable flags that must be passed in. 21 | 22 | FIXED_PARAMETERS = { 23 | "data_type": "sst", 24 | "model_type": "Model0", 25 | "training_data_path": "sst-data/train_expanded.txt", 26 | "eval_data_path": "sst-data/dev.txt:sst-data/train_sample.txt", 27 | "embedding_data_path": "/scr/nlp/data/glove_vecs/glove.840B.300d.txt", 28 | "word_embedding_dim": "300", 29 | "model_dim": "300", 30 | "seq_length": "100", 31 | "eval_seq_length": "100", 32 | "batch_size": "32", 33 | "ckpt_path": os.path.join("/scr/", getpass.getuser(), "/"), # Launching user's home scr dir 34 | "log_path": os.path.join("/scr/", getpass.getuser(), "/") # Launching user's home scr dir 35 | } 36 | 37 | # Tunable parameters. 38 | SWEEP_PARAMETERS = { 39 | "learning_rate": (EXP, 0.00005, 0.001), 40 | "l2_lambda": (EXP, 4e-6, 8e-5), 41 | "semantic_classifier_keep_rate": (LIN, 0.3, 0.6), 42 | "embedding_keep_rate": (LIN, 0.3, 0.6), 43 | "scheduled_sampling_exponent_base": (SS_BASE, 1e-5, 1e-4), 44 | "transition_cost_scale": (LIN, 18.0, 28.0), 45 | "tracking_lstm_hidden_dim": (EXP, 1, 32) 46 | } 47 | 48 | 49 | sweep_name = "sweep_" + \ 50 | FIXED_PARAMETERS["data_type"] + "_" + FIXED_PARAMETERS["model_type"] 51 | sweep_runs = 6 52 | queue = "jag" 53 | 54 | # - # 55 | print "# NAME: " + sweep_name 56 | print "# NUM RUNS: " + str(sweep_runs) 57 | print "# SWEEP PARAMETERS: " + str(SWEEP_PARAMETERS) 58 | print "# FIXED_PARAMETERS: " + str(FIXED_PARAMETERS) 59 | print 60 | 61 | for run_id in range(sweep_runs): 62 | params = {} 63 | params.update(FIXED_PARAMETERS) 64 | for param in SWEEP_PARAMETERS: 65 | config = SWEEP_PARAMETERS[param] 66 | t = config[0] 67 | mn = config[1] 68 | mx = config[2] 69 | 70 | r = random.uniform(0, 1) 71 | if t == EXP: 72 | lmn = np.log(mn) 73 | lmx = np.log(mx) 74 | sample = np.exp(lmn + (lmx - lmn) * r) 75 | elif t==SS_BASE: 76 | lmn = np.log(mn) 77 | lmx = np.log(mx) 78 | sample = 1 - np.exp(lmn + (lmx - lmn) * r) 79 | else: 80 | sample = mn + (mx - mn) * r 81 | 82 | if isinstance(mn, int): 83 | sample = int(round(sample, 0)) 84 | 85 | params[param] = sample 86 | 87 | name = sweep_name + "_" + str(run_id) 88 | flags = "" 89 | for param in params: 90 | value = params[param] 91 | val_str = "" 92 | flags += " --" + param + " " + str(value) 93 | if param not in FIXED_PARAMETERS: 94 | if isinstance(value, int): 95 | val_disp = str(value) 96 | else: 97 | val_disp = "%.2g" % value 98 | name += "-" + param + val_disp 99 | flags += " --experiment_name " + name 100 | print "export SPINN_FLAGS=\"" + flags + "\"; export DEVICE=gpuX; qsub -v SPINN_FLAGS,DEVICE ../scripts/train_spinn_classifier.sh -q " + queue + " -l host=jagupardX" 101 | print 102 | -------------------------------------------------------------------------------- /scripts/make_theano_patch.sh: -------------------------------------------------------------------------------- 1 | bash -c "cd /scr/jgauthie/tmp/theano-nshrdlu && diff -x \"*.pyc\" -uwr theano-base/theano theano" > theano.patch 2 | -------------------------------------------------------------------------------- /scripts/pick_gpu.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import sys 3 | import random 4 | import re 5 | 6 | # Print the name of a device to use, either 'cpu' or 'gpu0', 'gpu1',... 7 | # GPUs with usage under the constant threshold will be chosen first, 8 | # but subject to that constraint, selection is random. 9 | # 10 | # Warning: This is hacky and brittle, and can break if nvidia-smi changes 11 | # in the way it formats its output. 12 | # 13 | # Maintainer: sbowman@stanford.edu 14 | 15 | USAGE_THRESHOLD = 0.8 16 | 17 | proc = subprocess.Popen("nvidia-smi", stdout=subprocess.PIPE, 18 | stderr=subprocess.PIPE) 19 | output, error = proc.communicate() 20 | if error: 21 | sys.stderr.write() 22 | sys.stdout.write("cpu") 23 | 24 | usage_re = re.compile(r"(?<= )\d{1,8}(?=MiB /)") 25 | matches = usage_re.findall(output) 26 | usage_amts = [int(usage_amt) for usage_amt in matches] 27 | 28 | total_re = re.compile(r"(?<=/)\s*\d{1,8}(?=MiB)") 29 | matches = total_re.findall(output) 30 | total_amts = [int(total) for total in matches] 31 | 32 | pct_used = [float(usage_amt)/float(total) for (usage_amt, total) in zip(usage_amts, total_amts)] 33 | 34 | open_gpus = [index for index in range(len(pct_used)) if pct_used[index] < USAGE_THRESHOLD] 35 | 36 | if open_gpus: 37 | sys.stdout.write("gpu" + str(random.choice(open_gpus))) 38 | else: 39 | sys.stdout.write("cpu") 40 | -------------------------------------------------------------------------------- /scripts/train_spinn_classifier.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ### Generic job script for all experiments. 4 | 5 | # Usage example: 6 | # export SPINN_FLAGS="--learning_rate 0.01 --batch_size 256"; export DEVICE=gpu2; export DEVICE=gpu0; qsub -v SPINN_FLAGS,DEVICE scripts/train_spinn_classifier.sh -l host=jagupard10 7 | 8 | # Change to the submission directory. 9 | cd $PBS_O_WORKDIR 10 | echo Lauching from working directory: $PBS_O_WORKDIR 11 | echo Flags: $SPINN_FLAGS 12 | echo Device: $DEVICE 13 | 14 | # Log what we're running and where. 15 | echo $PBS_JOBID - `hostname` - $DEVICE - at `git log --pretty=format:'%h' -n 1` - $SPINN_FLAGS >> ~/spinn_machine_assignments.txt 16 | 17 | # Use Jon's Theano install. 18 | source /u/nlp/packages/anaconda/bin/activate conda-common 19 | export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH 20 | export PYTHONPATH=/scr/jgauthie/tmp/theano-nshrdlu:$PYTHONPATH 21 | 22 | THEANO_FLAGS=allow_gc=False,cuda.root=/usr/bin/cuda,warn_float64=warn,device=$DEVICE,floatX=float32 python -m spinn.models.classifier $SPINN_FLAGS 23 | -------------------------------------------------------------------------------- /scripts/train_spinn_fat_classifier.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ### Generic job script for all experiments. 4 | 5 | # Usage example: 6 | # export SPINN_FLAGS="--learning_rate 0.01 --batch_size 256"; export DEVICE=gpu2; export DEVICE=gpu0; qsub -v SPINN_FLAGS,DEVICE scripts/train_spinn_classifier.sh -l host=jagupard10 7 | 8 | # Change to the submission directory. 9 | cd $PBS_O_WORKDIR 10 | echo Lauching from working directory: $PBS_O_WORKDIR 11 | echo Flags: $SPINN_FLAGS 12 | echo Device: $DEVICE 13 | 14 | # Log what we're running and where. 15 | echo $PBS_JOBID - `hostname` - $DEVICE - at `git log --pretty=format:'%h' -n 1` - $SPINN_FLAGS >> ~/spinn_machine_assignments.txt 16 | 17 | # Use Jon's Theano install. 18 | source /u/nlp/packages/anaconda/bin/activate conda-common 19 | export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH 20 | export PYTHONPATH=/scr/jgauthie/tmp/theano-nshrdlu:$PYTHONPATH 21 | export PATH=/usr/local/cuda/bin:$PATH 22 | 23 | THEANO_FLAGS=allow_gc=False,cuda.root=/usr/bin/cuda,warn_float64=warn,device=$DEVICE,floatX=float32 python -m spinn.models.fat_classifier $SPINN_FLAGS 24 | -------------------------------------------------------------------------------- /writing/gist/gist.tex: -------------------------------------------------------------------------------- 1 | \documentclass[11pt,letterpaper]{article} 2 | \usepackage{../acl2015} 3 | \usepackage{times} 4 | \usepackage{latexsym} 5 | % \setlength\titlebox{5cm} % Expanding the titlebox 6 | 7 | %%% Custom additions %%% 8 | \usepackage{url} 9 | \usepackage[leqno, fleqn]{amsmath} 10 | \usepackage{amssymb} 11 | \usepackage{qtree} 12 | \usepackage{graphicx} 13 | \usepackage{booktabs} 14 | \usepackage{multirow} 15 | \usepackage{colortbl} 16 | \usepackage{caption} 17 | \usepackage{subcaption} 18 | \usepackage{color} 19 | \usepackage{xcolor} 20 | \usepackage{tikz} 21 | \usepackage{tikz-qtree} 22 | \usepackage{ifthen} 23 | \usepackage{framed} 24 | 25 | \newcount\colveccount 26 | \newcommand*\colvec[1]{ 27 | \global\colveccount#1 28 | \begin{bmatrix} 29 | \colvecnext 30 | } 31 | \def\colvecnext#1{ 32 | #1 33 | \global\advance\colveccount-1 34 | \ifnum\colveccount>0 35 | \\ 36 | \expandafter\colvecnext 37 | \else 38 | \end{bmatrix} 39 | \fi 40 | } 41 | 42 | \newcommand{\nateq}{\equiv} 43 | \newcommand{\natind}{\mathbin{\#}} 44 | \newcommand{\natneg}{\mathbin{^{\wedge}}} 45 | \newcommand{\natfor}{\sqsubset} 46 | \newcommand{\natrev}{\sqsupset} 47 | \newcommand{\natalt}{\mathbin{|}} 48 | \newcommand{\natcov}{\mathbin{\smallsmile}} 49 | 50 | \newcommand{\plneg}{\mathop{\textit{not}}} 51 | \newcommand{\pland}{\mathbin{\textit{and}}} 52 | \newcommand{\plor}{\mathbin{\textit{or}}} 53 | 54 | \newcommand{\shift}{\textsc{shift}} 55 | \newcommand{\reduce}{\textsc{reduce}} 56 | 57 | % Strikeout 58 | \newlength{\howlong}\newcommand{\strikeout}[1]{\settowidth{\howlong}{#1}#1\unitlength0.5ex% 59 | \begin{picture}(0,0)\put(0,1){\line(-1,0){\howlong\divide\unitlength}}\end{picture}} 60 | 61 | \newcommand{\True}{\texttt{T}} 62 | \newcommand{\False}{\texttt{F}} 63 | \usepackage{stmaryrd} 64 | \newcommand{\sem}[1]{\ensuremath{\llbracket#1\rrbracket}} 65 | 66 | \newcommand{\mynote}[1]{{\color{blue}#1}} 67 | \newcommand{\tbchecked}[1]{{\color{red}#1}} 68 | 69 | \usepackage{gb4e} 70 | \noautomath 71 | 72 | \def\ii#1{\textit{#1}} 73 | \newcommand{\word}[1]{\emph{#1}} 74 | \newcommand{\fulllabel}[2]{\b{#1}\newline\textsc{#2}} 75 | 76 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 77 | %%%%% Code to simulate natbib's citealt, which prints citations with 78 | %%%%% no parentheses: 79 | 80 | \makeatletter 81 | \def\citealt{\def\citename##1{{\frenchspacing##1} }\@internalcitec} 82 | \def\@citexc[#1]#2{\if@filesw\immediate\write\@auxout{\string\citation{#2}}\fi 83 | \def\@citea{}\@citealt{\@for\@citeb:=#2\do 84 | {\@citea\def\@citea{;\penalty\@m\ }\@ifundefined 85 | {b@\@citeb}{{\bf ?}\@warning 86 | {Citation `\@citeb' on page \thepage \space undefined}}% 87 | {\csname b@\@citeb\endcsname}}}{#1}} 88 | \def\@internalcitec{\@ifnextchar [{\@tempswatrue\@citexc}{\@tempswafalse\@citexc[]}} 89 | \def\@citealt#1#2{{#1\if@tempswa, #2\fi}} 90 | \makeatother 91 | 92 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 93 | 94 | 95 | %%% %%% 96 | 97 | \title{NSHRDLU:\thanks{Name is provisional, and stolen from Geoff Hinton.}\\Work in progress on the joint learning of parsing and semantic encoding} 98 | 99 | \author{ 100 | Samuel R.\ Bowman$^{\ast\dag}$ \\ 101 | \texttt{sbowman@stanford.edu} \\ 102 | \And 103 | Jon Gauthier$^{\dag\ddag}$ \\ 104 | \texttt{angeli@stanford.edu} \\ 105 | \AND 106 | Christopher D.\ Manning$^{\ast\dag\S}$\\ 107 | \texttt{manning@stanford.edu}\\ 108 | \And 109 | Christopher Potts$^{\ast}$\\ 110 | \texttt{cgpotts@stanford.edu} 111 | \AND\\[-3ex] 112 | {$^{\ast}$Stanford Linguistics\quad 113 | $^{\dag}$Stanford NLP Group}\\ 114 | {$^{\ddag}$Stanford Symbolic Systems\quad 115 | $^{\S}$Stanford Computer Science} 116 | } 117 | 118 | \date{} 119 | 120 | \makeatletter 121 | \newcommand{\@BIBLABEL}{\@emptybiblabel} 122 | \newcommand{\@emptybiblabel}[1]{} 123 | \definecolor{black}{rgb}{0,0,0} 124 | \makeatother 125 | \usepackage[breaklinks, colorlinks, linkcolor=black, urlcolor=black, citecolor=black, draft]{hyperref} 126 | 127 | \def\t#1{#1} 128 | \def\b#1{\t{\textbf{#1}}} 129 | \def\colspaceS{2.25mm} 130 | \def\colspaceM{4.0mm} 131 | \def\colspaceL{4.25mm} 132 | 133 | \newcommand\todo[1]{\textcolor{red}{\textbf{TODO:} #1}} 134 | \newcommand\note[1]{\textcolor{blue}{\textbf{NOTE:} #1}} 135 | 136 | 137 | \begin{document} 138 | \maketitle 139 | 140 | \section{Introduction} 141 | 142 | This project aims to use ideas from greedy transition-based parsing to build neural network models that can jointly learn to parse sentences and to use those parses to guide semantic composition. 143 | 144 | Table~\ref{models-table} shows the sequence of model designs that we plan to build, and Figure~\ref{m1-views} depicts some representative models. 145 | 146 | We see three reasons to pursue this approach: 147 | \begin{itemize} 148 | \item Simply adapting a greedy transition-based approach to sentence encoding makes it possible to exploit semantic compositionality in the same manner as in a TreeRNN, but using a static graph structure that can take advantage of existing neural network libraries like Theano for both automatic differentiation and for highly optimized matrix computations on both CPUs and GPUs. Model 0 pursues this property directly, and all subsequent models share it. 149 | \item When a parsing system and an interpretation system are trained jointly and forced to share representations, it is likely that the performance of both models will benefit: the semantic composition function will have better access to syntactic type information that could provide additional evidence on the functional behavior of rare words, and the syntactic parser will have access to information about the semantic interpretability of constituents, making semantically-conditioned parsing decisions like PP attachment easier to learn. Models 1--3 have this property. 150 | \item When a system that uses constituent structure information as a latent variable is initialized from scratch and trained solely to perform some semantic task, it will likely learn some coherent syntax for the genre of language on which it is trained. If this syntax is reasonably stable across similar training sets and across random initializations of training, it can reasonably be said to be latent in the text itself, offering a new kind of evidence about natural language syntax. Model 4 aims to collect this kind of evidence. 151 | 152 | \end{itemize} 153 | 154 | All five models function as sentence encoding models: their inputs are sentences (input as sequences of words, with the help of a learned embedding matrix), and their outputs are single sentence encoding vectors which can be used as the inputs to downstream models for tasks like sentiment analysis, translation, or inference. Provided that these downstream models are differentiable---as is the case for neural network models and simpler regression models---their gradient signals can be used to train the sentence encoding models. 155 | 156 | \begin{table*}[t] 157 | \centering\small 158 | \begin{tabular}{cccccc} 159 | \toprule 160 | Name & Stack Representation & Input Representation & Ops Classifier & \multicolumn{2}{l}{Op Predictions Used In} \\ 161 | & & & & Training & Testing \\ 162 | \midrule 163 | Model 0 & Discrete & Op. sequence & N & -- & -- \\ 164 | Model 1 & Discrete & \bf Discrete Buffer & \bf Y: Directly supervised & \bf N & Y \\ 165 | Model 2 & Discrete & Discrete Buffer & Y: Directly supervised & \bf Y & Y \\ 166 | Model 2S & Discrete & Discrete Buffer & Y: Directly supervised & \bf Sched. sampling & Y \\ 167 | Model 3 & \bf Soft & \bf Soft Buffer & Y: Directly supervised & Y & Y \\ 168 | Model 4 & Soft & Soft Buffer & \bf Y: Indirectly supervised & Y & Y \\ 169 | \bottomrule 170 | \end{tabular} 171 | \protect\caption{\protect\label{models-table}Model variants, ordered by increasing reliance on learning. Bolding indicates the differences between each model and its parent model.} 172 | \end{table*} 173 | 174 | \subsection{Publication plans} 175 | 176 | \subsubsection{Winter 2016} 177 | 178 | Our first paper will be focused on Model 0 and Model 2S, with no discussion of soft stack models. Our primary goal with this paper will be to achieve state-of-the-art performance on SNLI with Model 0, and performance with Model 2S that surpasses that of a baseline LSTM encoder. We will pursue the state of the art both in the Siamese architecture regime with sentence encoders and in the \citealt{rocktaschel2015reasoning}-style word-by-word attention regime (reimplemented as constituent-by-constituent attention (See \S\ref{sec:c-by-c}). 179 | 180 | \subsubsection{Spring 2015} 181 | 182 | Our second paper will be focused on the soft stack models, Models 3 and 4. It will investigate the role of differentiable stack data structures in sentence encoding and the possibility of learning consistent and useful syntactic parsers using only semantic supervision. 183 | 184 | \section{Models} 185 | 186 | \subsection{Model 0} 187 | 188 | Model 0, depicted in Figure~\ref{fig:model:0}, is the simplest instantiation of our design, using only a conventional stack and a learned composition function to incrementally build a sentence representation over a series of timesteps. For a sentence of $N$ words its input is a sequence of $2N-3$ inputs. These inputs can take two types. At some timesteps, the input will be a word embedding vector. This triggers the \shift~operation, in which the vector is pushed onto the top of a stack. At other timesteps, the input will be the special \reduce~token, which triggers the reduction operation. In that operation, the top two word vectors are popped from the stack, fed into a learned composition function that maps them to a single vector (in the simplest case, this is a single neural network layer), and then pushed back onto the stack. 189 | 190 | If we add no additional features, this model computes the same function as a plain TreeRNN. However, we expect it to be substantially faster than conventional TreeRNN implementations. Unlike a TreeRNN, the Model 0 computation graph is essentially static across examples, so examples of varying structures and lengths can be batched together and run on the same graph in a single step. This simply requires ensuring that the graph is run for enough timesteps to finish all of the sentences. This involves some wasted computation, since the composition function will be run $2N-3$ times (with the output of composition at non-\reduce~steps discarded), rather than $N-1$ times in a TreeRNN. However, this loss can be dramatically offset by the gains of batching, which stem from the ability to exploit highly optimized CPU and GPU libraries for batched matrix multiiplication. 191 | 192 | \subsection{Model 1} 193 | 194 | \input{../model1_fig.tex} 195 | 196 | Model 1, depicted in Figures~\ref{fig:model:1d} and \ref{fig:model:1b}, adapts Model 0 to use a stack and a buffer, making it more closely resemble a shift--reduce parser, and laying the groundwork for a model which can parse novel sentences at test time. 197 | 198 | The model runs for a fixed number of transition steps: $2N - 3$. In its starting configuration, it contains a stack that is prepopulated with the first two words of the sentence (since \shift~\shift~is the only legal operation sequence for the first two timesteps of a true shift-reduce parser), as well as a buffer (a queue) prepopulated with all of the remaining words in the sentence. Both the stack and buffer represent words using their embeddings. 199 | 200 | At each timestep at test time, the model combines views of the stack and buffer (the top element of the buffer and the top two elements of the stack, highlighted in yellow) as the input to a tracking LSTM (red). This LSTM's output is fed into a sigmoid operation classifier (blue) which chooses between the \shift~and \reduce~operations. If \shift~is chosen, one word embedding is popped from the buffer and pushed onto the stack. If \reduce~is chosen, the buffer is left as is, and the top two elements of the stack are popped and composed using a learned composition function (green), with the result placed back on top of the stack. 201 | 202 | \paragraph{Supervision} The model is trained using two objective functions simultaneously. The semantic objective function is computed by feeding the value from the top of the stack at the final timestep---the full sentence encoding---into a downstream neural network model for some semantic task, like a sentiment classifier or an entailment classifier. The gradients from that classifier propagate to every part of the model except the operation classifier (blue). The syntactic objective function takes the form of direct supervision on the operation classifier (blue) which encourages that classifier to produce the same sequence of operations that an existing parser would produce for that sentence. The gradients from the syntactic objective function propagate to every part of the model but the downstream semantic model. 203 | 204 | At training time, following the strategy used in LSTM text decoders, the decisions made by the operation classifier (blue) is discarded, and the model instead uses the correct operation as specified in the (already parsed) training corpus. At test time, this signal is not available, and the model uses its own predicted operations. 205 | 206 | \subsection{Model 2} 207 | 208 | Model 2 makes a small change to Model 1 that is likely to substantially change the dynamics of learning: It uses the operation sequence predicted by the operation classifier (blue) at training time as well as at test time. It may be possible to accelerate Model 2 training by initializing it with parameters learned by Model 1. 209 | 210 | By exposing Model 2 to the results of its own decisions during training, we encourage it to become more robust to its own prediction errors. \citealt{bengio2015scheduled} applied a similar strategy\footnote{The authors experiment with several strategies which interpolate between oracle-driven training and oracle-free training (Models 1 and 2 in our presentation, respectively). It may be useful to adopt a similar interpolating approach.} to an image captioning model. They suggest that the resulting model can avoid propagating prediction errors through long sequences due to this training regime. 211 | 212 | \subsection{Model 3} 213 | 214 | Model 3 modifies Model 2 by introducing the soft stack/soft queue from \cite{grefenstette2015learning} in place of the hard, conventional stack and buffer. The soft stack makes it possible to for the model to predict smooth distributions over operations of the form (0.93 \shift, 0.07 \reduce), instead of making hard decisions. These soft decisions allow for gradient information to flow from the stack and the buffer back into the operation classifier (blue). This is crucial to our ultimate goal, as it makes it possible for semantic considerations to influence the model's parsing decisions. 215 | 216 | Model 3 still receives a direct supervision signal from some existing parser. In order to train the soft stack, we must represent the hard supervision signal from the parser by a soft prediction which matches the soft stack operation output. The supervision signal simply assigns 100\% weight to the ground-truth operation. 217 | 218 | \subsection{Model 4} 219 | 220 | Model 4 modifies Model 3 by removing the direct supervision signal from the operation classifier (blue), instead forcing the operation classifier to learn solely from the gradient provided by the downstream supervision task. It may be possible to accelerate or otherwise improve Model 4 training by initializing it with parameters learned by Model 3. 221 | 222 | By removing an external parser signal, we allow Model 4 to fully exploit the soft stack representation. It is free to predict soft parse operations in the case of ambiguous parses. This distinguishes it from Model 3, which is encouraged to replicate the 100\%-certain ground truth parse predictions provided by the external parser. 223 | 224 | \section{Other possible model features} 225 | 226 | \subsection{Contextually-informed composition} 227 | 228 | The composition function in the basic model (green) combines only the top elements of the stack, without using any further information. We can encourage the composition function to learn to do some amount of context-sensitive interpretation/disambiguation by adding a connection from the tracking LSTM (red) directly into the composition function. 229 | 230 | For Model 0, no tracking LSTM is needed for the ordinary operation of the model, but we can simply add one for this purpose, taking as inputs the top two values of the stack at each time point and emitting as output a context vector that can be used to condition the composition function. 231 | 232 | So far, we have found this technique to yield non-trivial performance gain. A clearer evaluation on converged models is forthcoming. 233 | 234 | \subsection{Constituent-by-constituent attention}\label{sec:c-by-c} 235 | 236 | \input{../tree_attn_fig.tex} 237 | 238 | We aim to build on the results of \citealt{rocktaschel2015reasoning} and \citealt{wang2015learning}, who find that neural soft attention models is an extremely effective technique for learning natural language inference. In particular, both papers use versions of word-by-word entailment, in which a latent alignment is learned from every word in the hypothesis to one or more words of the premise. We propose to borrow this basic idea, but to adapt it to a tree-structured setting, proposing \textit{constituent-by-constituent} attention. While these models do attention over a matrix $\mathbf{Y}$ of word-in-context representations from the premise encoder, we will perform attention instead over our own primary data structure, $\mathbf{Y}^{st}$, the matrix of vectors that have appeared at the top of the stack during premise encoding, which correspond one-to-one to the constituents in the tree structure representing the premise. Similarly, while the previous models perform one instance of soft attention conditioned on each word in the hypothesis, we perform one instance of soft attention conditioned on each stack top in the hypothesis encoder, representing the constituents of the hypothesis tree. 239 | 240 | In our model, attention is performed at each step $t$ of the premise encoder. At step $t$, the query vector that drives attention will be $S^t_0$, the top of the stack. 241 | 242 | \todo{[AR, RG] Write up the attention equations that you use.} 243 | 244 | \todo{[Anyone -- Sam can] Draw a diagram indicating the structure of constituent by constituent attention.} 245 | 246 | \subsection{Encoding the contents of the stack and buffer} 247 | 248 | \note{Not currently planned.} 249 | 250 | The tracking LSTM (red) needs access to the top of the buffer and the top two elements of the stack in order to make even minimally informed decisions about whether to shift or reduce. It could benefit further from additional information about broader sentential context. This can be provided by running new LSTMs along the elements of each of the stack and the buffer (following \citealt{dyer-EtAl:2015:ACL-IJCNLP}) and feeding the result into the tracking LSTM. 251 | 252 | \subsection{Typed \reduce~operations} 253 | 254 | \note{Not currently planned.} 255 | 256 | Shift-reduce parsers for natural language typically operate with a restricted set of typed \reduce~operations (also known as ``arc'' operations). These operations specify the precise relation between the elements being merged. It would be possible to train any of the parse-supervised models (1--3) to learn such typed arc operations, expanding the op set dramatically to something like \{\shift, \reduce-NP, \reduce-S, \reduce-PP, ...\} (in the case of constituency parse supervision). The model can then learn a distinct composition function depending on the relation of the two elements being merged. 257 | 258 | 259 | \section{Implementation notes} 260 | 261 | The size of the stack should be $N$ for sentences of $N$ words, in case the first reduce merges the final two words. The size of the buffer should be $N - 2$. 262 | 263 | \subsection{Data preparation} 264 | 265 | For Models 0--3, all training data must be parsed in advance into an unlabeled binary constituency tree. In addition, Model 0 requires that parses be available at test time as well. For both SST and SNLI we use the parses included with the corpus distributions whenever parses are needed. 266 | 267 | For model 0, training data can be prepared by linearizing the provided parse, then deleting left brackets and replacing right brackets with \reduce~instructions. That is demonstrated here with the example sentence \ii{the cat sat down}: 268 | 269 | \begin{quote}\small 270 | ( ( the cat ) ( sat down ) )$\Rightarrow$\\ 271 | the cat \reduce~sat down \reduce~\reduce 272 | \end{quote} 273 | 274 | The input for models 1--4 is simply the word sequence from the parse, with the first two words moved into the stack. The syntactic supervision labels for models 1--3 are simply a binarized version of the Model 0 inputs, with the first two tokens (which are necessarily \shift~\shift) omitted: 275 | 276 | \begin{quote}\small 277 | ( ( the cat ) ( sat down ) )$\Rightarrow$ \\ 278 | stack: $\langle$the, cat$\rangle$\\ 279 | buffer: $\langle$sat, down$\rangle$\\ 280 | ops: \reduce~\shift~\shift~\reduce~\reduce 281 | \end{quote} 282 | 283 | \subsection{Memory management and the thin stack} 284 | 285 | \todo{[JG] Explain.} 286 | 287 | \section{Experiments} 288 | 289 | \subsection{Step time} 290 | 291 | \todo{Does anyone know of a better baseline TreeRNN implementation that we can use on Jagupard? We can use the CoreNLP SST model, but using a Java model as a baseline seems worrying unless we're guaranteed that it's competitively fast.} 292 | 293 | Comparing model step time to the plain RNN of \citealt{li2015tree}. We use the small \citealt{pang2004sentimental} sentiment corpus that they use, and train with identical hyperparameters: ... 294 | 295 | Evaluation metrics: Time per minibatch on a jag machine, with and without GPU access. 296 | 297 | \subsection{Sentiment} 298 | 299 | \note{We're running sentiment experiments to evaluate the behavior of converged models, since SNLI takes too long to train to convergence. However, we don't currently plan to publish SNLI results.} 300 | 301 | Learning to jointly parse and to predict logical relations between sentences over SST 302 | \cite{socher2013recursive}. 303 | 304 | Evaluation metrics: accuracy, F1 (for all models but 0). 305 | 306 | 307 | \subsection{Natural language inference} 308 | 309 | Learning to jointly parse and to predict logical relations between sentences over SNLI \cite{snli:emnlp2015}. 310 | 311 | Evaluation metrics: accuracy, F1 (for all models but 0). 312 | 313 | \subsection{Parser quality evaluation} 314 | 315 | \todo{[RG] Write up the parser experiment results.} 316 | 317 | \section{Discussion} 318 | 319 | \subsection{Inferred Model 4 parses?} 320 | 321 | \subsubsection*{Acknowledgments} 322 | 323 | (some of) the Tesla K40(s) used for this research was/were donated by the NVIDIA Corporation. 324 | 325 | \bibliographystyle{../acl} 326 | \bibliography{../MLSemantics} 327 | 328 | \end{document} 329 | -------------------------------------------------------------------------------- /writing/hard_stack_paper/batching_fig.tex: -------------------------------------------------------------------------------- 1 | %!TEX root = paper.tex 2 | 3 | \begin{figure}[t] 4 | 5 | \begin{subfigure}[t]{\columnwidth} 6 | \begin{center} 7 | \scalebox{1}{ 8 | \begin{tikzpicture} 9 | \tikzstyle{word}=[fill=yellow!40,text height=2mm,line width=1pt] 10 | \tikzstyle{nonleaf}=[fill=yellow!40,text height=2mm,line width=1pt,draw=black] 11 | \tikzstyle{alt}=[fill=orange!40] 12 | \pgfsetarrowsend{latex} 13 | \tikzstyle{fwd} = [draw=black, line width=1pt] 14 | 15 | \def\dx{23pt} 16 | \def\dy{11pt} 17 | \def\sy{7*\dy} 18 | \def\oxb{5.5*\dx} 19 | \def\by{1*\dy} 20 | \def\ox{0*\oxb} 21 | 22 | \begin{scope}[shift={(0in,0in)}, frontier/.style={distance from root=60pt}] 23 | 24 | \node[word,alt] (w1) at (\ox+-3*\dx,\by+0*\dy) {the}; 25 | \node[word,alt] (w2) at (\ox+-1*\dx,\by+0*\dy) {old}; 26 | \node[word,alt] (w3) at (\ox+1*\dx,\by+0*\dy) {cat}; 27 | \node[word,alt] (w4) at (\ox+3*\dx,\by+0*\dy) {ate}; 28 | 29 | \node[word,alt] (n1) at (\ox+-3*\dx,\by+3*\dy) {~~~~~}; 30 | \node[word,alt] (n2) at (\ox+-1*\dx,\by+3*\dy) {~~~~~}; 31 | \node[word,alt] (n3) at (\ox+1*\dx,\by+3*\dy) {~~~~~}; 32 | \node[word,alt] (n4) at (\ox+3*\dx,\by+3*\dy) {~~~~~}; 33 | \node[] (n5) at (\ox+4.5*\dx,\by+3*\dy) {...}; 34 | 35 | \draw [fwd] (w1) -- (n1); 36 | \draw [fwd] (w2) -- (n2); 37 | \draw [fwd] (w3) -- (n3); 38 | \draw [fwd] (w4) -- (n4); 39 | 40 | \draw [fwd] (n1) -- (n2); 41 | \draw [fwd] (n2) -- (n3); 42 | \draw [fwd] (n3) -- (n4); 43 | \draw [fwd] (n4) -- (n5); 44 | 45 | 46 | \end{scope} 47 | 48 | \begin{scope}[shift={(0.25in,-0.2in)}, frontier/.style={distance from root=60pt}] 49 | 50 | \node[word] (w1) at (\ox+-3*\dx,\by+0*\dy) {the}; 51 | \node[word] (w2) at (\ox+-1*\dx,\by+0*\dy) {cat}; 52 | \node[word] (w3) at (\ox+1*\dx,\by+0*\dy) {sat}; 53 | \node[word] (w4) at (\ox+3*\dx,\by+0*\dy) {down}; 54 | 55 | \node[word] (n1) at (\ox+-3*\dx,\by+3*\dy) {~~~~~}; 56 | \node[word] (n2) at (\ox+-1*\dx,\by+3*\dy) {~~~~~}; 57 | \node[word] (n3) at (\ox+1*\dx,\by+3*\dy) {~~~~~}; 58 | \node[word] (n4) at (\ox+3*\dx,\by+3*\dy) {~~~~~}; 59 | \node[] (n5) at (\ox+4.5*\dx,\by+3*\dy) {...}; 60 | 61 | 62 | \draw [fwd] (w1) -- (n1); 63 | \draw [fwd] (w2) -- (n2); 64 | \draw [fwd] (w3) -- (n3); 65 | \draw [fwd] (w4) -- (n4); 66 | 67 | \draw [fwd] (n1) -- (n2); 68 | \draw [fwd] (n2) -- (n3); 69 | \draw [fwd] (n3) -- (n4); 70 | \draw [fwd] (n4) -- (n5); 71 | 72 | \end{scope} 73 | 74 | \end{tikzpicture}} 75 | \end{center} 76 | 77 | 78 | \caption{\label{fig:batching:good}A conventional sequence-based RNN for two sentences.} 79 | \end{subfigure} 80 | 81 | \begin{subfigure}[t]{\columnwidth} 82 | \begin{center} 83 | \scalebox{1}{ 84 | \begin{tikzpicture} 85 | \tikzstyle{word}=[fill=yellow!40,text height=2mm,line width=1pt] 86 | \tikzstyle{nonleaf}=[fill=yellow!40,text height=2mm,line width=1pt] 87 | \tikzstyle{alt}=[fill=orange!40] 88 | \pgfsetarrowsend{latex} 89 | \tikzset{edge from parent/.append style={<-, line width=1pt, >=latex}} 90 | 91 | \begin{scope}[shift={(0in,0in)}] 92 | 93 | \Tree [.\node[](root){...}; [.\node[nonleaf,alt](2thekittenate){the old cat ate}; [.\node[nonleaf,alt](2thekitten){the old cat}; \node[word,alt](2the){the}; [.\node[nonleaf,alt](2bigkitten){old cat}; \node[word,alt](2kitten){old}; \node[word,alt](2kitten){cat}; ] ] \node[word,alt](2ate){ate}; ] ] 94 | 95 | \end{scope} 96 | 97 | 98 | \begin{scope}[shift={(1.5in,0in)}] 99 | 100 | \Tree [.\node[](root){...}; [.\node[nonleaf](1thecatsatdown){the cat sat down}; [.\node[nonleaf](1thecat){the cat}; \node[word](1the){the}; \node[word](1cat){cat}; ] [.\node[nonleaf](1satdown){sat down}; \node[word](1sat){sat}; \node[word](1down){down}; ] ] ] 101 | 102 | \end{scope} 103 | 104 | \end{tikzpicture}} 105 | \end{center} 106 | 107 | \caption{\label{fig:batching:bad}A conventional TreeRNN for two sentences.} 108 | \end{subfigure} 109 | 110 | \caption{\label{fig:batching} An illustration of two standard designs for sentence encoders. The TreeRNN, unlike the sequence-based RNN, requires a substantially different connection structure for each sentence, making batched computation impractical.} 111 | \end{figure} 112 | -------------------------------------------------------------------------------- /writing/hard_stack_paper/bowman2016spinn.bib: -------------------------------------------------------------------------------- 1 | @inproceedings{bowman2016spinn, 2 | title = {A Fast Unified Model for Parsing and Sentence Understanding}, 3 | author = {Samuel R. Bowman and Jon Gauthier and Abhinav Rastogi and Raghav Gupta and Christopher D. Manning and Christopher Potts}, 4 | booktitle = {Association for Computational Linguistics (ACL)}, 5 | year = {2016} 6 | } 7 | -------------------------------------------------------------------------------- /writing/hard_stack_paper/runtime.tsv: -------------------------------------------------------------------------------- 1 | Batch size CPU GPU RNN 2 | 2048 2.737505333 0.7130436667 3 | 1024 1.697412 0.4858816667 4 | 512 23.259665533 0.991202 0.2629633333 5 | 256 12.059698 0.9430405333 0.2198066667 6 | 128 6.380825333 0.73615333 0.1706013333 7 | 64 3.192774 0.7601923333 0.182761 8 | 32 1.597907667 0.5116996667 0.1506473333 9 | -------------------------------------------------------------------------------- /writing/model0_fig.tex: -------------------------------------------------------------------------------- 1 | %!TEX root = hard_stack_paper/paper.tex 2 | 3 | 4 | \begin{figure*}[t] 5 | 6 | \begin{subfigure}[t]{\textwidth} 7 | \centering 8 | \scalebox{0.6}{ 9 | \begin{tikzpicture} 10 | \def\dx{21pt} 11 | \def\dy{11pt} 12 | \def\sy{13*\dy} 13 | \def\oxb{8*\dx} 14 | \def\by{1*\dy} 15 | \def\ox{0*\oxb} 16 | 17 | \tikzstyle{label}=[text width=35mm,align=center,text height=2mm] 18 | \tikzstyle{word}=[text width=35mm,align=center,text height=2mm] 19 | \tikzstyle{tracker}=[fill=red!40,text width=15mm,align=center,text height=2mm] 20 | \tikzstyle{softmax}=[fill=blue!40,text width=15mm,align=center,text height=2mm] 21 | \tikzstyle{comp}=[fill=green!40,text width=20mm,align=center,text height=2mm] 22 | \tikzstyle{compoff}=[fill=green!10!black!10,text width=20mm,align=center,text height=2mm] 23 | \tikzstyle{result}=[line width=1pt,draw=black,text width=15mm,align=center,text height=2mm] 24 | \tikzstyle{sbox}=[line width=1pt,draw=black,text width=25mm,align=center,text height=13.3mm] 25 | \tikzstyle{bbox}=[line width=1pt,draw=black,text width=25mm,align=center,text height=6.5mm] 26 | \tikzstyle{focus1}=[fill=yellow!40,text width=25mm,align=center,text height=2mm] 27 | \tikzstyle{focus2}=[fill=yellow!40,text width=25mm,align=center,text height=5.5mm] 28 | 29 | \node[label] (sl) at (\ox-0.35*\oxb+0*\dx,\by+0.5*\dy) {buffer}; 30 | 31 | \node[label] (1l) at (\ox+0*\dx,\sy+3*\dy) {$t=0$}; 32 | 33 | \node[focus1] (0bb) at (\ox+0*\dx,2*\dy) {}; 34 | \node[word] (0b3) at (\ox+0*\dx,\by-1*\dy) {}; 35 | \node[word] (0b2) at (\ox+0*\dx,\by+0*\dy) {down}; 36 | \node[word] (0b1) at (\ox+0*\dx,\by+1*\dy) {sat}; 37 | \node[bbox] (0bb) at (\ox+0*\dx,\by+0.5*\dy) {}; 38 | 39 | \node[label] (sl) at (\ox-0.35*\oxb+0*\dx,\sy+0.5*\dy) {stack}; 40 | 41 | \node[focus2] (0sb) at (\ox+0*\dx,\sy-0.5*\dy) {}; 42 | \node[word] (0s1) at (\ox+0*\dx,\sy-1*\dy) {cat}; 43 | \node[word] (0s2) at (\ox+0*\dx,\sy+0*\dy) {the}; 44 | \node[word] (0s3) at (\ox+0*\dx,\sy+1*\dy) {}; 45 | \node[sbox] (0sb) at (\ox+0*\dx,\sy+0.5*\dy) {}; 46 | 47 | \node[comp] (0c) at (\ox+0.5*\oxb,\sy-1.5*\dy) {composition}; 48 | 49 | \node[tracker] (0t) at (\ox+0*\dx,5*\dy) {tracking}; 50 | % \node[softmax] (0sm) at (\ox+3*\dx,7*\dy) {$\sigma$}; 51 | \node[result] (0so) at (\ox+3*\dx,9*\dy) {\reduce}; 52 | 53 | \def\ox{1*\oxb} 54 | 55 | \node[label] (1l) at (\ox+0*\dx,\sy+3*\dy) {$t=1$}; 56 | 57 | \node[focus1] (1bb) at (\ox+0*\dx,2*\dy) {}; 58 | \node[word] (1b3) at (\ox+0*\dx,\by-1*\dy) {}; 59 | \node[word] (1b2) at (\ox+0*\dx,\by+0*\dy) {down}; 60 | \node[word] (1b1) at (\ox+0*\dx,\by+1*\dy) {sat}; 61 | \node[bbox] (1bb) at (\ox+0*\dx,\by+0.5*\dy) {}; 62 | 63 | \node[focus2] (1sb) at (\ox+0*\dx,\sy-0.5*\dy) {}; 64 | \node[word] (1s1) at (\ox+0*\dx,\sy-1*\dy) {the cat}; 65 | \node[word] (1s2) at (\ox+0*\dx,\sy+0*\dy) {}; 66 | \node[word] (1s3) at (\ox+0*\dx,\sy+1*\dy) {}; 67 | \node[sbox] (1sb) at (\ox+0*\dx,\sy+0.5*\dy) {}; 68 | 69 | \node[compoff] (1c) at (\ox+0.5*\oxb,\sy-1.5*\dy) {composition}; 70 | 71 | \node[tracker] (1t) at (\ox+0*\dx,5*\dy) {tracking}; 72 | % \node[softmax] (1sm) at (\ox+3*\dx,7*\dy) {$\sigma$}; 73 | \node[result] (1so) at (\ox+3*\dx,9*\dy) {\shift}; 74 | 75 | \def\ox{2*\oxb} 76 | 77 | \node[label] (1l) at (\ox+0*\dx,\sy+3*\dy) {$t=2$}; 78 | 79 | \node[focus1] (2bb) at (\ox+0*\dx,2*\dy) {}; 80 | \node[word] (2b3) at (\ox+0*\dx,\by-1*\dy) {}; 81 | \node[word] (2b2) at (\ox+0*\dx,\by+0*\dy) {}; 82 | \node[word] (2b1) at (\ox+0*\dx,\by+1*\dy) {down}; 83 | \node[bbox] (2bb) at (\ox+0*\dx,\by+0.5*\dy) {}; 84 | 85 | \node[focus2] (2sb) at (\ox+0*\dx,\sy-0.5*\dy) {}; 86 | \node[word] (2s1) at (\ox+0*\dx,\sy-1*\dy) {sat}; 87 | \node[word] (2s2) at (\ox+0*\dx,\sy+0*\dy) {the cat}; 88 | \node[word] (2s3) at (\ox+0*\dx,\sy+1*\dy) {}; 89 | \node[sbox] (2sb) at (\ox+0*\dx,\sy+0.5*\dy) {}; 90 | 91 | \node[tracker] (2t) at (\ox+0*\dx,5*\dy) {tracking}; 92 | 93 | 94 | \pgfsetarrowsend{latex} 95 | \tikzstyle{fwd} = [draw=black, line width=1pt] 96 | \tikzstyle{gated} = [draw=black!50, line width=0.8pt] 97 | 98 | \draw [fwd] (0sb) -- (0t); 99 | \draw [fwd] (0bb) -- (0t); 100 | \draw [fwd] (0sb) -- (0c); 101 | 102 | \draw [fwd] (0t) -- (1t); 103 | \draw [fwd] (0t) to[out=70,in=-160] (0c); 104 | \draw [fwd] (0sb) -- (1sb); 105 | \draw [fwd] (0bb) -- (1bb); 106 | \draw [fwd] (0so) -- (1sb); 107 | \draw [fwd] (0so) -- (1bb); 108 | \draw [gated] (0bb) to[out=20,in=-110] (1sb); 109 | \draw [fwd] (0c) -- (1sb); 110 | 111 | \draw [fwd] (1sb) -- (1t); 112 | \draw [fwd] (1bb) -- (1t); 113 | \draw [fwd] (1sb) -- (1c); 114 | 115 | \draw [fwd] (1t) -- (2t); 116 | \draw [fwd] (1t) to[out=70,in=-160] (1c); 117 | \draw [fwd] (1sb) -- (2sb); 118 | \draw [fwd] (1bb) -- (2bb); 119 | \draw [fwd] (1so) -- (2sb); 120 | \draw [fwd] (1so) -- (2bb); 121 | \draw [fwd] (1bb) to[out=20,in=-110] (2sb); 122 | \draw [gated] (1c) -- (2sb); 123 | 124 | \draw [fwd] (2sb) -- (2t); 125 | \draw [fwd] (2bb) -- (2t); 126 | 127 | 128 | \end{tikzpicture}} 129 | 130 | \caption{The model unrolled for two transitions on the input \word{the cat sat down}.}\label{fig:model:0} 131 | 132 | \end{subfigure}\\\\\\ 133 | \begin{subfigure}[t]{\textwidth} 134 | \centering 135 | \scalebox{0.6}{ 136 | \begin{tikzpicture} 137 | \def\dx{21pt} 138 | \def\dy{11pt} 139 | \def\sy{7*\dy} 140 | \def\oxb{5.5*\dx} 141 | \def\by{1*\dy} 142 | \def\ox{0*\oxb} 143 | 144 | \tikzstyle{label}=[text width=35mm,align=center,text height=2mm] 145 | \tikzstyle{word}=[text width=35mm,align=center,text height=2mm] 146 | \tikzstyle{tracker}=[fill=red!40,text width=15mm,align=center,text height=2mm] 147 | \tikzstyle{softmax}=[text width=40mm,align=center,text height=2mm] 148 | \tikzstyle{comp}=[fill=green!40,text width=20mm,align=center,text height=2mm] 149 | \tikzstyle{result}=[line width=1pt,draw=black,text width=15mm,align=center,text height=2mm] 150 | \tikzstyle{sbox}=[line width=1pt,draw=black,text width=33mm,align=center,text height=13.3mm] 151 | \tikzstyle{bbox}=[line width=1pt,draw=black,text width=33mm,align=center,text height=6.5mm] 152 | \tikzstyle{focus1}=[fill=yellow!40,text width=33mm,align=center,text height=2mm] 153 | \tikzstyle{focus2}=[fill=yellow!40,text width=33mm,align=center,text height=5.5mm] 154 | 155 | \def\ox{0*\oxb} 156 | 157 | \node[label] (1l) at (\ox+0*\dx,\sy+3*\dy) {$t=0$}; 158 | 159 | \node[label] (sl) at (\ox-0.65*\oxb+0*\dx,\by+0.5*\dy) {buffer}; 160 | 161 | \node[focus1] (1bb) at (\ox+0*\dx,2*\dy) {}; 162 | \node[word] (1b3) at (\ox+0*\dx,\by-1*\dy) {}; 163 | \node[word] (1b2) at (\ox+0*\dx,\by+0*\dy) {down}; 164 | \node[word] (1b1) at (\ox+0*\dx,\by+1*\dy) {sat}; 165 | \node[bbox] (1bb) at (\ox+0*\dx,\by+0.5*\dy) {}; 166 | 167 | \node[label] (sl) at (\ox-0.65*\oxb+0*\dx,\sy+0.5*\dy) {stack}; 168 | 169 | \node[focus2] (1sb) at (\ox+0*\dx,\sy-0.5*\dy) {}; 170 | \node[word] (1s1) at (\ox+0*\dx,\sy-1*\dy) {cat}; 171 | \node[word] (1s2) at (\ox+0*\dx,\sy+0*\dy) {the}; 172 | \node[word] (1s3) at (\ox+0*\dx,\sy+1*\dy) {}; 173 | \node[sbox] (1sb) at (\ox+0*\dx,\sy+0.5*\dy) {}; 174 | 175 | \node[result] (1so) at (\ox+0.5*\dx,4*\dy) {\reduce}; 176 | 177 | \def\ox{1*\oxb} 178 | 179 | \node[label] (1l) at (\ox+0*\dx,\sy+3*\dy) {$t=1$}; 180 | 181 | \node[focus1] (2bb) at (\ox+0*\dx,2*\dy) {}; 182 | \node[word] (2b3) at (\ox+0*\dx,\by-1*\dy) {}; 183 | \node[word] (2b2) at (\ox+0*\dx,\by+0*\dy) {down}; 184 | \node[word] (2b1) at (\ox+0*\dx,\by+1*\dy) {sat}; 185 | \node[bbox] (2bb) at (\ox+0*\dx,\by+0.5*\dy) {}; 186 | 187 | \node[focus2] (2sb) at (\ox+0*\dx,\sy-0.5*\dy) {}; 188 | \node[word] (2s1) at (\ox+0*\dx,\sy-1*\dy) {the cat}; 189 | \node[word] (2s2) at (\ox+0*\dx,\sy+0*\dy) {}; 190 | \node[word] (2s3) at (\ox+0*\dx,\sy+1*\dy) {}; 191 | \node[sbox] (2sb) at (\ox+0*\dx,\sy+0.5*\dy) {}; 192 | 193 | \node[result] (2so) at (\ox+0.5*\dx,4*\dy) {\shift}; 194 | 195 | \def\ox{2*\oxb} 196 | 197 | \node[label] (1l) at (\ox+0*\dx,\sy+3*\dy) {$t=2$}; 198 | 199 | \node[focus1] (3bb) at (\ox+0*\dx,2*\dy) {}; 200 | \node[word] (3b3) at (\ox+0*\dx,\by-1*\dy) {}; 201 | \node[word] (3b2) at (\ox+0*\dx,\by+0*\dy) {}; 202 | \node[word] (3b1) at (\ox+0*\dx,\by+1*\dy) {down}; 203 | \node[bbox] (3bb) at (\ox+0*\dx,\by+0.5*\dy) {}; 204 | 205 | \node[focus2] (3sb) at (\ox+0*\dx,\sy-0.5*\dy) {}; 206 | \node[word] (3s1) at (\ox+0*\dx,\sy-1*\dy) {sat}; 207 | \node[word] (3s2) at (\ox+0*\dx,\sy+0*\dy) {the cat}; 208 | \node[word] (3s3) at (\ox+0*\dx,\sy+1*\dy) {}; 209 | \node[sbox] (3sb) at (\ox+0*\dx,\sy+0.5*\dy) {}; 210 | 211 | \node[result] (3so) at (\ox+0.5*\dx,4*\dy) {\shift}; 212 | 213 | \def\ox{3*\oxb} 214 | 215 | \node[label] (1l) at (\ox+0*\dx,\sy+3*\dy) {$t=3$}; 216 | 217 | \node[focus1] (4bb) at (\ox+0*\dx,2*\dy) {}; 218 | \node[word] (4b3) at (\ox+0*\dx,\by-1*\dy) {}; 219 | \node[word] (4b2) at (\ox+0*\dx,\by+0*\dy) {}; 220 | \node[word] (4b1) at (\ox+0*\dx,\by+1*\dy) {}; 221 | \node[bbox] (4bb) at (\ox+0*\dx,\by+0.5*\dy) {}; 222 | 223 | \node[focus2] (4sb) at (\ox+0*\dx,\sy-0.5*\dy) {}; 224 | \node[word] (4s1) at (\ox+0*\dx,\sy-1*\dy) {down}; 225 | \node[word] (4s2) at (\ox+0*\dx,\sy+0*\dy) {sat}; 226 | \node[word] (4s3) at (\ox+0*\dx,\sy+1*\dy) {the cat}; 227 | \node[sbox] (4sb) at (\ox+0*\dx,\sy+0.5*\dy) {}; 228 | 229 | \node[result] (4so) at (\ox+0.5*\dx,4*\dy) {\reduce}; 230 | 231 | \def\ox{4*\oxb} 232 | 233 | \node[label] (1l) at (\ox+0*\dx,\sy+3*\dy) {$t=4$}; 234 | 235 | \node[focus1] (5bb) at (\ox+0*\dx,2*\dy) {}; 236 | \node[word] (5b3) at (\ox+0*\dx,\by-1*\dy) {}; 237 | \node[word] (5b2) at (\ox+0*\dx,\by+0*\dy) {}; 238 | \node[word] (5b1) at (\ox+0*\dx,\by+1*\dy) {}; 239 | \node[bbox] (5bb) at (\ox+0*\dx,\by+0.5*\dy) {}; 240 | 241 | \node[focus2] (5sb) at (\ox+0*\dx,\sy-0.5*\dy) {}; 242 | \node[word] (5s1) at (\ox+0*\dx,\sy-1*\dy) {sat down}; 243 | \node[word] (5s2) at (\ox+0*\dx,\sy+0*\dy) {the cat}; 244 | \node[word] (5s3) at (\ox+0*\dx,\sy+1*\dy) {}; 245 | \node[sbox] (5sb) at (\ox+0*\dx,\sy+0.5*\dy) {}; 246 | 247 | \node[result] (5so) at (\ox+0.5*\dx,4*\dy) {\reduce}; 248 | 249 | \def\ox{5*\oxb} 250 | 251 | \node[label] (1l) at (\ox+0*\dx,\sy+3*\dy) {$t=5$}; 252 | 253 | \node[focus2] (6sb) at (\ox+0*\dx,\sy-0.5*\dy) {}; 254 | \node[word] (6s1) at (\ox+0*\dx,\sy-1*\dy) {(the cat) (sat down)}; 255 | \node[word] (6s2) at (\ox+0*\dx,\sy+0*\dy) {}; 256 | \node[word] (6s3) at (\ox+0*\dx,\sy+1*\dy) {}; 257 | \node[sbox] (6sb) at (\ox+0*\dx,\sy+0.5*\dy) {}; 258 | 259 | \node[softmax] (6sm) at (\ox+0.5*\dx,2*\dy) {output to model for semantic task}; 260 | 261 | \pgfsetarrowsend{latex} 262 | \tikzstyle{fwd} = [draw=black, line width=1pt] 263 | 264 | \draw [fwd] (1sb) -- (1so); 265 | \draw [fwd] (1bb) -- (1so); 266 | 267 | \draw [fwd] (1sb) -- (2sb); 268 | \draw [fwd] (1bb) -- (2bb); 269 | \draw [fwd] (1so) -- (2sb); 270 | \draw [fwd] (1so) -- (2bb); 271 | \draw [fwd] (1bb) -- (2sb); 272 | 273 | \draw [fwd] (2sb) -- (2so); 274 | \draw [fwd] (2bb) -- (2so); 275 | 276 | \draw [fwd] (2sb) -- (3sb); 277 | \draw [fwd] (2bb) -- (3bb); 278 | \draw [fwd] (2so) -- (3sb); 279 | \draw [fwd] (2so) -- (3bb); 280 | \draw [fwd] (2bb) -- (3sb); 281 | 282 | \draw [fwd] (3sb) -- (3so); 283 | \draw [fwd] (3bb) -- (3so); 284 | 285 | \draw [fwd] (3sb) -- (4sb); 286 | \draw [fwd] (3bb) -- (4bb); 287 | \draw [fwd] (3so) -- (4sb); 288 | \draw [fwd] (3so) -- (4bb); 289 | \draw [fwd] (3bb) -- (4sb); 290 | 291 | \draw [fwd] (4sb) -- (4so); 292 | \draw [fwd] (4bb) -- (4so); 293 | 294 | \draw [fwd] (4sb) -- (5sb); 295 | \draw [fwd] (4bb) -- (5bb); 296 | \draw [fwd] (4so) -- (5sb); 297 | \draw [fwd] (4so) -- (5bb); 298 | \draw [fwd] (4bb) -- (5sb); 299 | 300 | \draw [fwd] (5sb) -- (5so); 301 | \draw [fwd] (5bb) -- (5so); 302 | 303 | \draw [fwd] (5sb) -- (6sb); 304 | \draw [fwd] (5so) -- (6sb); 305 | 306 | \draw [fwd] (6sb) -- (6sm); 307 | 308 | \end{tikzpicture}} 309 | 310 | \caption{The fully unrolled model for \word{the cat sat down} with some layers omitted for clarity. \todo{[SB] Add in the first two steps.}}\label{fig:model:1b} 311 | 312 | \end{subfigure} 313 | \caption{\label{m1-views}Two views of the transition-based sentence model. In both views, the lower boxes represent the input buffer, and the upper boxes represent the stack. Yellow highlighting indicates which portions of these data structures are visible to the tracking LSTM and to the composition function. Thin gray arrows indicate connections which are blocked by a gating function, and so contribute no information. \todo{[SB] Clean up the arrangements of these figures now that we aren't reporting on Models 1/2.}} 314 | 315 | \end{figure*} 316 | -------------------------------------------------------------------------------- /writing/model1_fig.tex: -------------------------------------------------------------------------------- 1 | %!TEX root = hard_stack_paper/paper.tex 2 | 3 | \begin{figure*}[t] 4 | \begin{subfigure}[t]{\textwidth} 5 | \centering 6 | \scalebox{0.72}{% 7 | \begin{tikzpicture} 8 | \def\dx{21pt} 9 | \def\dy{11pt} 10 | \def\sy{12*\dy} 11 | \def\oxb{8*\dx} 12 | \def\by{0pt} 13 | \def\ox{0*\oxb} 14 | 15 | \tikzstyle{label}=[text width=35mm,align=center,text height=2mm] 16 | \tikzstyle{word}=[text width=35mm,align=center,text height=2mm] 17 | \tikzstyle{tracker}=[fill=red!40,text width=15mm,align=center,text height=2mm] 18 | \tikzstyle{softmax}=[fill=blue!40,text width=15mm,align=center,text height=2mm] 19 | \tikzstyle{comp}=[fill=green!40,text width=20mm,align=center,text height=2mm] 20 | \tikzstyle{compoff}=[fill=green!10!black!10,text width=20mm,align=center,text height=2mm] 21 | \tikzstyle{result}=[line width=1pt,draw=black,text width=15mm,align=center,text height=2mm] 22 | \tikzstyle{sbox}=[line width=1pt,draw=black,text width=25mm,align=center,text height=13.3mm] 23 | \tikzstyle{bbox}=[line width=1pt,draw=black,text width=25mm,align=center,text height=13.3mm] 24 | \tikzstyle{focus1}=[fill=yellow!40,text width=25mm,align=center,text height=2mm] 25 | \tikzstyle{focus2}=[fill=yellow!40,text width=25mm,align=center,text height=5.5mm] 26 | 27 | \node[label] (sl) at (\ox-0.35*\oxb+0*\dx,\by+1*\dy) {buffer}; 28 | 29 | \node[focus1] (0bb) at (\ox+0*\dx,2.5*\dy) {}; 30 | \node[word] (0b3) at (\ox+0*\dx,\by-0.5*\dy) {}; 31 | \node[word] (0b2) at (\ox+0*\dx,\by+1.5*\dy) {down}; 32 | \node[word] (0b1) at (\ox+0*\dx,\by+2.5*\dy) {sat}; 33 | \node[bbox] (0bb) at (\ox+0*\dx,\by+1.0*\dy) {}; 34 | 35 | \node[label] (sl) at (\ox-0.35*\oxb+0*\dx,\sy+0.5*\dy) {stack}; 36 | 37 | \node[focus2] (0sb) at (\ox+0*\dx,\sy-0.5*\dy) {}; 38 | \node[word] (0s1) at (\ox+0*\dx,\sy-1*\dy) {cat}; 39 | \node[word] (0s2) at (\ox+0*\dx,\sy+0*\dy) {the}; 40 | \node[word] (0s3) at (\ox+0*\dx,\sy+1*\dy) {}; 41 | \node[sbox] (0sb) at (\ox+0*\dx,\sy+0.5*\dy) {}; 42 | 43 | \node[comp] (0c) at (\ox+0.5*\oxb,\sy-1.5*\dy) {composition}; 44 | 45 | \node[tracker] (0t) at (\ox+0*\dx,5*\dy) {tracking}; 46 | \node[softmax] (0sm) at (\ox+3.25*\dx,6.25*\dy) {transition}; 47 | \node[result] (0so) at (\ox+3.25*\dx,8.5*\dy) {\reduce}; 48 | 49 | \def\ox{1*\oxb} 50 | 51 | \node[focus1] (1bb) at (\ox+0*\dx,2.5*\dy) {}; 52 | \node[word] (1b3) at (\ox+0*\dx,\by-0.5*\dy) {}; 53 | \node[word] (1b2) at (\ox+0*\dx,\by+1.5*\dy) {down}; 54 | \node[word] (1b1) at (\ox+0*\dx,\by+2.5*\dy) {sat}; 55 | \node[bbox] (1bb) at (\ox+0*\dx,\by+1*\dy) {}; 56 | 57 | \node[focus2] (1sb) at (\ox+0*\dx,\sy-0.5*\dy) {}; 58 | \node[word] (1s1) at (\ox+0*\dx,\sy-1*\dy) {the cat}; 59 | \node[word] (1s2) at (\ox+0*\dx,\sy+0*\dy) {}; 60 | \node[word] (1s3) at (\ox+0*\dx,\sy+1*\dy) {}; 61 | \node[sbox] (1sb) at (\ox+0*\dx,\sy+0.5*\dy) {}; 62 | 63 | \node[compoff] (1c) at (\ox+0.5*\oxb,\sy-1.5*\dy) {composition}; 64 | 65 | \node[tracker] (1t) at (\ox+0*\dx,5*\dy) {tracking}; 66 | \node[softmax] (1sm) at (\ox+3.25*\dx,6.25*\dy) {transition}; 67 | \node[result] (1so) at (\ox+3.25*\dx,8.5*\dy) {\shift}; 68 | 69 | \def\ox{2*\oxb} 70 | 71 | \node[focus1] (2bb) at (\ox+0*\dx,2.5*\dy) {}; 72 | \node[word] (2b3) at (\ox+0*\dx,\by-0.5*\dy) {}; 73 | \node[word] (2b2) at (\ox+0*\dx,\by+1.5*\dy) {}; 74 | \node[word] (2b1) at (\ox+0*\dx,\by+2.5*\dy) {down}; 75 | \node[bbox] (2bb) at (\ox+0*\dx,\by+1*\dy) {}; 76 | 77 | \node[focus2] (2sb) at (\ox+0*\dx,\sy-0.5*\dy) {}; 78 | \node[word] (2s1) at (\ox+0*\dx,\sy-1*\dy) {sat}; 79 | \node[word] (2s2) at (\ox+0*\dx,\sy+0*\dy) {the cat}; 80 | \node[word] (2s3) at (\ox+0*\dx,\sy+1*\dy) {}; 81 | \node[sbox] (2sb) at (\ox+0*\dx,\sy+0.5*\dy) {}; 82 | 83 | \node[tracker] (2t) at (\ox+0*\dx,5*\dy) {tracking}; 84 | 85 | 86 | \pgfsetarrowsend{latex} 87 | \tikzstyle{fwd} = [draw=black, line width=1pt] 88 | \tikzstyle{gated} = [draw=black!50, line width=0.8pt] 89 | 90 | \draw [fwd] (0sb) -- (0t); 91 | \draw [fwd] (0bb) -- (0t); 92 | \draw [fwd] (0t) -- (0sm); 93 | \draw [fwd] (0sm) -- (0so); 94 | \draw [fwd] (0sb) -- (0c); 95 | 96 | \draw [fwd] (0t) -- (1t); 97 | \draw [fwd] (0t) to[out=50,in=-175] (0c); 98 | \draw [fwd] (0sb) -- (1sb); 99 | \draw [fwd] (0bb) -- (1bb); 100 | \draw [fwd] (0so) to[out=5,in=-140] (1sb); 101 | \draw [fwd] (0so) to[out=-5,in=170] (1bb); 102 | \draw [gated] (0bb) to[out=15,in=-125] (1sb); 103 | \draw [fwd] (0c) -- (1sb); 104 | 105 | \draw [fwd] (1sb) -- (1t); 106 | \draw [fwd] (1bb) -- (1t); 107 | \draw [fwd] (1t) -- (1sm); 108 | \draw [fwd] (1sm) -- (1so); 109 | \draw [fwd] (1sb) -- (1c); 110 | 111 | \draw [fwd] (1t) -- (2t); 112 | \draw [fwd] (1t) to[out=50,in=-175] (1c); 113 | \draw [fwd] (1sb) -- (2sb); 114 | \draw [fwd] (1bb) -- (2bb); 115 | \draw [fwd] (1so) to[out=5,in=-140] (2sb); 116 | \draw [fwd] (1so) to[out=-5,in=170] (2bb); 117 | \draw [fwd] (1bb) to[out=15,in=-125] (2sb); 118 | \draw [gated] (1c) -- (2sb); 119 | 120 | \draw [fwd] (2sb) -- (2t); 121 | \draw [fwd] (2bb) -- (2t); 122 | 123 | 124 | \end{tikzpicture}} 125 | 126 | \caption{The SPINN model unrolled for two transitions during the processing of the sentence \word{the cat sat down}. `Tracking', `transition', and `composition' are neural network layers. Gray arrows indicate connections which are blocked by a gating function.}\label{fig:model:1d} 127 | 128 | \end{subfigure}\\\\\\ 129 | \begin{subfigure}[t]{\textwidth} 130 | \centering 131 | \scalebox{0.5}{% 132 | \begin{tikzpicture} 133 | \def\dx{20pt} 134 | \def\dy{11pt} 135 | \def\sy{7*\dy} 136 | \def\oxb{5.5*\dx} 137 | \def\by{0pt} 138 | \def\oxs{0pt} 139 | 140 | \tikzstyle{label}=[text width=15mm,align=center,text height=2mm] 141 | \tikzstyle{word}=[text width=32mm,align=center,text height=2mm] 142 | \tikzstyle{tracker}=[fill=red!40,text width=15mm,align=center,text height=2mm] 143 | \tikzstyle{softmax}=[text width=40mm,align=center,text height=2mm] 144 | \tikzstyle{comp}=[fill=green!40,text width=20mm,align=center,text height=2mm] 145 | \tikzstyle{result}=[line width=1pt,draw=black,text width=15mm,align=center,text height=2mm] 146 | \tikzstyle{sbox}=[line width=1pt,draw=black,text width=30mm,align=center,text height=13.3mm] 147 | \tikzstyle{bbox}=[line width=1pt,draw=black,text width=30mm,align=center,text height=13.3mm] 148 | \tikzstyle{focus1}=[fill=yellow!40,text width=30mm,align=center,text height=2mm] 149 | \tikzstyle{focus2}=[fill=yellow!40,text width=30mm,align=center,text height=5.5mm] 150 | 151 | \def\ox{0*\oxb+\oxs} 152 | 153 | \node[label] (sl) at (\ox-0.58*\oxb+0*\dx,\by+0.5*\dy) {buffer}; 154 | 155 | \node[label] (sl) at (\ox-0.58*\oxb+0*\dx,\sy+0.5*\dy) {stack}; 156 | 157 | \node[label] (00l) at (\ox+0*\dx,\sy+3*\dy) {$t=0$}; 158 | 159 | \node[focus1] (00bb) at (\ox+0*\dx,2*\dy) {}; 160 | \node[word] (00b3) at (\ox+0*\dx,\by-1*\dy) {down}; 161 | \node[word] (00b2) at (\ox+0*\dx,\by+0*\dy) {sat}; 162 | \node[word] (00b1) at (\ox+0*\dx,\by+1*\dy) {cat}; 163 | \node[word] (00b1) at (\ox+0*\dx,\by+2*\dy) {the}; 164 | \node[bbox] (00bb) at (\ox+0*\dx,\by+0.5*\dy) {}; 165 | 166 | \node[focus2] (00sb) at (\ox+0*\dx,\sy-0.5*\dy) {}; 167 | \node[word] (00s1) at (\ox+0*\dx,\sy-1*\dy) {}; 168 | \node[word] (00s2) at (\ox+0*\dx,\sy+0*\dy) {}; 169 | \node[word] (00s3) at (\ox+0*\dx,\sy+1*\dy) {}; 170 | \node[sbox] (00sb) at (\ox+0*\dx,\sy+0.5*\dy) {}; 171 | 172 | \node[result] (00so) at (\ox+0.7*\dx,4*\dy) {\shift}; 173 | 174 | \def\ox{1*\oxb+\oxs} 175 | 176 | \node[label] (0l) at (\ox+0*\dx,\sy+3*\dy) {$t=1$}; 177 | 178 | \node[focus1] (0bb) at (\ox+0*\dx,2*\dy) {}; 179 | \node[word] (00b3) at (\ox+0*\dx,\by-1*\dy) {}; 180 | \node[word] (00b2) at (\ox+0*\dx,\by+0*\dy) {down}; 181 | \node[word] (00b1) at (\ox+0*\dx,\by+1*\dy) {sat}; 182 | \node[word] (00b1) at (\ox+0*\dx,\by+2*\dy) {cat}; 183 | \node[bbox] (0bb) at (\ox+0*\dx,\by+0.5*\dy) {}; 184 | 185 | \node[focus2] (0sb) at (\ox+0*\dx,\sy-0.5*\dy) {}; 186 | \node[word] (0s1) at (\ox+0*\dx,\sy-1*\dy) {the}; 187 | \node[word] (0s2) at (\ox+0*\dx,\sy+0*\dy) {}; 188 | \node[word] (0s3) at (\ox+0*\dx,\sy+1*\dy) {}; 189 | \node[sbox] (0sb) at (\ox+0*\dx,\sy+0.5*\dy) {}; 190 | 191 | \node[result] (0so) at (\ox+0.7*\dx,4*\dy) {\shift}; 192 | 193 | \def\ox{2*\oxb+\oxs} 194 | 195 | \node[label] (1l) at (\ox+0*\dx,\sy+3*\dy) {$t=2$}; 196 | 197 | \node[focus1] (1bb) at (\ox+0*\dx,2*\dy) {}; 198 | \node[word] (1b3) at (\ox+0*\dx,\by-1*\dy) {}; 199 | \node[word] (1b2) at (\ox+0*\dx,\by+1*\dy) {down}; 200 | \node[word] (1b1) at (\ox+0*\dx,\by+2*\dy) {sat}; 201 | \node[bbox] (1bb) at (\ox+0*\dx,\by+0.5*\dy) {}; 202 | 203 | \node[focus2] (1sb) at (\ox+0*\dx,\sy-0.5*\dy) {}; 204 | \node[word] (1s1) at (\ox+0*\dx,\sy-1*\dy) {cat}; 205 | \node[word] (1s2) at (\ox+0*\dx,\sy+0*\dy) {the}; 206 | \node[word] (1s3) at (\ox+0*\dx,\sy+1*\dy) {}; 207 | \node[sbox] (1sb) at (\ox+0*\dx,\sy+0.5*\dy) {}; 208 | 209 | \node[result] (1so) at (\ox+0.7*\dx,4*\dy) {\reduce}; 210 | 211 | \def\ox{3*\oxb+\oxs} 212 | 213 | \node[label] (1l) at (\ox+0*\dx,\sy+3*\dy) {$t=3$}; 214 | 215 | \node[focus1] (2bb) at (\ox+0*\dx,2*\dy) {}; 216 | \node[word] (2b3) at (\ox+0*\dx,\by-1*\dy) {}; 217 | \node[word] (2b2) at (\ox+0*\dx,\by+1*\dy) {down}; 218 | \node[word] (2b1) at (\ox+0*\dx,\by+2*\dy) {sat}; 219 | \node[bbox] (2bb) at (\ox+0*\dx,\by+0.5*\dy) {}; 220 | 221 | \node[focus2] (2sb) at (\ox+0*\dx,\sy-0.5*\dy) {}; 222 | \node[word] (2s1) at (\ox+0*\dx,\sy-1*\dy) {the cat}; 223 | \node[word] (2s2) at (\ox+0*\dx,\sy+0*\dy) {}; 224 | \node[word] (2s3) at (\ox+0*\dx,\sy+1*\dy) {}; 225 | \node[sbox] (2sb) at (\ox+0*\dx,\sy+0.5*\dy) {}; 226 | 227 | \node[result] (2so) at (\ox+0.7*\dx,4*\dy) {\shift}; 228 | 229 | \def\ox{4*\oxb+\oxs} 230 | 231 | \node[label] (1l) at (\ox+0*\dx,\sy+3*\dy) {$t=4$}; 232 | 233 | \node[focus1] (3bb) at (\ox+0*\dx,2*\dy) {}; 234 | \node[word] (3b3) at (\ox+0*\dx,\by-1*\dy) {}; 235 | \node[word] (3b2) at (\ox+0*\dx,\by+1*\dy) {}; 236 | \node[word] (3b1) at (\ox+0*\dx,\by+2*\dy) {down}; 237 | \node[bbox] (3bb) at (\ox+0*\dx,\by+0.5*\dy) {}; 238 | 239 | \node[focus2] (3sb) at (\ox+0*\dx,\sy-0.5*\dy) {}; 240 | \node[word] (3s1) at (\ox+0*\dx,\sy-1*\dy) {sat}; 241 | \node[word] (3s2) at (\ox+0*\dx,\sy+0*\dy) {the cat}; 242 | \node[word] (3s3) at (\ox+0*\dx,\sy+1*\dy) {}; 243 | \node[sbox] (3sb) at (\ox+0*\dx,\sy+0.5*\dy) {}; 244 | 245 | \node[result] (3so) at (\ox+0.7*\dx,4*\dy) {\shift}; 246 | 247 | \def\ox{5*\oxb+\oxs} 248 | 249 | \node[label] (1l) at (\ox+0*\dx,\sy+3*\dy) {$t=5$}; 250 | 251 | \node[focus1] (4bb) at (\ox+0*\dx,2*\dy) {}; 252 | \node[word] (4b3) at (\ox+0*\dx,\by-1*\dy) {}; 253 | \node[word] (4b2) at (\ox+0*\dx,\by+0*\dy) {}; 254 | \node[word] (4b1) at (\ox+0*\dx,\by+1*\dy) {}; 255 | \node[bbox] (4bb) at (\ox+0*\dx,\by+0.5*\dy) {}; 256 | 257 | \node[focus2] (4sb) at (\ox+0*\dx,\sy-0.5*\dy) {}; 258 | \node[word] (4s1) at (\ox+0*\dx,\sy-1*\dy) {down}; 259 | \node[word] (4s2) at (\ox+0*\dx,\sy+0*\dy) {sat}; 260 | \node[word] (4s3) at (\ox+0*\dx,\sy+1*\dy) {the cat}; 261 | \node[sbox] (4sb) at (\ox+0*\dx,\sy+0.5*\dy) {}; 262 | 263 | \node[result] (4so) at (\ox+0.7*\dx,4*\dy) {\reduce}; 264 | 265 | \def\ox{6*\oxb+\oxs} 266 | 267 | \node[label] (1l) at (\ox+0*\dx,\sy+3*\dy) {$t=6$}; 268 | 269 | \node[focus1] (5bb) at (\ox+0*\dx,2*\dy) {}; 270 | \node[word] (5b3) at (\ox+0*\dx,\by-1*\dy) {}; 271 | \node[word] (5b2) at (\ox+0*\dx,\by+0*\dy) {}; 272 | \node[word] (5b1) at (\ox+0*\dx,\by+1*\dy) {}; 273 | \node[bbox] (5bb) at (\ox+0*\dx,\by+0.5*\dy) {}; 274 | 275 | \node[focus2] (5sb) at (\ox+0*\dx,\sy-0.5*\dy) {}; 276 | \node[word] (5s1) at (\ox+0*\dx,\sy-1*\dy) {sat down}; 277 | \node[word] (5s2) at (\ox+0*\dx,\sy+0*\dy) {the cat}; 278 | \node[word] (5s3) at (\ox+0*\dx,\sy+1*\dy) {}; 279 | \node[sbox] (5sb) at (\ox+0*\dx,\sy+0.5*\dy) {}; 280 | 281 | \node[result] (5so) at (\ox+0.7*\dx,4*\dy) {\reduce}; 282 | 283 | \def\ox{7*\oxb+\oxs} 284 | 285 | \node[label,text width=25mm] (1l) at (\ox+0*\dx,\sy+3*\dy) {$t=7=T$}; 286 | 287 | \node[focus2] (6sb) at (\ox+0*\dx,\sy-0.5*\dy) {}; 288 | \node[word] (6s1) at (\ox+0*\dx,\sy-1*\dy) {(the cat) (sat down)}; 289 | \node[word] (6s2) at (\ox+0*\dx,\sy+0*\dy) {}; 290 | \node[word] (6s3) at (\ox+0*\dx,\sy+1*\dy) {}; 291 | \node[sbox] (6sb) at (\ox+0*\dx,\sy+0.5*\dy) {}; 292 | 293 | \node[softmax] (6sm) at (\ox+0.5*\dx,2*\dy) {output to model for semantic task}; 294 | 295 | \pgfsetarrowsend{latex} 296 | \tikzstyle{fwd} = [draw=black, line width=1pt] 297 | 298 | \draw [fwd] (00sb) -- (00so); 299 | \draw [fwd] (00bb) -- (00so); 300 | 301 | \draw [fwd] (00sb) -- (0sb); 302 | \draw [fwd] (00bb) -- (0bb); 303 | \draw [fwd] (00so) -- (0sb); 304 | \draw [fwd] (00so) -- (0bb); 305 | \draw [fwd] (00bb) -- (0sb); 306 | 307 | \draw [fwd] (0sb) -- (0so); 308 | \draw [fwd] (0bb) -- (0so); 309 | 310 | \draw [fwd] (0sb) -- (1sb); 311 | \draw [fwd] (0bb) -- (1bb); 312 | \draw [fwd] (0so) -- (1sb); 313 | \draw [fwd] (0so) -- (1bb); 314 | \draw [fwd] (0bb) -- (1sb); 315 | 316 | \draw [fwd] (1sb) -- (1so); 317 | \draw [fwd] (1bb) -- (1so); 318 | 319 | \draw [fwd] (1sb) -- (2sb); 320 | \draw [fwd] (1bb) -- (2bb); 321 | \draw [fwd] (1so) -- (2sb); 322 | \draw [fwd] (1so) -- (2bb); 323 | \draw [fwd] (1bb) -- (2sb); 324 | 325 | \draw [fwd] (2sb) -- (2so); 326 | \draw [fwd] (2bb) -- (2so); 327 | 328 | \draw [fwd] (2sb) -- (3sb); 329 | \draw [fwd] (2bb) -- (3bb); 330 | \draw [fwd] (2so) -- (3sb); 331 | \draw [fwd] (2so) -- (3bb); 332 | \draw [fwd] (2bb) -- (3sb); 333 | 334 | \draw [fwd] (3sb) -- (3so); 335 | \draw [fwd] (3bb) -- (3so); 336 | 337 | \draw [fwd] (3sb) -- (4sb); 338 | \draw [fwd] (3bb) -- (4bb); 339 | \draw [fwd] (3so) -- (4sb); 340 | \draw [fwd] (3so) -- (4bb); 341 | \draw [fwd] (3bb) -- (4sb); 342 | 343 | \draw [fwd] (4sb) -- (4so); 344 | \draw [fwd] (4bb) -- (4so); 345 | 346 | \draw [fwd] (4sb) -- (5sb); 347 | \draw [fwd] (4bb) -- (5bb); 348 | \draw [fwd] (4so) -- (5sb); 349 | \draw [fwd] (4so) -- (5bb); 350 | \draw [fwd] (4bb) -- (5sb); 351 | 352 | \draw [fwd] (5sb) -- (5so); 353 | \draw [fwd] (5bb) -- (5so); 354 | 355 | \draw [fwd] (5sb) -- (6sb); 356 | \draw [fwd] (5so) -- (6sb); 357 | 358 | \draw [fwd] (6sb) -- (6sm); 359 | \end{tikzpicture}} 360 | 361 | \caption{The fully unrolled SPINN for \word{the cat sat down}, with neural network layers omitted for clarity.}\label{fig:model:1b} 362 | \end{subfigure} 363 | \caption{\label{fig:m1-views}Two views of the Stack-augmented Parser-Interpreter Neural Network (SPINN).} 364 | \end{figure*} 365 | -------------------------------------------------------------------------------- /writing/titlepage.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanfordnlp/spinn/f2f95f585328ec5cd7da071753f5e0582e7600a0/writing/titlepage.pdf -------------------------------------------------------------------------------- /writing/titlepage.tex: -------------------------------------------------------------------------------- 1 | % !TEX TS-program = pdflatex 2 | % !TEX encoding = UTF-8 Unicode 3 | 4 | % This is a simple template for a LaTeX document using the "article" class. 5 | % See "book", "report", "letter" for other types of document. 6 | 7 | \documentclass[11pt]{article} % use larger type; default would be 10pt 8 | 9 | \usepackage[utf8]{inputenc} % set input encoding (not needed with XeLaTeX) 10 | 11 | %%% Examples of Article customizations 12 | % These packages are optional, depending whether you want the features they provide. 13 | % See the LaTeX Companion or other references for full information. 14 | 15 | %%% PAGE DIMENSIONS 16 | \usepackage{geometry} % to change the page dimensions 17 | \geometry{a4paper} % or letterpaper (US) or a5paper or.... 18 | % \geometry{margin=2in} % for example, change the margins to 2 inches all round 19 | % \geometry{landscape} % set up the page for landscape 20 | % read geometry.pdf for detailed page layout information 21 | 22 | \usepackage{graphicx} % support the \includegraphics command and options 23 | 24 | % \usepackage[parfill]{parskip} % Activate to begin paragraphs with an empty line rather than an indent 25 | 26 | %%% PACKAGES 27 | \usepackage{booktabs} % for much better looking tables 28 | \usepackage{array} % for better arrays (eg matrices) in maths 29 | \usepackage{paralist} % very flexible & customisable lists (eg. enumerate/itemize, etc.) 30 | \usepackage{verbatim} % adds environment for commenting out blocks of text & for better verbatim 31 | \usepackage{subfig} % make it possible to include more than one captioned figure/table in a single float 32 | % These packages are all incorporated in the memoir class to one degree or another... 33 | 34 | %%% HEADERS & FOOTERS 35 | \usepackage{fancyhdr} % This should be set AFTER setting up the page geometry 36 | \pagestyle{fancy} % options: empty , plain , fancy 37 | \renewcommand{\headrulewidth}{0pt} % customise the layout... 38 | \lhead{}\chead{}\rhead{} 39 | \lfoot{}\cfoot{\thepage}\rfoot{} 40 | 41 | %%% SECTION TITLE APPEARANCE 42 | \usepackage{sectsty} 43 | \allsectionsfont{\sffamily\mdseries\upshape} % (See the fntguide.pdf for font help) 44 | % (This matches ConTeXt defaults) 45 | 46 | %%% ToC (table of contents) APPEARANCE 47 | \usepackage[nottoc,notlof,notlot]{tocbibind} % Put the bibliography in the ToC 48 | \usepackage[titles,subfigure]{tocloft} % Alter the style of the Table of Contents 49 | \renewcommand{\cftsecfont}{\rmfamily\mdseries\upshape} 50 | \renewcommand{\cftsecpagefont}{\rmfamily\mdseries\upshape} % No bold! 51 | 52 | %%% END Article customizations 53 | 54 | %%% The "real" document content comes below... 55 | 56 | \title{Title page for:\\A Fast Unified Model for Parsing and Sentence Understanding} 57 | %\date{} % Activate to display a given date or no date (if empty), 58 | % otherwise the current date is printed 59 | 60 | \begin{document} 61 | \maketitle 62 | A version of this paper will be posted as a non-archival manuscript to arXiv.org under the same title. 63 | \end{document} 64 | -------------------------------------------------------------------------------- /writing/tree_attn_fig.tex: -------------------------------------------------------------------------------- 1 | %!TEX root = hard_stack_paper/paper.tex 2 | 3 | \begin{figure}[t] 4 | \centering 5 | \scalebox{0.9}{ 6 | \begin{tikzpicture} 7 | \tikzstyle{word}=[fill=yellow!40,text height=2mm] 8 | \tikzstyle{nonleaf}=[fill=yellow!40,text height=2mm] 9 | 10 | \begin{scope}[shift={(0in,0in)}, frontier/.style={distance from root=60pt}] 11 | 12 | \Tree [.\node[nonleaf](1thecatsatdown){the cat sat down}; [.\node[nonleaf](1thecat){the cat}; \node[word](1the){the}; \node[word](1cat){cat}; ] [.\node[nonleaf](1satdown){sat down}; \node[word](1sat){sat}; \node[word](1down){down}; ] ] 13 | 14 | \end{scope} 15 | 16 | \begin{scope}[shift={(2in,0in)}, frontier/.style={distance from root=60pt}] 17 | 18 | \Tree [.\node[nonleaf](2thekittenmeowed){the kitten meowed}; [.\node[nonleaf](2thekitten){the kitten}; \node[word](2the){the}; \node[word](2kitten){kitten}; ] \node[word](2meowed){meowed}; ] 19 | 20 | \end{scope} 21 | 22 | \usetikzlibrary{arrows,scopes} 23 | 24 | \tikzstyle{attn} = [line width=1pt,draw=black!80,opacity=0.25,line cap=round] 25 | \tikzstyle{heavy} = [line width=3pt] 26 | \tikzstyle{light} = [line width=0.33pt] 27 | \tikzstyle{focus} = [draw=red,opacity=1.0] 28 | 29 | \draw [attn,heavy] (1the) to[in=155,out=25] (2the); 30 | \draw [attn,light] (1thecat) to[in=155,out=25] (2the); 31 | 32 | \draw [attn,heavy] (1cat) to[in=155,out=25] (2kitten); 33 | \draw [attn,light] (1thecat) to[in=155,out=25] (2kitten); 34 | \draw [attn,light] (1thecatsatdown) to[in=155,out=25] (2kitten); 35 | 36 | \draw [attn,light] (1satdown) to[in=155,out=25] (2thekittenmeowed); 37 | \draw [attn,light] (1thecat) to[in=155,out=25] (2thekittenmeowed); 38 | \draw [attn,heavy] (1thecatsatdown) to[in=155,out=25] (2thekittenmeowed); 39 | 40 | \draw [attn] (1satdown) to[in=155,out=25] (2meowed); 41 | \draw [attn] (1sat) to[in=155,out=25] (2meowed); 42 | \draw [attn,light] (1down) to[in=155,out=25] (2meowed); 43 | 44 | \draw [attn,focus] (1cat) to[in=155,out=25,] (2thekitten); 45 | \draw [attn,heavy,focus] (1thecat) to[in=155,out=25] (2thekitten); 46 | \draw [attn,light,focus] (1the) to[in=155,out=25] (2thekitten); 47 | \draw [attn,light,focus] (1thecatsatdown) to[in=155,out=25] (2thekitten); 48 | 49 | \end{tikzpicture}} 50 | 51 | \caption{Soft alignments between the nodes of two trees, with alignments to ``the kitten'' highlighted in red.}\label{fig:tree_attn} 52 | \end{figure} 53 | --------------------------------------------------------------------------------