├── .gitignore ├── LICENSE ├── README.md ├── configs └── squad │ └── r-net │ ├── hkust+bert.jsonnet │ ├── hkust+elmo.jsonnet │ ├── hkust.jsonnet │ └── original.jsonnet ├── img ├── em.png └── f1.png ├── main.py ├── modules ├── __init__.py ├── dropout.py ├── gate.py ├── pair_encoder │ ├── __init__.py │ ├── attentions.py │ ├── cells.py │ └── pair_encoder.py ├── pointer_network │ ├── __init__.py │ └── pointer_network.py ├── rnn │ ├── __init__.py │ └── stacked_rnn.py └── utils.py ├── qa ├── __init__.py └── squad │ ├── __init__.py │ ├── dataset.py │ └── rnet.py ├── requirements.txt └── tests ├── __init__.py ├── fixtures └── rnet │ ├── experiment.jsonnet │ └── experiment_dynamic.jsonnet └── models ├── __init__.py ├── r-net_dynamic_test.py └── r-net_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | .idea 3 | .vscode 4 | log 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | checkpoint/ 8 | *.py[cod] 9 | *$py.class 10 | *.ckpt 11 | # C extensions 12 | *.so 13 | .pytest_cache 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT LICENSE 2 | 3 | Copyright (C) <2017> 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files 6 | (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, 7 | publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, 8 | subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 11 | 12 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED 13 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 14 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 15 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | An unofficial implementation of R-net in [PyTorch](https://github.com/pytorch/pytorch) and [AllenNLP](https://github.com/allenai/allennlp). 2 | 3 | [Natural Language Computing Group, MSRA: R-NET: Machine Reading Comprehension with Self-matching Networks](https://www.microsoft.com/en-us/research/wp-content/uploads/2017/05/r-net.pdf) 4 | 5 | Actually, I didn't reproduce the model of this paper exactly because some details are not very clear to me and the dynamic attention in self-matching requires too much memory. 6 | Instead, I implemented the variant of R-Net according to [HKUST-KnowComp/R-Net](https://github.com/HKUST-KnowComp/R-Net) (in Tensorflow). 7 | 8 | The biggest difference between the original R-net and HKUST R-net is that: 9 | * The original R-net performs attention **at each RNN step**, which means that the hidden states are involved in the attention calculation. I call it dynamic attention. 10 | * In HKUST R-Net, attentions (in pair encoder and self-matching encoder) are calculated **before** performing RNN. I call it static attention. 11 | 12 | Some details in [HKUST-KnowComp/R-Net](https://github.com/HKUST-KnowComp/R-Net) that improves performance: 13 | * Question and Passage share the same GRU sentence encoder instead of using two GRU encoders respectively. 14 | * The sentence encoder has three layers, but its output is the concat of the three layers instead of the output of the top layer. 15 | * The GRUs in the pair encoder and the self-matching encoder have only one layer instead of three layers. 16 | * Variational dropouts are applied to (1) the inputs of RNNs (2) inputs of attentions 17 | 18 | Furthermore, this repo added ELMo word embeddings, which further improved the model's performance. 19 | 20 | ### Dependency 21 | 22 | * Python == 3.6 23 | * [AllenNLP](https://github.com/allenai/allennlp) == 0.7.2 24 | * PyTorch == 1.0 25 | 26 | 27 | 28 | ### Usage 29 | 30 | ``` 31 | git clone https://github.com/matthew-z/R-net.git 32 | cd R-net 33 | python main.py train configs/squad/r-net/hkust.jsonnet // HKUST R-Net 34 | ``` 35 | Note that the batch size may be a bit too large for 11GB GPUs. Please try 64 in case of OOM Error by adding the following arg: 36 | ```-o '{"iterator.batch_size": 64}'``` 37 | 38 | ### Configuration 39 | 40 | The models and hyperparameters are declared in `configs/` 41 | 42 | * the HKUST R-Net: `configs/r-net/hkust.jsonnet` (79.4 F1) 43 | * +ELMo: `configs/r-net/hkust+elmo.jsonnet` (82.2 F1) 44 | * the original R-Net: `configs/r-net/original.jsonnet` (currently not workable) 45 | 46 | 47 | ### Performance 48 | 49 | This implementation of HKUST R-Net can obtain 79.4 F1 and 70.5 EM on the validation set. 50 | + ELMo: 82.2 F1 and 74.4 EM. 51 | 52 | The visualization of R-Net + Elmo Training: 53 | Red: training score, Green: validation score 54 | 55 | 56 | 57 | 58 | Note that validation score is higher than training because each validation has three acceptable answers, which makes validation easier than training. 59 | 60 | ### Future Work 61 | * Add BERT: A preliminary implementation is in `configs/r-net/hkust+bert.jsonnet` 62 | * Add ensemble training 63 | * Add FP16 training 64 | 65 | ### Acknowledgement 66 | 67 | Thank [HKUST-KnowComp/R-Net](https://github.com/HKUST-KnowComp/R-Net) for sharing their Tensorflow implementation of R-net. This repo is based on their work. 68 | -------------------------------------------------------------------------------- /configs/squad/r-net/hkust+bert.jsonnet: -------------------------------------------------------------------------------- 1 | local embedding_size = 768; 2 | local hidden_size = 75; 3 | local attention_size = 75; 4 | local num_layers = 3; 5 | local dropout = 0.3; 6 | local bidirectional = true; 7 | 8 | { 9 | dataset_reader: { 10 | type: 'squad_truncated', 11 | truncate_train_only: false, 12 | max_passage_len: 300, 13 | token_indexers: { 14 | bert: { 15 | type: 'bert-pretrained', 16 | pretrained_model: 'bert-base-cased', 17 | do_lowercase: false, 18 | use_starting_offsets: false, 19 | }, 20 | }, 21 | }, 22 | 23 | train_data_path: 'https://s3-us-west-2.amazonaws.com/allennlp/datasets/squad/squad-train-v1.1.json', 24 | validation_data_path: 'https://s3-us-west-2.amazonaws.com/allennlp/datasets/squad/squad-dev-v1.1.json', 25 | 26 | model: { 27 | type: 'r_net', 28 | share_encoder: true, 29 | text_field_embedder: { 30 | allow_unmatched_keys: true, 31 | embedder_to_indexer_map: { 32 | bert: ['bert', 'bert-offsets'], 33 | }, 34 | token_embedders: { 35 | bert: { 36 | type: 'bert-pretrained', 37 | pretrained_model: 'bert-base-cased', 38 | }, 39 | }, 40 | }, 41 | 42 | 43 | question_encoder: { 44 | type: 'concat_rnn', 45 | input_size: embedding_size, 46 | hidden_size: hidden_size, 47 | num_layers: num_layers, 48 | bidirectional: bidirectional, 49 | dropout: dropout, 50 | }, 51 | 52 | passage_encoder: { 53 | type: 'concat_rnn', 54 | input_size: embedding_size, 55 | hidden_size: hidden_size, 56 | num_layers: num_layers, 57 | bidirectional: bidirectional, 58 | dropout: dropout, 59 | }, 60 | 61 | pair_encoder: { 62 | type: 'static_pair_encoder', 63 | memory_size: hidden_size * 2 * num_layers, 64 | input_size: hidden_size * 2 * num_layers, 65 | hidden_size: hidden_size, 66 | attention_size: attention_size, 67 | bidirectional: bidirectional, 68 | dropout: dropout, 69 | batch_first: true, 70 | 71 | }, 72 | 73 | self_encoder: { 74 | type: 'static_self_encoder', 75 | memory_size: hidden_size * 2, 76 | input_size: hidden_size * 2, 77 | hidden_size: hidden_size, 78 | attention_size: attention_size, 79 | bidirectional: bidirectional, 80 | dropout: dropout, 81 | batch_first: true, 82 | 83 | }, 84 | 85 | output_layer: { 86 | type: 'pointer_network', 87 | question_size: hidden_size * 2 * num_layers, 88 | passage_size: hidden_size * 2, 89 | attention_size: attention_size, 90 | dropout: dropout, 91 | batch_first: true, 92 | }, 93 | }, 94 | 95 | iterator: { 96 | type: 'basic', 97 | // sorting_keys: [['passage', 'num_tokens'], ['question', 'num_tokens']], 98 | batch_size: 64, 99 | // padding_noise: 0.2, 100 | // biggest_batch_first: true 101 | }, 102 | 103 | trainer: { 104 | num_epochs: 120, 105 | num_serialized_models_to_keep: 5, 106 | grad_norm: 5.0, 107 | patience: 10, 108 | validation_metric: '+f1', 109 | cuda_device: [0], 110 | learning_rate_scheduler: { 111 | type: 'reduce_on_plateau', 112 | factor: 0.5, 113 | mode: 'max', 114 | patience: 3, 115 | }, 116 | optimizer: { 117 | type: 'adadelta', 118 | lr: 1, 119 | rho: 0.95, 120 | }, 121 | }, 122 | } 123 | -------------------------------------------------------------------------------- /configs/squad/r-net/hkust+elmo.jsonnet: -------------------------------------------------------------------------------- 1 | local embedding_size = 1524; 2 | local hidden_size = 75; 3 | local attention_size = 75; 4 | local num_layers = 3; 5 | local dropout = 0.3; 6 | local bidirectional = true; 7 | 8 | { 9 | dataset_reader: { 10 | type: 'squad_truncated', 11 | token_indexers: { 12 | tokens: { 13 | type: 'single_id', 14 | lowercase_tokens: true, 15 | }, 16 | token_characters: { 17 | type: 'characters', 18 | character_tokenizer: { 19 | byte_encoding: 'utf-8', 20 | start_tokens: [259], 21 | end_tokens: [260], 22 | }, 23 | }, 24 | elmo: { 25 | type: 'elmo_characters', 26 | }, 27 | }, 28 | }, 29 | 30 | train_data_path: 'https://s3-us-west-2.amazonaws.com/allennlp/datasets/squad/squad-train-v1.1.json', 31 | validation_data_path: 'https://s3-us-west-2.amazonaws.com/allennlp/datasets/squad/squad-dev-v1.1.json', 32 | model: { 33 | type: 'r_net', 34 | share_encoder: true, 35 | text_field_embedder: { 36 | token_embedders: { 37 | tokens: { 38 | type: 'embedding', 39 | pretrained_file: 'https://s3-us-west-2.amazonaws.com/allennlp/datasets/glove/glove.840B.300d.txt.gz', 40 | embedding_dim: 300, 41 | trainable: false, 42 | }, 43 | elmo: { 44 | type: 'elmo_token_embedder', 45 | options_file: 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json', 46 | weight_file: 'https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5', 47 | do_layer_norm: false, 48 | dropout: 0.0, 49 | }, 50 | token_characters: { 51 | type: 'character_encoding', 52 | embedding: { 53 | num_embeddings: 262, 54 | embedding_dim: 8, 55 | }, 56 | encoder: { 57 | type: 'gru', 58 | input_size: 8, 59 | hidden_size: 100, 60 | bidirectional: true, 61 | dropout: dropout, 62 | }, 63 | }, 64 | }, 65 | }, 66 | 67 | question_encoder: { 68 | type: 'concat_rnn', 69 | input_size: embedding_size, 70 | hidden_size: hidden_size, 71 | num_layers: num_layers, 72 | bidirectional: bidirectional, 73 | dropout: dropout, 74 | }, 75 | 76 | passage_encoder: { 77 | type: 'concat_rnn', 78 | input_size: embedding_size, 79 | hidden_size: hidden_size, 80 | num_layers: num_layers, 81 | bidirectional: bidirectional, 82 | dropout: dropout, 83 | }, 84 | 85 | pair_encoder: { 86 | type: 'static_pair_encoder', 87 | memory_size: hidden_size * 2 * num_layers, 88 | input_size: hidden_size * 2 * num_layers, 89 | hidden_size: hidden_size, 90 | attention_size: attention_size, 91 | bidirectional: bidirectional, 92 | dropout: dropout, 93 | batch_first: true, 94 | 95 | }, 96 | 97 | self_encoder: { 98 | type: 'static_self_encoder', 99 | memory_size: hidden_size * 2, 100 | input_size: hidden_size * 2, 101 | hidden_size: hidden_size, 102 | attention_size: attention_size, 103 | bidirectional: bidirectional, 104 | dropout: dropout, 105 | batch_first: true, 106 | 107 | }, 108 | 109 | output_layer: { 110 | type: 'pointer_network', 111 | question_size: hidden_size * 2 * num_layers, 112 | passage_size: hidden_size * 2, 113 | attention_size: attention_size, 114 | dropout: dropout, 115 | batch_first: true, 116 | }, 117 | }, 118 | 119 | iterator: { 120 | type: 'basic', 121 | batch_size: 64, 122 | }, 123 | 124 | trainer: { 125 | num_epochs: 120, 126 | num_serialized_models_to_keep: 5, 127 | grad_norm: 5.0, 128 | patience: 10, 129 | validation_metric: '+f1', 130 | cuda_device: [0], 131 | learning_rate_scheduler: { 132 | type: 'reduce_on_plateau', 133 | factor: 0.5, 134 | mode: 'max', 135 | patience: 3, 136 | }, 137 | optimizer: { 138 | type: 'adadelta', 139 | lr: 0.5, 140 | rho: 0.95, 141 | }, 142 | }, 143 | } 144 | -------------------------------------------------------------------------------- /configs/squad/r-net/hkust.jsonnet: -------------------------------------------------------------------------------- 1 | local embedding_size = 500; 2 | local hidden_size = 75; 3 | local attention_size = 75; 4 | local num_layers = 3; 5 | local dropout = 0.3; 6 | local bidirectional = true; 7 | 8 | { 9 | dataset_reader: { 10 | type: 'squad', 11 | token_indexers: { 12 | tokens: { 13 | type: 'single_id', 14 | lowercase_tokens: true, 15 | }, 16 | token_characters: { 17 | type: 'characters', 18 | character_tokenizer: { 19 | byte_encoding: 'utf-8', 20 | start_tokens: [259], 21 | end_tokens: [260], 22 | }, 23 | // min_padding_length: 5, 24 | }, 25 | }, 26 | }, 27 | 28 | train_data_path: 'https://s3-us-west-2.amazonaws.com/allennlp/datasets/squad/squad-train-v1.1.json', 29 | validation_data_path: 'https://s3-us-west-2.amazonaws.com/allennlp/datasets/squad/squad-dev-v1.1.json', 30 | model: { 31 | type: 'r_net', 32 | share_encoder: true, 33 | text_field_embedder: { 34 | token_embedders: { 35 | tokens: { 36 | type: 'embedding', 37 | pretrained_file: 'https://s3-us-west-2.amazonaws.com/allennlp/datasets/glove/glove.840B.300d.txt.gz', 38 | embedding_dim: 300, 39 | trainable: false, 40 | }, 41 | token_characters: { 42 | type: 'character_encoding', 43 | embedding: { 44 | num_embeddings: 262, 45 | embedding_dim: 8, 46 | }, 47 | encoder: { 48 | type: 'gru', 49 | input_size: 8, 50 | hidden_size: 100, 51 | bidirectional: true, 52 | dropout: dropout, 53 | }, 54 | }, 55 | }, 56 | }, 57 | 58 | question_encoder: { 59 | type: 'concat_rnn', 60 | input_size: embedding_size, 61 | hidden_size: hidden_size, 62 | num_layers: num_layers, 63 | bidirectional: bidirectional, 64 | dropout: dropout, 65 | }, 66 | 67 | passage_encoder: { 68 | type: 'concat_rnn', 69 | input_size: embedding_size, 70 | hidden_size: hidden_size, 71 | num_layers: num_layers, 72 | bidirectional: bidirectional, 73 | dropout: dropout, 74 | }, 75 | 76 | pair_encoder: { 77 | type: 'static_pair_encoder', 78 | memory_size: hidden_size * 2 * num_layers, 79 | input_size: hidden_size * 2 * num_layers, 80 | hidden_size: hidden_size, 81 | attention_size: attention_size, 82 | bidirectional: bidirectional, 83 | dropout: dropout, 84 | batch_first: true, 85 | 86 | }, 87 | 88 | self_encoder: { 89 | type: 'static_self_encoder', 90 | memory_size: hidden_size * 2, 91 | input_size: hidden_size * 2, 92 | hidden_size: hidden_size, 93 | attention_size: attention_size, 94 | bidirectional: bidirectional, 95 | dropout: dropout, 96 | batch_first: true, 97 | 98 | }, 99 | 100 | output_layer: { 101 | type: 'pointer_network', 102 | question_size: hidden_size * 2 * num_layers, 103 | passage_size: hidden_size * 2, 104 | attention_size: attention_size, 105 | dropout: dropout, 106 | batch_first: true, 107 | }, 108 | }, 109 | 110 | iterator: { 111 | type: 'basic', 112 | // sorting_keys: [['passage', 'num_tokens'], ['question', 'num_tokens']], 113 | batch_size: 128, 114 | // padding_noise: 0.2, 115 | }, 116 | 117 | trainer: { 118 | num_epochs: 120, 119 | num_serialized_models_to_keep: 5, 120 | grad_norm: 5.0, 121 | patience: 10, 122 | validation_metric: '+f1', 123 | cuda_device: [0], 124 | learning_rate_scheduler: { 125 | type: 'reduce_on_plateau', 126 | factor: 0.5, 127 | mode: 'max', 128 | patience: 3, 129 | }, 130 | optimizer: { 131 | type: 'adadelta', 132 | lr: 0.5, 133 | rho: 0.95, 134 | }, 135 | }, 136 | } 137 | -------------------------------------------------------------------------------- /configs/squad/r-net/original.jsonnet: -------------------------------------------------------------------------------- 1 | local embedding_size = 500; 2 | local hidden_size = 75; 3 | local attention_size = 75; 4 | local num_layers = 3; 5 | local dropout = 0.3; 6 | local bidirectional = true; 7 | 8 | { 9 | dataset_reader: { 10 | type: 'squad', 11 | token_indexers: { 12 | tokens: { 13 | type: 'single_id', 14 | lowercase_tokens: true, 15 | }, 16 | token_characters: { 17 | type: 'characters', 18 | character_tokenizer: { 19 | byte_encoding: 'utf-8', 20 | start_tokens: [259], 21 | end_tokens: [260], 22 | }, 23 | // min_padding_length: 5, 24 | }, 25 | }, 26 | }, 27 | 28 | train_data_path: 'https://s3-us-west-2.amazonaws.com/allennlp/datasets/squad/squad-train-v1.1.json', 29 | validation_data_path: 'https://s3-us-west-2.amazonaws.com/allennlp/datasets/squad/squad-dev-v1.1.json', 30 | model: { 31 | type: 'r_net', 32 | share_encoder: false, 33 | text_field_embedder: { 34 | token_embedders: { 35 | tokens: { 36 | type: 'embedding', 37 | pretrained_file: 'https://s3-us-west-2.amazonaws.com/allennlp/datasets/glove/glove.840B.300d.txt.gz', 38 | embedding_dim: 300, 39 | trainable: false, 40 | }, 41 | token_characters: { 42 | type: 'character_encoding', 43 | embedding: { 44 | num_embeddings: 262, 45 | embedding_dim: 8, 46 | }, 47 | encoder: { 48 | type: 'gru', 49 | input_size: 8, 50 | hidden_size: 100, 51 | bidirectional: true, 52 | dropout: dropout, 53 | }, 54 | }, 55 | }, 56 | }, 57 | 58 | question_encoder: { 59 | type: 'gru', 60 | input_size: embedding_size, 61 | hidden_size: hidden_size, 62 | num_layers: num_layers, 63 | bidirectional: bidirectional, 64 | dropout: dropout, 65 | }, 66 | 67 | passage_encoder: { 68 | type: 'gru', 69 | input_size: embedding_size, 70 | hidden_size: hidden_size, 71 | num_layers: num_layers, 72 | bidirectional: bidirectional, 73 | dropout: dropout, 74 | }, 75 | 76 | pair_encoder: { 77 | type: 'dynamic_pair_encoder', 78 | memory_size: hidden_size * 2 * num_layers, 79 | input_size: hidden_size * 2 * num_layers, 80 | hidden_size: hidden_size, 81 | attention_size: attention_size, 82 | bidirectional: bidirectional, 83 | dropout: dropout, 84 | batch_first: true, 85 | }, 86 | 87 | self_encoder: { 88 | type: 'dynamic_self_encoder', 89 | memory_size: hidden_size * 2, 90 | input_size: hidden_size * 2, 91 | hidden_size: hidden_size, 92 | attention_size: attention_size, 93 | bidirectional: bidirectional, 94 | dropout: dropout, 95 | batch_first: true, 96 | 97 | }, 98 | 99 | output_layer: { 100 | type: 'pointer_network', 101 | question_size: hidden_size * 2 * num_layers, 102 | passage_size: hidden_size * 2, 103 | attention_size: attention_size, 104 | dropout: dropout, 105 | batch_first: true, 106 | }, 107 | }, 108 | 109 | iterator: { 110 | type: 'basic', 111 | batch_size: 128, 112 | }, 113 | 114 | trainer: { 115 | num_epochs: 120, 116 | num_serialized_models_to_keep: 5, 117 | grad_norm: 5.0, 118 | patience: 10, 119 | validation_metric: '+f1', 120 | cuda_device: [0], 121 | learning_rate_scheduler: { 122 | type: 'reduce_on_plateau', 123 | factor: 0.5, 124 | mode: 'max', 125 | patience: 3, 126 | }, 127 | optimizer: { 128 | type: 'adadelta', 129 | lr: 0.5, 130 | rho: 0.95, 131 | }, 132 | }, 133 | } 134 | -------------------------------------------------------------------------------- /img/em.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthew-z/R-net/670f3796cd01595903ef781ec7eb0e55c020e77b/img/em.png -------------------------------------------------------------------------------- /img/f1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthew-z/R-net/670f3796cd01595903ef781ec7eb0e55c020e77b/img/f1.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | from pathlib import Path 4 | 5 | from allennlp.commands import main, Subcommand 6 | from allennlp.commands.train import train_model 7 | from allennlp.common import Params 8 | from allennlp.common.util import import_submodules 9 | from allennlp.models import Model 10 | 11 | 12 | class MyTrain(Subcommand): 13 | def add_subparser(self, name: str, parser: argparse._SubParsersAction) -> argparse.ArgumentParser: 14 | # pylint: disable=protected-access 15 | description = '''Train the specified model on the specified dataset.''' 16 | subparser = parser.add_parser(name, description=description, help='Train a model') 17 | 18 | subparser.add_argument('param_path', 19 | type=str, 20 | help='path to parameter file describing the model to be trained') 21 | 22 | subparser.add_argument('-s', '--serialization-dir', 23 | required=False, 24 | default="", 25 | type=str, 26 | help='directory in which to save the model and its logs') 27 | 28 | subparser.add_argument('-r', '--recover', 29 | action='store_true', 30 | default=False, 31 | help='recover training from the state in serialization_dir') 32 | 33 | subparser.add_argument('-f', '--force', 34 | action='store_true', 35 | required=False, 36 | help='overwrite the output directory if it exists') 37 | 38 | subparser.add_argument('-o', '--overrides', 39 | type=str, 40 | default="", 41 | help='a JSON structure used to override the experiment configuration') 42 | 43 | 44 | subparser.add_argument('-e', '--ext-vars', 45 | type=str, 46 | default=None, 47 | help='Used to provide ext variable to jsonnet') 48 | 49 | subparser.add_argument('--fp16', 50 | action='store_true', 51 | required=False, 52 | help='use fp 16 training') 53 | 54 | subparser.add_argument('--file-friendly-logging', 55 | action='store_true', 56 | default=False, 57 | help='outputs tqdm status on separate lines and slows tqdm refresh rate') 58 | 59 | subparser.set_defaults(func=train_model_from_args) 60 | 61 | return subparser 62 | 63 | 64 | def train_model_from_args(args: argparse.Namespace): 65 | """ 66 | Just converts from an ``argparse.Namespace`` object to string paths. 67 | """ 68 | 69 | start_time = datetime.datetime.now().strftime('%b-%d_%H-%M') 70 | 71 | if args.serialization_dir: 72 | serialization_dir = args.serialization_dir 73 | else: 74 | path = Path(args.param_path.replace("configs/", "results/")).resolve() 75 | serialization_dir = path.with_name(path.stem) / start_time 76 | 77 | 78 | train_model_from_file(args.param_path, 79 | serialization_dir, 80 | args.overrides, 81 | args.file_friendly_logging, 82 | args.recover, 83 | args.force, 84 | args.ext_vars) 85 | 86 | def train_model_from_file(parameter_filename: str, 87 | serialization_dir: str, 88 | overrides: str = "", 89 | file_friendly_logging: bool = False, 90 | recover: bool = False, 91 | force: bool = False, 92 | ext_vars=None) -> Model: 93 | """ 94 | A wrapper around :func:`train_model` which loads the params from a file. 95 | 96 | Parameters 97 | ---------- 98 | param_path : ``str`` 99 | A json parameter file specifying an AllenNLP experiment. 100 | serialization_dir : ``str`` 101 | The directory in which to save results and logs. We just pass this along to 102 | :func:`train_model`. 103 | overrides : ``str`` 104 | A JSON string that we will use to override values in the input parameter file. 105 | file_friendly_logging : ``bool``, optional (default=False) 106 | If ``True``, we make our output more friendly to saved model files. We just pass this 107 | along to :func:`train_model`. 108 | recover : ``bool`, optional (default=False) 109 | If ``True``, we will try to recover a training run from an existing serialization 110 | directory. This is only intended for use when something actually crashed during the middle 111 | of a run. For continuing training a model on new data, see the ``fine-tune`` command. 112 | """ 113 | # Load the experiment config from a file and pass it to ``train_model``. 114 | params = Params.from_file(parameter_filename, overrides, ext_vars=ext_vars) 115 | return train_model(params, serialization_dir, file_friendly_logging, recover, force) 116 | 117 | 118 | if __name__ == "__main__": 119 | import_submodules("qa") 120 | import_submodules("modules") 121 | main(prog="ReadingZoo",subcommand_overrides={"train": MyTrain()}) -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthew-z/R-net/670f3796cd01595903ef781ec7eb0e55c020e77b/modules/__init__.py -------------------------------------------------------------------------------- /modules/dropout.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class RNNDropout(nn.Module): 5 | def __init__(self, p, batch_first=False): 6 | super().__init__() 7 | self.dropout = nn.Dropout(p) 8 | self.batch_first = batch_first 9 | 10 | def forward(self, inputs): 11 | 12 | if not self.training: 13 | return inputs 14 | if self.batch_first: 15 | mask = inputs.new_ones(inputs.size(0), 1, inputs.size(2), requires_grad=False) 16 | else: 17 | mask = inputs.new_ones(1, inputs.size(1), inputs.size(2), requires_grad=False) 18 | return self.dropout(mask) * inputs 19 | -------------------------------------------------------------------------------- /modules/gate.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from modules.dropout import RNNDropout 4 | 5 | 6 | class Gate(nn.Module): 7 | def __init__(self, input_size, dropout=0.3): 8 | super().__init__() 9 | self.gate = nn.Sequential( 10 | RNNDropout(dropout), 11 | nn.Linear(input_size, input_size, bias=False), 12 | nn.Sigmoid() 13 | ) 14 | 15 | def forward(self, inputs): 16 | return inputs * self.gate(inputs) 17 | -------------------------------------------------------------------------------- /modules/pair_encoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthew-z/R-net/670f3796cd01595903ef781ec7eb0e55c020e77b/modules/pair_encoder/__init__.py -------------------------------------------------------------------------------- /modules/pair_encoder/attentions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from modules.dropout import RNNDropout 5 | from allennlp.nn.util import masked_softmax 6 | from torch import Tensor 7 | 8 | def unroll_attention_cell(cell, inputs, memory, memory_mask, batch_first=False, initial_state=None, backward=False): 9 | if batch_first: 10 | inputs = inputs.transpose(0, 1) 11 | output = [] 12 | state = initial_state 13 | steps = range(inputs.size(0)) 14 | if backward: 15 | steps = range(inputs.size(0)-1, -1, -1) 16 | for t in steps: 17 | state = cell(inputs[t], memory=memory, memory_mask=memory_mask, state=state) 18 | output.append(state) 19 | if backward: 20 | output = output[::-1] 21 | output = torch.stack(output, dim=1 if batch_first else 0) 22 | return output, state 23 | 24 | 25 | def bidirectional_unroll_attention_cell(cell_fw, cell_bw, inputs, memory, memory_mask, batch_first=False, 26 | initial_state=None): 27 | if initial_state is None: 28 | initial_state = [None, None] 29 | 30 | output_fw, state_fw = unroll_attention_cell( 31 | cell_fw, inputs, memory, memory_mask, 32 | batch_first=batch_first, 33 | initial_state=initial_state[0], backward=False) 34 | 35 | output_bw, state_bw = unroll_attention_cell( 36 | cell_bw, inputs, memory, memory_mask, 37 | batch_first=batch_first, 38 | initial_state=initial_state[1], backward=True) 39 | 40 | return torch.cat([output_fw, output_bw], dim=-1), (state_fw, state_bw) 41 | 42 | 43 | # class StaticAddAttention(nn.Module): 44 | # def __init__(self, memory_size, input_size, attention_size, dropout=0.2, batch_first=False): 45 | # super().__init__() 46 | # self.batch_first = batch_first 47 | 48 | # self.attention_w = nn.Sequential( 49 | # nn.Linear(input_size + memory_size, attention_size, bias=False), 50 | # nn.Dropout(dropout), 51 | # nn.Tanh(), 52 | # nn.Linear(attention_size, 1, bias=False), 53 | # nn.Dropout(dropout), 54 | # ) 55 | 56 | # def forward(self, inputs: Tensor, memory: Tensor, memory_mask: Tensor): 57 | # if not self.batch_first: 58 | # raise NotImplementedError 59 | 60 | # T = inputs.size(0) 61 | # memory_mask = memory_mask.unsqueeze(0) 62 | # memory_key = memory.unsqueeze(0).expand(T, -1, -1, -1) 63 | # input_key = inputs.unsqueeze(1).expand(-1, T, -1, -1) 64 | # attention_logits = self.attention_w(torch.cat([input_key, memory_key], -1)).squeeze(-1) 65 | # logits, score = softmax_mask(attention_logits, memory_mask, dim=1) 66 | # context = torch.sum(score.unsqueeze(-1) * inputs.unsqueeze(0), dim=1) 67 | # new_input = torch.cat([context, inputs], dim=-1) 68 | # return new_input 69 | 70 | 71 | class StaticDotAttention(nn.Module): 72 | def __init__(self, memory_size, input_size, attention_size, batch_first=False, dropout=0.2): 73 | 74 | super().__init__() 75 | 76 | self.input_linear = nn.Sequential( 77 | RNNDropout(dropout, batch_first=True), 78 | nn.Linear(input_size, attention_size, bias=False), 79 | nn.ReLU() 80 | ) 81 | 82 | self.memory_linear = nn.Sequential( 83 | RNNDropout(dropout, batch_first=True), 84 | nn.Linear(memory_size, attention_size, bias=False), 85 | nn.ReLU() 86 | ) 87 | self.attention_size = attention_size 88 | self.batch_first = batch_first 89 | 90 | def forward(self, inputs: Tensor, memory: Tensor, memory_mask: Tensor): 91 | if not self.batch_first: 92 | inputs = inputs.transpose(0, 1) 93 | memory = memory.transpose(0, 1) 94 | memory_mask = memory_mask.transpose(0, 1) 95 | 96 | input_ = self.input_linear(inputs) 97 | memory_ = self.memory_linear(memory) 98 | 99 | logits = torch.bmm(input_, memory_.transpose(2, 1)) / (self.attention_size ** 0.5) 100 | 101 | memory_mask = memory_mask.unsqueeze(1).expand(-1, inputs.size(1), -1) 102 | score = masked_softmax(logits, memory_mask, dim=-1) 103 | 104 | context = torch.bmm(score, memory) 105 | new_input = torch.cat([context, inputs], dim=-1) 106 | 107 | if not self.batch_first: 108 | return new_input.transpose(0, 1) 109 | return new_input 110 | 111 | -------------------------------------------------------------------------------- /modules/pair_encoder/cells.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch import Tensor 4 | from allennlp.nn.util import masked_softmax 5 | 6 | class _PairEncodeCell(nn.Module): 7 | def __init__(self, input_size, cell, attention_size, memory_size=None, use_state_in_attention=True, batch_first=False): 8 | super().__init__() 9 | if memory_size is None: 10 | memory_size = input_size 11 | 12 | self.cell = cell 13 | self.use_state = use_state_in_attention 14 | 15 | attention_input_size = input_size + memory_size 16 | if use_state_in_attention: 17 | attention_input_size += cell.hidden_size 18 | 19 | self.attention_w = nn.Sequential( 20 | nn.Dropout(), 21 | nn.Linear(attention_input_size, attention_size, bias=False), 22 | nn.Tanh(), 23 | nn.Linear(attention_size, 1, bias=False)) 24 | 25 | self.batch_first = batch_first 26 | 27 | def forward(self, inputs: Tensor, memory: Tensor = None, memory_mask: Tensor = None, state: Tensor = None): 28 | """ 29 | :param inputs: B x H 30 | :param memory: T x B x H if not batch_first 31 | :param memory_mask: T x B if not batch_first 32 | :param state: B x H 33 | :return: 34 | """ 35 | if self.batch_first: 36 | memory = memory.transpose(0, 1) 37 | memory_mask = memory_mask.transpose(0, 1) 38 | 39 | assert inputs.size(0) == memory.size(1) == memory_mask.size( 40 | 1), "inputs batch size does not match memory batch size" 41 | 42 | memory_time_length = memory.size(0) 43 | 44 | if state is None: 45 | state = inputs.new_zeros(inputs.size(0), self.cell.hidden_size, requires_grad=False) 46 | 47 | if self.use_state: 48 | hx = state 49 | if isinstance(state, tuple): 50 | hx = state[0] 51 | attention_input = torch.cat([inputs, hx], dim=-1) 52 | attention_input = attention_input.unsqueeze(0).expand(memory_time_length, -1, -1) # T B H 53 | else: 54 | attention_input = inputs.unsqueeze(0).expand(memory_time_length, -1, -1) 55 | 56 | attention_logits = self.attention_w(torch.cat([attention_input, memory], dim=-1)).squeeze(-1) 57 | 58 | attention_scores = masked_softmax(attention_logits, memory_mask, dim=0) 59 | 60 | attention_vector = torch.sum(attention_scores.unsqueeze(-1) * memory, dim=0) 61 | 62 | new_input = torch.cat([inputs, attention_vector], dim=-1) 63 | 64 | return self.cell(new_input, state) 65 | 66 | 67 | class PairEncodeCell(_PairEncodeCell): 68 | def __init__(self, input_size, cell, attention_size, 69 | memory_size=None, batch_first=False): 70 | super().__init__( 71 | input_size, cell, attention_size, 72 | memory_size=memory_size, use_state_in_attention=True, batch_first=batch_first) 73 | 74 | 75 | class SelfMatchCell(_PairEncodeCell): 76 | def __init__(self, input_size, cell, attention_size, 77 | memory_size=None, batch_first=False): 78 | super().__init__( 79 | input_size, cell, attention_size, 80 | memory_size=memory_size, use_state_in_attention=False, batch_first=batch_first) 81 | -------------------------------------------------------------------------------- /modules/pair_encoder/pair_encoder.py: -------------------------------------------------------------------------------- 1 | from overrides import overrides 2 | from torch import nn 3 | 4 | from modules.dropout import RNNDropout 5 | from modules.gate import Gate 6 | from modules.pair_encoder.attentions import bidirectional_unroll_attention_cell, unroll_attention_cell, \ 7 | StaticDotAttention 8 | from modules.pair_encoder.cells import PairEncodeCell, SelfMatchCell 9 | from allennlp.common import Registrable 10 | 11 | from modules.utils import get_rnn 12 | from allennlp.modules import Seq2SeqEncoder 13 | 14 | 15 | class AttentionEncoder(nn.Module, Registrable): 16 | def forward(self, inputs, inputs_mask, memory, memory_mask): 17 | raise NotImplementedError 18 | 19 | 20 | class DynamicAttentionEncoder(AttentionEncoder): 21 | def __init__(self, memory_size, input_size, hidden_size, attention_size, 22 | bidirectional, dropout, cell_factory=PairEncodeCell, batch_first=False): 23 | super().__init__() 24 | self.batch_first = batch_first 25 | cell_fn = lambda: cell_factory(input_size, cell=nn.GRUCell(input_size + memory_size, hidden_size), 26 | attention_size=attention_size, memory_size=memory_size, 27 | batch_first=batch_first) 28 | 29 | num_directions = 2 if bidirectional else 1 30 | self.bidirectional = bidirectional 31 | 32 | self.cells = nn.ModuleList([cell_fn() for _ in range(num_directions)]) 33 | self.dropout = RNNDropout(dropout) 34 | self.gate = Gate(input_size=hidden_size*2 if bidirectional else hidden_size, dropout=dropout) 35 | 36 | @overrides 37 | def forward(self, inputs, inputs_mask, memory, memory_mask): 38 | if self.bidirectional: 39 | cell_fw, cell_bw = self.cells 40 | output, _ = bidirectional_unroll_attention_cell( 41 | cell_fw, cell_bw, inputs, memory, memory_mask, 42 | batch_first=self.batch_first) 43 | else: 44 | cell, = self.cells 45 | output, _ = unroll_attention_cell( 46 | cell, inputs, memory, memory_mask, batch_first=self.batch_first) 47 | 48 | return self.gate(output) 49 | 50 | 51 | @AttentionEncoder.register("dynamic_pair_encoder") 52 | class DynamicPairEncoder(DynamicAttentionEncoder): 53 | def __init__(self, memory_size, input_size, hidden_size, attention_size, 54 | bidirectional, dropout, batch_first=False): 55 | super().__init__(memory_size, input_size, hidden_size, attention_size, 56 | bidirectional, dropout, PairEncodeCell, batch_first=batch_first) 57 | 58 | 59 | @AttentionEncoder.register("dynamic_self_encoder") 60 | class DynamicSelfEncoder(DynamicAttentionEncoder): 61 | def __init__(self, memory_size, input_size, hidden_size, attention_size, 62 | bidirectional, dropout, batch_first=False): 63 | super().__init__(memory_size, input_size, hidden_size, attention_size, 64 | bidirectional, dropout, SelfMatchCell, batch_first=batch_first) 65 | 66 | 67 | 68 | 69 | @AttentionEncoder.register("pass_through") 70 | class PassThrough(AttentionEncoder): 71 | def __init__(self, memory_size, input_size, hidden_size): 72 | super().__init__() 73 | 74 | 75 | @overrides 76 | def forward(self, inputs, inputs_mask, memory, memory_mask): 77 | return inputs 78 | 79 | 80 | 81 | @AttentionEncoder.register("static_pair_encoder") 82 | class StaticPairEncoder(AttentionEncoder): 83 | def __init__(self, memory_size, input_size, hidden_size, attention_size, bidirectional, dropout, 84 | attention_factory=StaticDotAttention, rnn_type="GRU", batch_first=True): 85 | super().__init__() 86 | rnn_fn = Seq2SeqEncoder.by_name(rnn_type.lower()) 87 | 88 | self.attention = attention_factory(memory_size, input_size, attention_size, 89 | dropout=dropout, batch_first=batch_first) 90 | 91 | self.gate = nn.Sequential( 92 | Gate(input_size + memory_size, dropout=dropout), 93 | RNNDropout(dropout, batch_first=batch_first) 94 | ) 95 | 96 | self.encoder = rnn_fn(input_size=memory_size + input_size, hidden_size=hidden_size, 97 | bidirectional=bidirectional, batch_first=batch_first) 98 | 99 | @overrides 100 | def forward(self, inputs, inputs_mask, memory, memory_mask): 101 | """ 102 | Memory: T B H 103 | input: T B H 104 | """ 105 | new_inputs = self.gate(self.attention(inputs, memory, memory_mask)) 106 | outputs = self.encoder(new_inputs, inputs_mask) 107 | return outputs 108 | 109 | 110 | @AttentionEncoder.register("static_self_encoder") 111 | class StaticSelfMatchEncoder(StaticPairEncoder): 112 | pass 113 | -------------------------------------------------------------------------------- /modules/pointer_network/__init__.py: -------------------------------------------------------------------------------- 1 | from .pointer_network import PointerNetwork -------------------------------------------------------------------------------- /modules/pointer_network/pointer_network.py: -------------------------------------------------------------------------------- 1 | from allennlp.common import Registrable 2 | from torch import nn 3 | import torch 4 | from allennlp.nn.util import masked_softmax 5 | from modules.dropout import RNNDropout 6 | from overrides import overrides 7 | 8 | class QAOutputLayer(nn.Module, Registrable): 9 | pass 10 | 11 | @QAOutputLayer.register("pointer_network") 12 | class PointerNetwork(QAOutputLayer): 13 | def __init__(self, question_size, passage_size, attention_size=75, 14 | cell_type=nn.GRUCell, dropout=0, batch_first=False): 15 | super().__init__() 16 | 17 | self.batch_first = batch_first 18 | 19 | # TODO: what is V_q? (section 3.4) 20 | v_q_size = question_size 21 | 22 | self.cell = cell_type(passage_size, question_size) 23 | self.dropout = dropout 24 | 25 | self.passage_linear = nn.Sequential( 26 | RNNDropout(dropout), 27 | nn.Linear(question_size + passage_size, attention_size, bias=False), 28 | nn.Tanh(), 29 | nn.Linear(attention_size, 1, bias=False), 30 | ) 31 | 32 | self.question_linear = nn.Sequential( 33 | RNNDropout(dropout), 34 | nn.Linear(question_size + v_q_size, attention_size, bias=False), 35 | nn.Tanh(), 36 | nn.Linear(attention_size, 1, bias=False), 37 | ) 38 | 39 | self.V_q = nn.Parameter(torch.randn(1, 1, v_q_size), requires_grad=True) 40 | 41 | @overrides 42 | def forward(self, question, question_mask, passage, passage_mask): 43 | """ 44 | :param question: T B H 45 | :param question_mask: T B 46 | :param passage: 47 | :param passage_mask: 48 | :return: B x 2 49 | """ 50 | 51 | if self.batch_first: 52 | question = question.transpose(0, 1) 53 | question_mask = question_mask.transpose(0, 1) 54 | passage = passage.transpose(0, 1) 55 | passage_mask = passage_mask.transpose(0, 1) 56 | 57 | state = self._question_pooling(question, question_mask) 58 | cell_input, ans_start_logits = self._passage_attention(passage, passage_mask, state) 59 | state = self.cell(cell_input, hx=state) 60 | _, ans_end_logits = self._passage_attention(passage, passage_mask, state) 61 | 62 | 63 | return ans_start_logits.transpose(0, 1), ans_end_logits.transpose(0, 1) 64 | 65 | def _question_pooling(self, question, question_mask): 66 | V_q = self.V_q.expand(question.size(0), question.size(1), -1) 67 | logits = self.question_linear(torch.cat([question, V_q], dim=-1)).squeeze(-1) 68 | score = masked_softmax(logits, question_mask, dim=0) 69 | state = torch.sum(score.unsqueeze(-1) * question, dim=0) 70 | return state 71 | 72 | def _passage_attention(self, passage, passage_mask, state): 73 | state_expand = state.unsqueeze(0).expand(passage.size(0), -1, -1) 74 | logits = self.passage_linear(torch.cat([passage, state_expand], dim=-1)).squeeze(-1) 75 | score = masked_softmax(logits, passage_mask, dim=0) 76 | cell_input = torch.sum(score.unsqueeze(-1) * passage, dim=0) 77 | return cell_input, logits 78 | -------------------------------------------------------------------------------- /modules/rnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthew-z/R-net/670f3796cd01595903ef781ec7eb0e55c020e77b/modules/rnn/__init__.py -------------------------------------------------------------------------------- /modules/rnn/stacked_rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from allennlp.modules import Seq2SeqEncoder 3 | 4 | from modules.dropout import RNNDropout 5 | 6 | 7 | @Seq2SeqEncoder.register("concat_rnn") 8 | class ConcatRNN(Seq2SeqEncoder): 9 | def __init__(self, input_size, hidden_size, num_layers, bidirectional, rnn_type="GRU", dropout=0., stateful=False, batch_first=True): 10 | super().__init__(stateful=stateful) 11 | rnn_cls = Seq2SeqEncoder.by_name(rnn_type.lower()) 12 | self.rnn_list = torch.nn.ModuleList( 13 | [rnn_cls(input_size=input_size, hidden_size=hidden_size, bidirectional=bidirectional, dropout=dropout)]) 14 | 15 | for _ in range(num_layers - 1): 16 | self.rnn_list.append( 17 | rnn_cls(input_size=hidden_size * 2, hidden_size=hidden_size, bidirectional=bidirectional, 18 | dropout=dropout)) 19 | 20 | self.dropout = RNNDropout(dropout, batch_first=batch_first) 21 | 22 | def forward(self, inputs, mask, hidden=None): 23 | outputs_list = [] 24 | 25 | for layer in self.rnn_list: 26 | outputs = layer(self.dropout(inputs), mask, hidden) 27 | outputs_list.append(outputs) 28 | inputs = outputs 29 | 30 | return torch.cat(outputs_list, -1) -------------------------------------------------------------------------------- /modules/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def reverse_padded_sequence_fast(inputs, lengths, batch_first=False): 5 | """Reverses sequences according to their lengths. 6 | Inputs should have size ``T x B x *`` if ``batch_first`` is False, or 7 | ``B x T x *`` if True. T is the length of the longest sequence (or larger), 8 | B is the batch size, and * is any number of dimensions (including 0). 9 | Arguments: 10 | inputs (Variable): padded batch of variable length sequences. 11 | lengths (list[int]): list of sequence lengths 12 | batch_first (bool, optional): if True, inputs should be B x T x *. 13 | Returns: 14 | A Variable with the same size as inputs, but with each sequence 15 | reversed according to its length. 16 | """ 17 | if not batch_first: 18 | inputs = inputs.transpose(0, 1) 19 | if inputs.size(0) != len(lengths): 20 | raise ValueError('inputs incompatible with lengths.') 21 | reversed_indices = [list(range(inputs.size(1))) for _ in range(inputs.size(0))] 22 | for i, length in enumerate(lengths): 23 | if length > 0: 24 | reversed_indices[i][:length] = reversed_indices[i][length-1::-1] 25 | reversed_indices = torch.LongTensor(reversed_indices).unsqueeze(2).expand_as(inputs) 26 | if inputs.is_cuda: 27 | reversed_indices = reversed_indices.cuda() 28 | reversed_inputs = torch.gather(inputs, 1, reversed_indices) 29 | if not batch_first: 30 | reversed_inputs = reversed_inputs.transpose(0, 1) 31 | return reversed_inputs 32 | 33 | 34 | def get_rnn(rnn_type): 35 | return getattr(torch.nn, rnn_type) -------------------------------------------------------------------------------- /qa/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthew-z/R-net/670f3796cd01595903ef781ec7eb0e55c020e77b/qa/__init__.py -------------------------------------------------------------------------------- /qa/squad/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthew-z/R-net/670f3796cd01595903ef781ec7eb0e55c020e77b/qa/squad/__init__.py -------------------------------------------------------------------------------- /qa/squad/dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, Tuple, Dict, List 3 | 4 | from allennlp.common.file_utils import cached_path 5 | from allennlp.data import DatasetReader, Tokenizer, TokenIndexer, Token, Instance 6 | from allennlp.data.dataset_readers.reading_comprehension.util import make_reading_comprehension_instance, \ 7 | char_span_to_token_span 8 | from allennlp.data.token_indexers import SingleIdTokenIndexer 9 | from allennlp.data.tokenizers import WordTokenizer 10 | from overrides import overrides 11 | import json 12 | 13 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 14 | 15 | 16 | 17 | @DatasetReader.register("squad_truncated") 18 | class SquadReader(DatasetReader): 19 | """ 20 | Modified from allennlp.data.dataset_readers.reading_comprehension.squad 21 | 22 | Reads a JSON-formatted SQuAD file and returns a ``Dataset`` where the ``Instances`` have four 23 | fields: ``question``, a ``TextField``, ``passage``, another ``TextField``, and ``span_start`` 24 | and ``span_end``, both ``IndexFields`` into the ``passage`` ``TextField``. We also add a 25 | ``MetadataField`` that stores the instance's ID, the original passage text, gold answer strings, 26 | and token offsets into the original passage, accessible as ``metadata['id']``, 27 | ``metadata['original_passage']``, ``metadata['answer_texts']`` and 28 | ``metadata['token_offsets']``. This is so that we can more easily use the official SQuAD 29 | evaluation script to get metrics. 30 | 31 | Parameters 32 | ---------- 33 | tokenizer : ``Tokenizer``, optional (default=``WordTokenizer()``) 34 | We use this ``Tokenizer`` for both the question and the passage. See :class:`Tokenizer`. 35 | Default is ```WordTokenizer()``. 36 | token_indexers : ``Dict[str, TokenIndexer]``, optional 37 | We similarly use this for both the question and the passage. See :class:`TokenIndexer`. 38 | Default is ``{"tokens": SingleIdTokenIndexer()}``. 39 | """ 40 | def __init__(self, 41 | tokenizer: Tokenizer = None, 42 | token_indexers: Dict[str, TokenIndexer] = None, 43 | lazy: bool = False, 44 | max_passage_len=400, 45 | truncate_train_only=True) -> None: 46 | super().__init__(lazy) 47 | self._tokenizer = tokenizer or WordTokenizer() 48 | self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()} 49 | self.max_passage_len = max_passage_len 50 | self.truncate_train_only = truncate_train_only 51 | 52 | @overrides 53 | def _read(self, file_path: str): 54 | if "train" in file_path: 55 | is_train = True 56 | else: 57 | is_train = False 58 | 59 | # if `file_path` is a URL, redirect to the cache 60 | file_path = cached_path(file_path) 61 | 62 | logger.info("Reading file at %s", file_path) 63 | with open(file_path) as dataset_file: 64 | dataset_json = json.load(dataset_file) 65 | dataset = dataset_json['data'] 66 | logger.info("Reading the dataset") 67 | for article in dataset: 68 | for paragraph_json in article['paragraphs']: 69 | paragraph = paragraph_json["context"] 70 | tokenized_paragraph = self._tokenizer.tokenize(paragraph) 71 | 72 | if len(tokenized_paragraph) > self.max_passage_len: 73 | if is_train or not self.truncate_train_only: 74 | continue 75 | 76 | for question_answer in paragraph_json['qas']: 77 | question_text = question_answer["question"].strip().replace("\n", "") 78 | answer_texts = [answer['text'] for answer in question_answer['answers']] 79 | span_starts = [answer['answer_start'] for answer in question_answer['answers']] 80 | span_ends = [start + len(answer) for start, answer in zip(span_starts, answer_texts)] 81 | instance = self.text_to_instance(question_text, 82 | paragraph, 83 | zip(span_starts, span_ends), 84 | answer_texts, 85 | tokenized_paragraph) 86 | yield instance 87 | 88 | @overrides 89 | def text_to_instance(self, # type: ignore 90 | question_text: str, 91 | passage_text: str, 92 | char_spans: List[Tuple[int, int]] = None, 93 | answer_texts: List[str] = None, 94 | passage_tokens: List[Token] = None) -> Instance: 95 | # pylint: disable=arguments-differ 96 | if not passage_tokens: 97 | passage_tokens = self._tokenizer.tokenize(passage_text) 98 | char_spans = char_spans or [] 99 | 100 | # We need to convert character indices in `passage_text` to token indices in 101 | # `passage_tokens`, as the latter is what we'll actually use for supervision. 102 | token_spans: List[Tuple[int, int]] = [] 103 | passage_offsets = [(token.idx, token.idx + len(token.text)) for token in passage_tokens] 104 | for char_span_start, char_span_end in char_spans: 105 | (span_start, span_end), error = char_span_to_token_span(passage_offsets, 106 | (char_span_start, char_span_end)) 107 | if error: 108 | logger.debug("Passage: %s", passage_text) 109 | logger.debug("Passage tokens: %s", passage_tokens) 110 | logger.debug("Question text: %s", question_text) 111 | logger.debug("Answer span: (%d, %d)", char_span_start, char_span_end) 112 | logger.debug("Token span: (%d, %d)", span_start, span_end) 113 | logger.debug("Tokens in answer: %s", passage_tokens[span_start:span_end + 1]) 114 | logger.debug("Answer: %s", passage_text[char_span_start:char_span_end]) 115 | token_spans.append((span_start, span_end)) 116 | 117 | return make_reading_comprehension_instance(self._tokenizer.tokenize(question_text), 118 | passage_tokens, 119 | self._token_indexers, 120 | passage_text, 121 | token_spans, 122 | answer_texts) 123 | -------------------------------------------------------------------------------- /qa/squad/rnet.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, List, Any 2 | 3 | import torch 4 | from allennlp.data import Vocabulary 5 | from allennlp.models import BidirectionalAttentionFlow 6 | from allennlp.models.model import Model 7 | from allennlp.modules import TextFieldEmbedder, Seq2SeqEncoder 8 | from allennlp.nn import InitializerApplicator, RegularizerApplicator 9 | from allennlp.nn import util 10 | from allennlp.nn.util import get_text_field_mask 11 | from allennlp.training.metrics import CategoricalAccuracy, BooleanAccuracy, SquadEmAndF1 12 | from torch.nn.functional import nll_loss 13 | 14 | from modules.pair_encoder.pair_encoder import AttentionEncoder 15 | from modules.pointer_network.pointer_network import QAOutputLayer 16 | 17 | @Model.register("r_net") 18 | class RNet(Model): 19 | def __init__(self, 20 | vocab: Vocabulary, 21 | text_field_embedder: TextFieldEmbedder, 22 | question_encoder: Seq2SeqEncoder, 23 | passage_encoder: Seq2SeqEncoder, 24 | pair_encoder: AttentionEncoder, 25 | self_encoder: AttentionEncoder, 26 | output_layer: QAOutputLayer, 27 | initializer: InitializerApplicator = InitializerApplicator(), 28 | regularizer: Optional[RegularizerApplicator] = None, 29 | share_encoder: bool = False): 30 | 31 | super().__init__(vocab, regularizer) 32 | self.text_field_embedder = text_field_embedder 33 | self.question_encoder = question_encoder 34 | self.passage_encoder = passage_encoder 35 | self.pair_encoder = pair_encoder 36 | self.self_encoder = self_encoder 37 | self.output_layer = output_layer 38 | 39 | self._span_start_accuracy = CategoricalAccuracy() 40 | self._span_end_accuracy = CategoricalAccuracy() 41 | self._span_accuracy = BooleanAccuracy() 42 | self._squad_metrics = SquadEmAndF1() 43 | self.share_encoder = share_encoder 44 | self.loss = torch.nn.CrossEntropyLoss() 45 | initializer(self) 46 | 47 | def forward(self, 48 | question: Dict[str, torch.LongTensor], 49 | passage: Dict[str, torch.LongTensor], 50 | span_start: torch.IntTensor = None, 51 | span_end: torch.IntTensor = None, 52 | metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: 53 | 54 | question_embeded = self.text_field_embedder(question) 55 | passage_embeded = self.text_field_embedder(passage) 56 | 57 | question_mask = get_text_field_mask(question).byte() 58 | passage_mask = get_text_field_mask(passage).byte() 59 | 60 | quetion_encoded = self.question_encoder( 61 | question_embeded, question_mask) 62 | 63 | if self.share_encoder: 64 | passage_encoded = self.question_encoder(passage_embeded, passage_mask) 65 | else: 66 | passage_encoded = self.passage_encoder(passage_embeded, passage_mask) 67 | 68 | passage_encoded = self.pair_encoder( 69 | passage_encoded, passage_mask, quetion_encoded, question_mask) 70 | passage_encoded = self.self_encoder( 71 | passage_encoded, passage_mask, passage_encoded, passage_mask) 72 | 73 | span_start_logits, span_end_logits = self.output_layer( 74 | quetion_encoded, question_mask, passage_encoded, passage_mask) 75 | 76 | # Calculating loss and making prediction 77 | # Following code is copied from allennlp.models.BidirectionalAttentionFlow 78 | span_start_probs = util.masked_softmax(span_start_logits, passage_mask) 79 | span_end_probs = util.masked_softmax(span_end_logits, passage_mask) 80 | 81 | span_start_logits = util.replace_masked_values( 82 | span_start_logits, passage_mask, -1e7) 83 | span_end_logits = util.replace_masked_values( 84 | span_end_logits, passage_mask, -1e7) 85 | 86 | best_span = self.get_best_span(span_start_logits, span_end_logits) 87 | 88 | output_dict = { 89 | "span_start_logits": span_start_logits, 90 | "span_start_probs": span_start_probs, 91 | "span_end_logits": span_end_logits, 92 | "span_end_probs": span_end_probs, 93 | "best_span": best_span, 94 | } 95 | 96 | 97 | if span_start is not None: 98 | loss = nll_loss(util.masked_log_softmax( 99 | span_start_logits, passage_mask), span_start.squeeze(-1)) 100 | self._span_start_accuracy( 101 | span_start_logits, span_start.squeeze(-1)) 102 | loss += nll_loss(util.masked_log_softmax(span_end_logits, 103 | passage_mask), span_end.squeeze(-1)) 104 | self._span_end_accuracy(span_end_logits, span_end.squeeze(-1)) 105 | self._span_accuracy(best_span, torch.stack( 106 | [span_start, span_end], -1)) 107 | output_dict["loss"] = loss 108 | 109 | # Compute the EM and F1 on SQuAD and add the tokenized input to the output. 110 | if metadata is not None: 111 | output_dict['best_span_str'] = [] 112 | question_tokens = [] 113 | passage_tokens = [] 114 | batch_size = question_embeded.size(0) 115 | for i in range(batch_size): 116 | question_tokens.append(metadata[i]['question_tokens']) 117 | passage_tokens.append(metadata[i]['passage_tokens']) 118 | passage_str = metadata[i]['original_passage'] 119 | offsets = metadata[i]['token_offsets'] 120 | predicted_span = tuple(best_span[i].detach().cpu().numpy()) 121 | start_offset = offsets[predicted_span[0]][0] 122 | end_offset = offsets[predicted_span[1]][1] 123 | best_span_string = passage_str[start_offset:end_offset] 124 | output_dict['best_span_str'].append(best_span_string) 125 | answer_texts = metadata[i].get('answer_texts', []) 126 | if answer_texts: 127 | self._squad_metrics(best_span_string, answer_texts) 128 | output_dict['question_tokens'] = question_tokens 129 | output_dict['passage_tokens'] = passage_tokens 130 | return output_dict 131 | 132 | def get_metrics(self, reset: bool = False) -> Dict[str, float]: 133 | exact_match, f1_score = self._squad_metrics.get_metric(reset) 134 | return {'start_acc': self._span_start_accuracy.get_metric(reset), 135 | 'end_acc': self._span_end_accuracy.get_metric(reset), 136 | 'span_acc': self._span_accuracy.get_metric(reset), 137 | 'em': exact_match, 138 | 'f1': f1_score, 139 | } 140 | 141 | @staticmethod 142 | def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor: 143 | if span_start_logits.dim() != 2 or span_end_logits.dim() != 2: 144 | raise ValueError("Input shapes must be (batch_size, passage_length)") 145 | batch_size, passage_length = span_start_logits.size() 146 | max_span_log_prob = [-1e20] * batch_size 147 | span_start_argmax = [0] * batch_size 148 | best_word_span = span_start_logits.new_zeros((batch_size, 2), dtype=torch.long) 149 | 150 | span_start_logits = span_start_logits.detach().cpu().numpy() 151 | span_end_logits = span_end_logits.detach().cpu().numpy() 152 | 153 | for b in range(batch_size): # pylint: disable=invalid-name 154 | for j in range(passage_length): 155 | val1 = span_start_logits[b, span_start_argmax[b]] 156 | if val1 < span_start_logits[b, j]: 157 | span_start_argmax[b] = j 158 | val1 = span_start_logits[b, j] 159 | 160 | val2 = span_end_logits[b, j] 161 | 162 | if val1 + val2 > max_span_log_prob[b]: 163 | best_word_span[b, 0] = span_start_argmax[b] 164 | best_word_span[b, 1] = j 165 | max_span_log_prob[b] = val1 + val2 166 | return best_word_span 167 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | allennlp>=0.7.2 2 | torch>=1.0.0 -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthew-z/R-net/670f3796cd01595903ef781ec7eb0e55c020e77b/tests/__init__.py -------------------------------------------------------------------------------- /tests/fixtures/rnet/experiment.jsonnet: -------------------------------------------------------------------------------- 1 | local embedding_size = 10; 2 | local hidden_size = 5; 3 | local attention_size = 5; 4 | local num_layers = 2; 5 | local dropout = 0.3; 6 | local bidirectional = true; 7 | local rnn_type = 'gru'; 8 | 9 | local classifier = { 10 | input_dim: hidden_size * 2, 11 | num_layers: 1, 12 | hidden_dims: [ 13 | 5, 14 | ], 15 | activations: [ 16 | 'linear', 17 | ], 18 | dropout: [ 19 | 0.0, 20 | ], 21 | }; 22 | 23 | { 24 | dataset_reader: { 25 | type: 'squad', 26 | token_indexers: { 27 | tokens: { 28 | type: 'single_id', 29 | lowercase_tokens: true, 30 | }, 31 | token_characters: { 32 | type: 'characters', 33 | character_tokenizer: { 34 | byte_encoding: 'utf-8', 35 | start_tokens: [259], 36 | end_tokens: [260], 37 | }, 38 | min_padding_length: 5, 39 | }, 40 | }, 41 | }, 42 | 43 | 44 | train_data_path: 'tests/fixtures/data/squad.json', 45 | validation_data_path: 'tests/fixtures/data/squad.json', 46 | model: { 47 | type: 'r_net', 48 | text_field_embedder: { 49 | token_embedders: { 50 | tokens: { 51 | type: 'embedding', 52 | embedding_dim: 2, 53 | trainable: true, 54 | }, 55 | token_characters: { 56 | type: 'character_encoding', 57 | embedding: { 58 | num_embeddings: 262, 59 | embedding_dim: 8, 60 | }, 61 | encoder: { 62 | type: 'cnn', 63 | embedding_dim: 8, 64 | num_filters: 8, 65 | ngram_filter_sizes: [5], 66 | }, 67 | dropout: 0.2, 68 | }, 69 | }, 70 | 71 | }, 72 | 73 | question_encoder: { 74 | type: 'concat_rnn', 75 | input_size: embedding_size, 76 | hidden_size: hidden_size, 77 | num_layers: num_layers, 78 | bidirectional: bidirectional, 79 | dropout: dropout, 80 | }, 81 | 82 | passage_encoder: { 83 | type: 'concat_rnn', 84 | input_size: embedding_size, 85 | hidden_size: hidden_size, 86 | num_layers: num_layers, 87 | bidirectional: bidirectional, 88 | dropout: dropout, 89 | }, 90 | 91 | pair_encoder: { 92 | type: 'dynamic_pair_encoder', 93 | memory_size: hidden_size * 2 * num_layers, 94 | input_size: hidden_size * 2 * num_layers, 95 | hidden_size: hidden_size, 96 | attention_size: attention_size, 97 | bidirectional: bidirectional, 98 | dropout: dropout, 99 | batch_first: true, 100 | }, 101 | 102 | self_encoder: { 103 | type: 'static_self_encoder', 104 | memory_size: hidden_size * 2, 105 | input_size: hidden_size * 2, 106 | hidden_size: hidden_size, 107 | attention_size: attention_size, 108 | bidirectional: bidirectional, 109 | dropout: dropout, 110 | batch_first: true, 111 | }, 112 | 113 | output_layer: { 114 | type: 'pointer_network', 115 | question_size: hidden_size * 2 * num_layers, 116 | passage_size: hidden_size * 2, 117 | attention_size: attention_size, 118 | dropout: dropout, 119 | batch_first: true, 120 | }, 121 | }, 122 | 123 | 124 | iterator: { 125 | type: 'bucket', 126 | sorting_keys: [['passage', 'num_tokens'], ['question', 'num_tokens']], 127 | batch_size: 64, 128 | padding_noise: 0.0, 129 | 130 | }, 131 | 132 | trainer: { 133 | num_epochs: 1, 134 | grad_norm: 5.0, 135 | patience: 10, 136 | validation_metric: '+em', 137 | cuda_device: -1, 138 | learning_rate_scheduler: { 139 | type: 'reduce_on_plateau', 140 | factor: 0.5, 141 | mode: 'max', 142 | patience: 2, 143 | }, 144 | optimizer: { 145 | type: 'adam', 146 | betas: [0.9, 0.9], 147 | }, 148 | }, 149 | } 150 | -------------------------------------------------------------------------------- /tests/fixtures/rnet/experiment_dynamic.jsonnet: -------------------------------------------------------------------------------- 1 | local embedding_size = 10; 2 | local hidden_size = 5; 3 | local attention_size = 5; 4 | local num_layers = 2; 5 | local dropout = 0.3; 6 | local bidirectional = true; 7 | local rnn_type = 'gru'; 8 | 9 | local classifier = { 10 | input_dim: hidden_size * 2, 11 | num_layers: 1, 12 | hidden_dims: [ 13 | 5, 14 | ], 15 | activations: [ 16 | 'linear', 17 | ], 18 | dropout: [ 19 | 0.0, 20 | ], 21 | }; 22 | 23 | { 24 | dataset_reader: { 25 | type: 'squad', 26 | token_indexers: { 27 | tokens: { 28 | type: 'single_id', 29 | lowercase_tokens: true, 30 | }, 31 | token_characters: { 32 | type: 'characters', 33 | character_tokenizer: { 34 | byte_encoding: 'utf-8', 35 | start_tokens: [259], 36 | end_tokens: [260], 37 | }, 38 | min_padding_length: 5, 39 | }, 40 | }, 41 | }, 42 | 43 | 44 | train_data_path: 'tests/fixtures/data/squad.json', 45 | validation_data_path: 'tests/fixtures/data/squad.json', 46 | model: { 47 | type: 'r_net', 48 | text_field_embedder: { 49 | token_embedders: { 50 | tokens: { 51 | type: 'embedding', 52 | embedding_dim: 2, 53 | trainable: true, 54 | }, 55 | token_characters: { 56 | type: 'character_encoding', 57 | embedding: { 58 | num_embeddings: 262, 59 | embedding_dim: 8, 60 | }, 61 | encoder: { 62 | type: 'cnn', 63 | embedding_dim: 8, 64 | num_filters: 8, 65 | ngram_filter_sizes: [5], 66 | }, 67 | dropout: 0.2, 68 | }, 69 | }, 70 | 71 | }, 72 | 73 | question_encoder: { 74 | type: 'gru', 75 | input_size: embedding_size, 76 | hidden_size: hidden_size, 77 | num_layers: num_layers, 78 | bidirectional: bidirectional, 79 | dropout: dropout, 80 | }, 81 | 82 | passage_encoder: { 83 | type: 'gru', 84 | input_size: embedding_size, 85 | hidden_size: hidden_size, 86 | num_layers: num_layers, 87 | bidirectional: bidirectional, 88 | dropout: dropout, 89 | }, 90 | 91 | pair_encoder: { 92 | type: 'dynamic_pair_encoder', 93 | memory_size: hidden_size * 2, 94 | input_size: hidden_size * 2, 95 | hidden_size: hidden_size, 96 | attention_size: attention_size, 97 | bidirectional: bidirectional, 98 | dropout: dropout, 99 | batch_first: true, 100 | 101 | }, 102 | 103 | self_encoder: { 104 | type: 'dynamic_self_encoder', 105 | memory_size: hidden_size * 2, 106 | input_size: hidden_size * 2, 107 | hidden_size: hidden_size, 108 | attention_size: attention_size, 109 | bidirectional: bidirectional, 110 | dropout: dropout, 111 | batch_first: true, 112 | 113 | }, 114 | 115 | output_layer: { 116 | type: 'pointer_network', 117 | question_size: hidden_size * 2, 118 | passage_size: hidden_size * 2, 119 | attention_size: attention_size, 120 | dropout: dropout, 121 | batch_first: true, 122 | }, 123 | }, 124 | 125 | 126 | iterator: { 127 | type: 'bucket', 128 | sorting_keys: [['passage', 'num_tokens'], ['question', 'num_tokens']], 129 | batch_size: 64, 130 | padding_noise: 0.0, 131 | 132 | }, 133 | 134 | trainer: { 135 | num_epochs: 3, 136 | grad_norm: 5.0, 137 | patience: 10, 138 | validation_metric: '+em', 139 | cuda_device: -1, 140 | learning_rate_scheduler: { 141 | type: 'reduce_on_plateau', 142 | factor: 0.5, 143 | mode: 'max', 144 | patience: 2, 145 | }, 146 | optimizer: { 147 | type: 'adam', 148 | betas: [0.9, 0.9], 149 | }, 150 | }, 151 | } 152 | -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matthew-z/R-net/670f3796cd01595903ef781ec7eb0e55c020e77b/tests/models/__init__.py -------------------------------------------------------------------------------- /tests/models/r-net_dynamic_test.py: -------------------------------------------------------------------------------- 1 | from allennlp.common.testing import ModelTestCase 2 | from qa.squad.rnet import RNet 3 | 4 | class RNetDynamicTest(ModelTestCase): 5 | def setUp(self): 6 | super().setUp() 7 | self.set_up_model('tests/fixtures/rnet/experiment_dynamic.jsonnet', 8 | 'tests/fixtures/data/squad.json') 9 | 10 | def test_model_can_train_save_and_load(self): 11 | self.ensure_model_can_train_save_and_load(self.param_file) 12 | -------------------------------------------------------------------------------- /tests/models/r-net_test.py: -------------------------------------------------------------------------------- 1 | from allennlp.common.testing import ModelTestCase 2 | from qa.squad.rnet import RNet 3 | from modules.rnn.stacked_rnn import ConcatRNN 4 | class RNetTest(ModelTestCase): 5 | def setUp(self): 6 | super().setUp() 7 | self.set_up_model('tests/fixtures/rnet/experiment.jsonnet', 8 | 'tests/fixtures/data/squad.json') 9 | 10 | def test_model_can_train_save_and_load(self): 11 | self.ensure_model_can_train_save_and_load(self.param_file) 12 | 13 | --------------------------------------------------------------------------------