├── .gitignore ├── LICENSE ├── README.md ├── configs ├── ablation.json5 ├── data │ ├── quora.json5 │ ├── scitail.json5 │ ├── snli.json5 │ └── wikiqa.json5 ├── debug.json5 ├── default.json5 └── main.json5 ├── data ├── prepare_quora.py ├── prepare_scitail.py ├── prepare_snli.py └── prepare_wikiqa.py ├── evaluate.py ├── figure.png ├── requirements.txt ├── src ├── __init__.py ├── evaluator.py ├── interface.py ├── model.py ├── modules │ ├── __init__.py │ ├── alignment.py │ ├── connection.py │ ├── embedding.py │ ├── encoder.py │ ├── fusion.py │ ├── pooling.py │ └── prediction.py ├── network.py ├── trainer.py └── utils │ ├── __init__.py │ ├── loader.py │ ├── logger.py │ ├── metrics.py │ ├── params.py │ ├── registry.py │ └── vocab.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | /data/* 2 | !/data/*.py 3 | /models/ 4 | /resources/ 5 | 6 | /.idea 7 | __pycache__/ 8 | .DS_Store 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RE2 2 | 3 | This is a pytorch implementation of the ACL 2019 paper "Simple and Effective Text Matching with Richer Alignment Features". The original Tensorflow implementation: https://github.com/alibaba-edu/simple-effective-text-matching. 4 | 5 | ## Quick Links 6 | 7 | - [About](#simple-and-effective-text-matching) 8 | - [Setup](#setup) 9 | - [Usage](#usage) 10 | 11 | ## Simple and Effective Text Matching 12 | 13 | RE2 is a fast and strong neural architecture for general purpose text matching applications. 14 | In a text matching task, a model takes two text sequences as input and predicts their relationship. 15 | This method aims to explore what is sufficient for strong performance in these tasks. 16 | It simplifies many slow components which are previously considered as core building blocks in text matching, 17 | while keeping three key features directly available for inter-sequence alignment: 18 | original point-wise features, previous aligned features, and contextual features. 19 | 20 | RE2 achieves performance on par with the state of the art on four benchmark datasets: SNLI, SciTail, Quora and WikiQA, 21 | across tasks of natural language inference, paraphrase identification and answer selection 22 | with no or few task-specific adaptations. It has at least 6 times faster inference speed compared to similarly performed models. 23 | 24 |

25 | 26 | The following table lists major experiment results. 27 | The paper reports the average and standard deviation of 10 runs. 28 | Inference time (in seconds) is measured by processing a batch of 8 pairs of length 20 on Intel i7 CPUs. 29 | The computation time of POS features used by CSRAN and DIIN is not included. 30 | 31 | |Model|SNLI|SciTail|Quora|WikiQA|Inference Time| 32 | |---|---|---|---|---|---| 33 | |[BiMPM](https://github.com/zhiguowang/BiMPM)|86.9|-|88.2|0.731|0.05| 34 | |[ESIM](https://github.com/lukecq1231/nli)|88.0|70.6|-|-|-| 35 | |[DIIN](https://github.com/YichenGong/Densely-Interactive-Inference-Network)|88.0|-|89.1|-|1.79| 36 | |[CSRAN](https://github.com/vanzytay/EMNLP2018_NLI)|88.7|86.7|89.2|-|0.28| 37 | |RE2|88.9±0.1|86.0±0.6|89.2±0.2|0.7618 ±0.0040|0.03~0.05| 38 | 39 | Refer to the paper for more details of the components and experiment results. 40 | 41 | ## Setup 42 | 43 | - install python >= 3.6 and pip 44 | - `pip install -r requirements.txt` 45 | - install [PyTorch](https://pytorch.org) 46 | - Download [GloVe word vectors](https://nlp.stanford.edu/projects/glove/) (glove.840B.300d) to `resources/` 47 | 48 | Data used in the paper are prepared as follows: 49 | 50 | ### SNLI 51 | 52 | - Download and unzip [SNLI](https://www.dropbox.com/s/0r82spk628ksz70/SNLI.zip?dl=0) 53 | (pre-processed by [Tay et al.](https://github.com/vanzytay/EMNLP2018_NLI)) to `data/orig`. 54 | - Unzip all zip files in the "data/orig/SNLI" folder. (`cd data/orig/SNLI && gunzip *.gz`) 55 | - `cd data && python prepare_snli.py` 56 | 57 | ### SciTail 58 | 59 | - Download and unzip [SciTail](http://data.allenai.org.s3.amazonaws.com/downloads/SciTailV1.1.zip) 60 | dataset to `data/orig`. 61 | - `cd data && python prepare_scitail.py` 62 | 63 | ### Quora 64 | 65 | - Download and unzip [Quora](https://drive.google.com/file/d/0B0PlTAo--BnaQWlsZl9FZ3l1c28/view?usp=sharing) 66 | dataset (pre-processed by [Wang et al.](https://github.com/zhiguowang/BiMPM)) to `data/orig`. 67 | - `cd data && python prepare_quora.py` 68 | 69 | ### WikiQA 70 | 71 | - Download and unzip [WikiQA](https://www.microsoft.com/en-us/download/details.aspx?id=52419) 72 | to `data/orig`. 73 | - `cd data && python prepare_wikiqa.py` 74 | - Download and unzip [evaluation scripts](http://cs.stanford.edu/people/mengqiu/data/qg-emnlp07-data.tgz). 75 | Use the `make -B` command to compile the source files in `qg-emnlp07-data/eval/trec_eval-8.0`. 76 | Move the binary file "trec_eval" to `resources/`. 77 | 78 | ## Usage 79 | 80 | To train a new text matching model, run the following command: 81 | 82 | ```bash 83 | python train.py $config_file.json5 84 | ``` 85 | 86 | Example configuration files are provided in `configs/`: 87 | 88 | - `configs/main.json5`: replicate the main experiment result in the paper. 89 | - `configs/robustness.json5`: robustness checks 90 | - `configs/ablation.json5`: ablation study 91 | 92 | The instructions to write your own configuration files: 93 | 94 | ```json5 95 | [ 96 | { 97 | name: 'exp1', // name of your experiment, can be the same across different data 98 | __parents__: [ 99 | 'default', // always put the default on top 100 | 'data/quora', // data specific configurations in `configs/data` 101 | // 'debug', // use "debug" to quick debug your code 102 | ], 103 | __repeat__: 5, // how may repetitions you want 104 | blocks: 3, // other configurations for this experiment 105 | }, 106 | // multiple configurations are executed sequentially 107 | { 108 | name: 'exp2', // results under the same name will be overwritten 109 | __parents__: [ 110 | 'default', 111 | 'data/quora', 112 | ], 113 | __repeat__: 5, 114 | blocks: 4, 115 | } 116 | ] 117 | ``` 118 | 119 | To check the configurations only, use 120 | 121 | ```bash 122 | python train.py $config_file.json5 --dry 123 | ``` 124 | 125 | To evaluate an existed model, use `python evaluate.py $model_path $data_file`, here's an example: 126 | 127 | ```bash 128 | python evaluate.py models/snli/benchmark/best.pt data/snli/train.txt 129 | python evaluate.py models/snli/benchmark/best.pt data/snli/test.txt 130 | ``` 131 | 132 | > Note that multi-GPU training is not yet supported in the pytorch implementation. A single 16G GPU is sufficient for training when blocks < 5 with hidden size 200 and batch size 512. All the results reported in the paper except the robustness checks can be reproduced with a single 16G GPU. 133 | 134 | ## Citation 135 | 136 | Please cite the ACL paper if you use RE2 in your work: 137 | 138 | ``` 139 | @inproceedings{yang2019simple, 140 | title={Simple and Effective Text Matching with Richer Alignment Features}, 141 | author={Yang, Runqi and Zhang, Jianhai and Gao, Xing and Ji, Feng and Chen, Haiqing}, 142 | booktitle={Association for Computational Linguistics (ACL)}, 143 | year={2019} 144 | } 145 | ``` 146 | 147 | ## License 148 | This project is under Apache License 2.0. 149 | -------------------------------------------------------------------------------- /configs/ablation.json5: -------------------------------------------------------------------------------- 1 | [ 2 | // original version 3 | { 4 | name: 'original', 5 | __parents__: [ 6 | 'default', 7 | 'data/snli', 8 | ], 9 | __repeat__: 10, 10 | }, 11 | { 12 | name: 'original', 13 | __parents__: [ 14 | 'default', 15 | 'data/scitail', 16 | ], 17 | __repeat__: 10, 18 | }, 19 | { 20 | name: 'original', 21 | __parents__: [ 22 | 'default', 23 | 'data/quora', 24 | ], 25 | __repeat__: 10, 26 | }, 27 | { 28 | name: 'original', 29 | __parents__: [ 30 | 'default', 31 | 'data/wikiqa', 32 | ], 33 | __repeat__: 10, 34 | }, 35 | // alignment alternative 36 | { 37 | name: 'alignment-alt', 38 | __parents__: [ 39 | 'default', 40 | 'data/snli', 41 | ], 42 | alignment: 'identity', 43 | __repeat__: 10, 44 | }, 45 | { 46 | name: 'alignment-alt', 47 | __parents__: [ 48 | 'default', 49 | 'data/scitail', 50 | ], 51 | __repeat__: 10, 52 | alignment: 'linear' 53 | }, 54 | { 55 | name: 'alignment-alt', 56 | __parents__: [ 57 | 'default', 58 | 'data/quora', 59 | ], 60 | __repeat__: 10, 61 | alignment: 'identity' 62 | }, 63 | { 64 | name: 'alignment-alt', 65 | __parents__: [ 66 | 'default', 67 | 'data/wikiqa', 68 | ], 69 | __repeat__: 10, 70 | alignment: 'identity' 71 | }, 72 | // prediction alternative 73 | { 74 | name: 'prediction-alt', 75 | __parents__: [ 76 | 'default', 77 | 'data/snli', 78 | ], 79 | prediction: 'simple', 80 | __repeat__: 10, 81 | }, 82 | { 83 | name: 'prediction-alt', 84 | __parents__: [ 85 | 'default', 86 | 'data/scitail', 87 | ], 88 | __repeat__: 10, 89 | prediction: 'simple', 90 | }, 91 | { 92 | name: 'prediction-alt', 93 | __parents__: [ 94 | 'default', 95 | 'data/quora', 96 | ], 97 | __repeat__: 10, 98 | prediction: 'simple', 99 | }, 100 | { 101 | name: 'prediction-alt', 102 | __parents__: [ 103 | 'default', 104 | 'data/wikiqa', 105 | ], 106 | __repeat__: 10, 107 | prediction: 'full', 108 | }, 109 | // residual connection 110 | { 111 | name: 'residual-conn', 112 | __parents__: [ 113 | 'default', 114 | 'data/snli', 115 | ], 116 | connection: 'residual', 117 | __repeat__: 10, 118 | }, 119 | { 120 | name: 'residual-conn', 121 | __parents__: [ 122 | 'default', 123 | 'data/scitail', 124 | ], 125 | __repeat__: 10, 126 | connection: 'residual' 127 | }, 128 | { 129 | name: 'residual-conn', 130 | __parents__: [ 131 | 'default', 132 | 'data/quora', 133 | ], 134 | __repeat__: 10, 135 | connection: 'residual' 136 | }, 137 | { 138 | name: 'residual-conn', 139 | __parents__: [ 140 | 'default', 141 | 'data/wikiqa', 142 | ], 143 | __repeat__: 10, 144 | connection: 'residual' 145 | }, 146 | // simple fusion 147 | { 148 | name: 'simple-fusion', 149 | __parents__: [ 150 | 'default', 151 | 'data/snli', 152 | ], 153 | __repeat__: 10, 154 | fusion: 'simple' 155 | }, 156 | { 157 | name: 'simple-fusion', 158 | __parents__: [ 159 | 'default', 160 | 'data/scitail', 161 | ], 162 | __repeat__: 10, 163 | fusion: 'simple' 164 | }, 165 | { 166 | name: 'simple-fusion', 167 | __parents__: [ 168 | 'default', 169 | 'data/quora', 170 | ], 171 | __repeat__: 10, 172 | fusion: 'simple' 173 | }, 174 | { 175 | name: 'simple-fusion', 176 | __parents__: [ 177 | 'default', 178 | 'data/wikiqa', 179 | ], 180 | __repeat__: 10, 181 | fusion: 'simple' 182 | } 183 | ] -------------------------------------------------------------------------------- /configs/data/quora.json5: -------------------------------------------------------------------------------- 1 | { 2 | data_dir: 'data/quora', 3 | output_dir: 'quora', 4 | metric: 'acc', 5 | 6 | model: { 7 | enc_layers: 2, 8 | blocks: 2, 9 | prediction: 'symmetric', 10 | hidden_size: 200, 11 | max_len: 100, 12 | }, 13 | 14 | routine: { 15 | eval_per_samples: 12800, 16 | eval_warmup_samples: 3584000, 17 | eval_per_samples_warmup: 512000, 18 | min_samples: 5120000, 19 | tolerance_samples: 2560000, 20 | }, 21 | 22 | optim: { 23 | lr: 0.0012, 24 | min_lr: 6e-5, 25 | lr_decay_samples: 256000, 26 | batch_size: 512, 27 | lr_warmup_samples: 0, 28 | }, 29 | } -------------------------------------------------------------------------------- /configs/data/scitail.json5: -------------------------------------------------------------------------------- 1 | { 2 | data_dir: 'data/scitail', 3 | output_dir: 'scitail', 4 | metric: 'acc', 5 | 6 | model: { 7 | alignment: 'identity', 8 | enc_layers: 3, 9 | blocks: 2, 10 | hidden_size: 200, 11 | }, 12 | 13 | routine: { 14 | epochs: 80, 15 | log_per_samples: 1280, 16 | eval_per_samples: 6400, 17 | }, 18 | 19 | optim: { // fixed learning rate 20 | lr: 0.001, 21 | lr_warmup_samples: 0, 22 | lr_decay_rate: 1.0, 23 | }, 24 | } -------------------------------------------------------------------------------- /configs/data/snli.json5: -------------------------------------------------------------------------------- 1 | { 2 | data_dir: 'data/snli', 3 | output_dir: 'snli', 4 | metric: 'acc', 5 | watch_metrics: [], 6 | 7 | model: { 8 | enc_layers: 2, 9 | blocks: 3, 10 | }, 11 | 12 | routine: { 13 | eval_per_samples: 12800, 14 | eval_warmup_samples: 5120000, 15 | eval_per_samples_warmup: 512000, 16 | min_samples: 5120000, 17 | tolerance_samples: 2560000, 18 | }, 19 | 20 | optim: { 21 | lr: 0.002, 22 | min_lr: 1e-4, 23 | lr_decay_samples: 256000, 24 | lr_decay_rate: 0.94, 25 | batch_size: 512, 26 | lr_warmup_samples: 2048000, 27 | }, 28 | } -------------------------------------------------------------------------------- /configs/data/wikiqa.json5: -------------------------------------------------------------------------------- 1 | { 2 | data_dir: 'data/wikiqa', 3 | output_dir: 'wikiqa', 4 | metric: 'mrr', 5 | watch_metrics: ['map'], 6 | 7 | model: { 8 | enc_layers: 3, 9 | blocks: 2, 10 | hidden_size: 200, 11 | prediction: 'simple', 12 | }, 13 | 14 | routine: { 15 | log_per_samples: 256, 16 | eval_per_samples: 1280, 17 | tolerance_samples: 256000, 18 | eval_epoch: false, 19 | }, 20 | 21 | optim: { 22 | lr: 0.001, 23 | lr_decay_rate: 1.0, 24 | batch_size: 128, 25 | }, 26 | } -------------------------------------------------------------------------------- /configs/debug.json5: -------------------------------------------------------------------------------- 1 | { 2 | batch_size: 8, 3 | blocks: 2, 4 | tensorboard: true, 5 | log_per_updates: 2, 6 | summary_per_logs: 1, 7 | eval_subset: 100, 8 | eval_per_updates: 50, 9 | eval_warmup_samples: 0, 10 | save_all: true, 11 | sort_by_len: true, 12 | seed: 123, 13 | pretrained_embeddings: 'resources/glove.6B.300d.txt', 14 | } -------------------------------------------------------------------------------- /configs/default.json5: -------------------------------------------------------------------------------- 1 | { 2 | basic: { 3 | output_dir: 'default', 4 | seed: null, 5 | cuda: true, 6 | multi_gpu: false, 7 | deterministic: true, // GPU deterministic mode, will slow down training 8 | }, 9 | 10 | data: { 11 | data_dir: null, 12 | min_df: 5, 13 | max_vocab: 999999, // capacity for words including out of embedding words 14 | max_len: 999, // large enough number, treated as unlimited 15 | min_len: 1, 16 | lower_case: true, // whether to treat the data and embedding as lowercase. 17 | sort_by_len: false, 18 | pretrained_embeddings: 'resources/glove.840B.300d.txt', 19 | embedding_dim: 300, 20 | embedding_mode: 'freq', // (options: 'freq', 'last', 'avg', 'strict') what to do when duplicated embedding tokens (after normalization) are found. 21 | }, 22 | 23 | model: { 24 | hidden_size: 150, 25 | dropout: 0.2, 26 | blocks: 2, 27 | fix_embeddings: true, 28 | encoder: { 29 | encoder: 'cnn', // cnn, lstm 30 | enc_layers: 2, 31 | kernel_sizes: [3], 32 | }, 33 | alignment: 'linear', // linear, identity 34 | fusion: 'full', // full, simple 35 | connection: 'aug', // aug, residual 36 | prediction: 'full', // full, symmetric, simple 37 | 38 | }, 39 | 40 | logging: { 41 | log_file: 'log.txt', 42 | log_per_samples: 5120, 43 | summary_per_logs: 20, 44 | tensorboard: true, 45 | }, 46 | 47 | training: { 48 | epochs: 30, 49 | batch_size: 128, 50 | grad_clipping: 5, 51 | weight_decay: 0, 52 | lr: 1e-3, 53 | beta1: 0.9, 54 | beta2: 0.999, 55 | max_loss: 999., // tolerance for unstable training 56 | lr_decay_rate: 0.95, // exp decay rate for lr 57 | lr_decay_samples: 128000, 58 | min_lr: 6e-5, 59 | lr_warmup_samples: 0, // linear warmup steps for lr 60 | }, 61 | 62 | evaluation: { 63 | // available metrics: acc, auc, f1, map, mrr 64 | metric: 'acc', // for early stopping 65 | watch_metrics: ['auc', 'f1'], // shown in logs 66 | eval_file: 'dev', 67 | eval_per_samples: 6400, 68 | eval_per_samples_warmup: 40000, 69 | eval_warmup_samples: 0, // after this many steps warmup mode for eval ends 70 | min_samples: 0, // train at least these many steps, not affected by early stopping 71 | tolerance_samples: 400000, // early stopping 72 | eval_epoch: true, // eval after epoch 73 | eval_subset: null, 74 | }, 75 | 76 | persistence: { 77 | resume: null, 78 | save: true, 79 | save_all: false, 80 | }, 81 | } -------------------------------------------------------------------------------- /configs/main.json5: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | name: 'benchmark', 4 | __parents__: [ 5 | 'default', 6 | 'data/snli', 7 | ], 8 | __repeat__: 10, 9 | eval_file: 'test', 10 | }, 11 | { 12 | name: 'benchmark', 13 | __parents__: [ 14 | 'default', 15 | 'data/scitail', 16 | ], 17 | __repeat__: 10, 18 | eval_file: 'test', 19 | }, 20 | { 21 | name: 'benchmark', 22 | __parents__: [ 23 | 'default', 24 | 'data/quora', 25 | ], 26 | __repeat__: 10, 27 | eval_file: 'test', 28 | }, 29 | { 30 | name: 'benchmark', 31 | __parents__: [ 32 | 'default', 33 | 'data/wikiqa', 34 | ], 35 | __repeat__: 10, 36 | eval_file: 'test', 37 | }, 38 | ] -------------------------------------------------------------------------------- /data/prepare_quora.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 Alibaba Group Holding Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import os 18 | from tqdm import tqdm 19 | 20 | 21 | print('processing quora') 22 | os.makedirs('quora', exist_ok=True) 23 | # use the partition on https://zhiguowang.github.io 24 | for split in ('train', 'dev', 'test'): 25 | with open('orig/Quora_question_pair_partition/{}.tsv'.format(split)) as f, \ 26 | open('quora/{}.txt'.format(split), 'w') as fout: 27 | n_lines = 0 28 | for _ in f: 29 | n_lines += 1 30 | f.seek(0) 31 | for line in tqdm(f, total=n_lines, leave=False): 32 | elements = line.rstrip().split('\t') 33 | fout.write('{}\t{}\t{}\n'.format(elements[1], elements[2], int(elements[0]))) 34 | -------------------------------------------------------------------------------- /data/prepare_scitail.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 Alibaba Group Holding Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import re 18 | import os 19 | import json 20 | from tqdm import tqdm 21 | from nltk.tokenize import TweetTokenizer 22 | 23 | 24 | tokenizer = TweetTokenizer() 25 | label_map = { 26 | 'entailment': 0, 27 | 'neutral': 1, 28 | 'contradiction': 2, 29 | } 30 | 31 | 32 | def tokenize(string): 33 | string = ' '.join(tokenizer.tokenize(string)) 34 | string = re.sub(r"[-.#\"/]", " ", string) 35 | string = re.sub(r"\'(?!(s|m|ve|t|re|d|ll)( |$))", " ", string) 36 | string = re.sub(r"\'s", " \'s", string) 37 | string = re.sub(r"\'m", " \'m", string) 38 | string = re.sub(r"\'ve", " \'ve", string) 39 | string = re.sub(r"n\'t", " n\'t", string) 40 | string = re.sub(r"\'re", " \'re", string) 41 | string = re.sub(r"\'d", " \'d", string) 42 | string = re.sub(r"\'ll", " \'ll", string) 43 | string = re.sub(r"\s{2,}", " ", string) 44 | return string.strip() 45 | 46 | 47 | os.makedirs('scitail', exist_ok=True) 48 | 49 | 50 | for split in ['train', 'dev', 'test']: 51 | print('processing SciTail', split) 52 | with open('orig/SciTailV1.1/snli_format/scitail_1.0_{}.txt'.format(split)) as f, \ 53 | open('scitail/{}.txt'.format(split), 'w', encoding='utf8') as fout: 54 | n_lines = 0 55 | for _ in f: 56 | n_lines += 1 57 | f.seek(0) 58 | for line in tqdm(f, total=n_lines, desc=split, leave=False): 59 | sample = json.loads(line) 60 | sentence1 = tokenize(sample['sentence1']) 61 | sentence2 = tokenize(sample['sentence2']) 62 | label = sample["gold_label"] 63 | assert label in label_map 64 | label = label_map[label] 65 | fout.write('{}\t{}\t{}\n'.format(sentence1, sentence2, label)) 66 | -------------------------------------------------------------------------------- /data/prepare_snli.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 Alibaba Group Holding Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import os 18 | import json 19 | import string 20 | import numpy as np 21 | import msgpack 22 | from collections import Counter 23 | 24 | in_dir = 'orig/SNLI' 25 | out_dir = '../models/snli/' 26 | data_dir = 'snli' 27 | label_map = {2: '0', 1: '1', 0: '2'} 28 | 29 | os.makedirs(out_dir, exist_ok=True) 30 | os.makedirs(data_dir, exist_ok=True) 31 | with open(os.path.join(in_dir, 'env')) as f: 32 | env = json.load(f) 33 | 34 | print('convert embeddings ...') 35 | emb = np.load(os.path.join(in_dir, 'emb_glove_300.npy')) 36 | print(len(emb)) 37 | with open(os.path.join(out_dir, 'embedding.msgpack'), 'wb') as f: 38 | msgpack.dump(emb.tolist(), f) 39 | 40 | print('convert_vocab ...') 41 | w2idx = env['word_index'] 42 | print(len(w2idx)) 43 | idx2w = {i: w for w, i in w2idx.items()} 44 | with open(os.path.join(out_dir, 'vocab.txt'), 'w') as f: 45 | for index in range(len(idx2w)): 46 | if index >= 2: 47 | f.write('{}\n'.format(idx2w[index])) 48 | with open(os.path.join(out_dir, 'target_map.txt'), 'w') as f: 49 | for label in (0, 1, 2): 50 | f.write('{}\n'.format(label)) 51 | 52 | # save data files 53 | punctuactions = set(string.punctuation) 54 | for split in ['train', 'dev', 'test']: 55 | labels = Counter() 56 | print('convert', split, '...') 57 | data = env[split] 58 | with open(os.path.join(data_dir, f'{split}.txt'), 'w') as f_out: 59 | for sample in data: 60 | a, b, label = sample 61 | a = a[1:-1] 62 | b = b[1:-1] 63 | a = [w.lower() for w in a if w and w not in punctuactions] 64 | b = [w.lower() for w in b if w and w not in punctuactions] 65 | assert all(w in w2idx for w in a) and all(w in w2idx for w in b) 66 | a = ' '.join(a) 67 | b = ' '.join(b) 68 | assert len(a) != 0 and len(b) != 0 69 | labels.update({label: 1}) 70 | assert label in label_map 71 | label = label_map[label] 72 | f_out.write(f'{a}\t{b}\t{label}\n') 73 | print('labels:', labels) 74 | -------------------------------------------------------------------------------- /data/prepare_wikiqa.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 Alibaba Group Holding Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import os 18 | from shutil import copyfile 19 | 20 | 21 | def copy(src, tgt): 22 | copyfile(os.path.abspath(src), os.path.abspath(tgt)) 23 | 24 | 25 | os.makedirs('wikiqa', exist_ok=True) 26 | 27 | 28 | copy('orig/WikiQACorpus/WikiQA-dev-filtered.ref', 'wikiqa/dev.ref') 29 | copy('orig/WikiQACorpus/WikiQA-test-filtered.ref', 'wikiqa/test.ref') 30 | copy('orig/WikiQACorpus/emnlp-table/WikiQA.CNN.dev.rank', 'wikiqa/dev.rank') 31 | copy('orig/WikiQACorpus/emnlp-table/WikiQA.CNN.test.rank', 'wikiqa/test.rank') 32 | for split in ['train', 'dev', 'test']: 33 | print('processing WikiQA', split) 34 | copy('orig/WikiQACorpus/WikiQA-{}.txt'.format(split), 'wikiqa/{}.txt'.format(split)) 35 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 Alibaba Group Holding Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import sys 18 | from src.evaluator import Evaluator 19 | 20 | 21 | def main(): 22 | argv = sys.argv 23 | if len(argv) == 3: 24 | model_path, data_file = argv[1:] 25 | evaluator = Evaluator(model_path, data_file) 26 | evaluator.evaluate() 27 | else: 28 | print('Usage: "python evaluate.py $model_path $data_file"') 29 | 30 | 31 | if __name__ == '__main__': 32 | main() 33 | -------------------------------------------------------------------------------- /figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba-edu/simple-effective-text-matching-pytorch/05d572e30801b235e989c78c95dd24d5f5d35f74/figure.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | nltk 3 | numpy 4 | scikit-learn 5 | msgpack-python 6 | tensorboardX 7 | json5 -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba-edu/simple-effective-text-matching-pytorch/05d572e30801b235e989c78c95dd24d5f5d35f74/src/__init__.py -------------------------------------------------------------------------------- /src/evaluator.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 Alibaba Group Holding Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import os 18 | from pprint import pprint 19 | from .model import Model 20 | from .interface import Interface 21 | from .utils.loader import load_data 22 | 23 | 24 | class Evaluator: 25 | def __init__(self, model_path, data_file): 26 | self.model_path = model_path 27 | self.data_file = data_file 28 | 29 | def evaluate(self): 30 | data = load_data(*os.path.split(self.data_file)) 31 | model, checkpoint = Model.load(self.model_path) 32 | args = checkpoint['args'] 33 | interface = Interface(args) 34 | batches = interface.pre_process(data, training=False) 35 | _, stats = model.evaluate(batches) 36 | pprint(stats) 37 | -------------------------------------------------------------------------------- /src/interface.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 Alibaba Group Holding Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import os 18 | import random 19 | import msgpack 20 | from .utils.vocab import Vocab, Indexer 21 | from .utils.loader import load_data, load_embeddings 22 | 23 | 24 | class Interface: 25 | def __init__(self, args, log=None): 26 | self.args = args 27 | # build/load vocab and target map 28 | vocab_file = os.path.join(args.output_dir, 'vocab.txt') 29 | target_map_file = os.path.join(args.output_dir, 'target_map.txt') 30 | if not os.path.exists(vocab_file): 31 | data = load_data(self.args.data_dir) 32 | self.target_map = Indexer.build((sample['target'] for sample in data), log=log) 33 | self.target_map.save(target_map_file) 34 | self.vocab = Vocab.build((word for sample in data 35 | for text in (sample['text1'], sample['text2']) 36 | for word in text.split()[:self.args.max_len]), 37 | lower=args.lower_case, min_df=self.args.min_df, log=log, 38 | pretrained_embeddings=args.pretrained_embeddings, 39 | dump_filtered=os.path.join(args.output_dir, 'filtered_words.txt')) 40 | self.vocab.save(vocab_file) 41 | 42 | else: 43 | self.target_map = Indexer.load(target_map_file) 44 | self.vocab = Vocab.load(vocab_file) 45 | args.num_classes = len(self.target_map) 46 | args.num_vocab = len(self.vocab) 47 | args.padding = Vocab.pad() 48 | 49 | def load_embeddings(self): 50 | """generate embeddings suited for the current vocab or load previously cached ones.""" 51 | assert self.args.pretrained_embeddings 52 | embedding_file = os.path.join(self.args.output_dir, 'embedding.msgpack') 53 | if not os.path.exists(embedding_file): 54 | embeddings = load_embeddings(self.args.pretrained_embeddings, self.vocab, 55 | self.args.embedding_dim, mode=self.args.embedding_mode, 56 | lower=self.args.lower_case) 57 | with open(embedding_file, 'wb') as f: 58 | msgpack.dump(embeddings, f) 59 | else: 60 | with open(embedding_file, 'rb') as f: 61 | embeddings = msgpack.load(f) 62 | return embeddings 63 | 64 | def pre_process(self, data, training=True): 65 | result = [self.process_sample(sample) for sample in data] 66 | if training: 67 | result = list(filter(lambda x: len(x['text1']) < self.args.max_len and len(x['text2']) < self.args.max_len, 68 | result)) 69 | if not self.args.sort_by_len: 70 | return result 71 | result = sorted(result, key=lambda x: (len(x['text1']), len(x['text2']), x['text1'])) 72 | batch_size = self.args.batch_size 73 | return [self.make_batch(result[i:i + batch_size]) for i in range(0, len(data), batch_size)] 74 | 75 | def process_sample(self, sample, with_target=True): 76 | text1 = sample['text1'] 77 | text2 = sample['text2'] 78 | if self.args.lower_case: 79 | text1 = text1.lower() 80 | text2 = text2.lower() 81 | processed = { 82 | 'text1': [self.vocab.index(w) for w in text1.split()[:self.args.max_len]], 83 | 'text2': [self.vocab.index(w) for w in text2.split()[:self.args.max_len]], 84 | } 85 | if 'target' in sample and with_target: 86 | target = sample['target'] 87 | assert target in self.target_map 88 | processed['target'] = self.target_map.index(target) 89 | return processed 90 | 91 | def shuffle_batch(self, data): 92 | data = random.sample(data, len(data)) 93 | if self.args.sort_by_len: 94 | return data 95 | batch_size = self.args.batch_size 96 | batches = [data[i:i + batch_size] for i in range(0, len(data), batch_size)] 97 | return list(map(self.make_batch, batches)) 98 | 99 | def make_batch(self, batch, with_target=True): 100 | batch = {key: [sample[key] for sample in batch] for key in batch[0].keys()} 101 | if 'target' in batch and not with_target: 102 | del batch['target'] 103 | batch = {key: self.padding(value, min_len=self.args.min_len) if key.startswith('text') else value 104 | for key, value in batch.items()} 105 | return batch 106 | 107 | @staticmethod 108 | def padding(samples, min_len=1): 109 | max_len = max(max(map(len, samples)), min_len) 110 | batch = [sample + [Vocab.pad()] * (max_len - len(sample)) for sample in samples] 111 | return batch 112 | 113 | def post_process(self, output): 114 | final_prediction = [] 115 | for prob in output: 116 | idx = max(range(len(prob)), key=prob.__getitem__) 117 | target = self.target_map[idx] 118 | final_prediction.append(target) 119 | return final_prediction 120 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 Alibaba Group Holding Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import os 18 | import math 19 | import random 20 | import torch 21 | import torch.nn.functional as f 22 | from tqdm import tqdm 23 | from .network import Network 24 | from .utils.metrics import registry as metrics 25 | 26 | 27 | class Model: 28 | prefix = 'checkpoint' 29 | best_model_name = 'best.pt' 30 | 31 | def __init__(self, args, state_dict=None): 32 | self.args = args 33 | 34 | # network 35 | self.network = Network(args) 36 | self.device = torch.cuda.current_device() if args.cuda else torch.device('cpu') 37 | self.network.to(self.device) 38 | # optimizer 39 | self.params = list(filter(lambda x: x.requires_grad, self.network.parameters())) 40 | self.opt = torch.optim.Adam(self.params, args.lr, betas=(args.beta1, args.beta2), 41 | weight_decay=args.weight_decay) 42 | # updates 43 | self.updates = state_dict['updates'] if state_dict else 0 44 | 45 | if state_dict: 46 | new_state = set(self.network.state_dict().keys()) 47 | for k in list(state_dict['model'].keys()): 48 | if k not in new_state: 49 | del state_dict['model'][k] 50 | self.network.load_state_dict(state_dict['model']) 51 | self.opt.load_state_dict(state_dict['opt']) 52 | 53 | def _update_schedule(self): 54 | if self.args.lr_decay_rate < 1.: 55 | args = self.args 56 | t = self.updates 57 | base_ratio = args.min_lr / args.lr 58 | if t < args.lr_warmup_steps: 59 | ratio = base_ratio + (1. - base_ratio) / max(1., args.lr_warmup_steps) * t 60 | else: 61 | ratio = max(base_ratio, args.lr_decay_rate ** math.floor((t - args.lr_warmup_steps) / 62 | args.lr_decay_steps)) 63 | self.opt.param_groups[0]['lr'] = args.lr * ratio 64 | 65 | def update(self, batch): 66 | self.network.train() 67 | self.opt.zero_grad() 68 | inputs, target = self.process_data(batch) 69 | output = self.network(inputs) 70 | summary = self.network.get_summary() 71 | loss = self.get_loss(output, target) 72 | loss.backward() 73 | grad_norm = torch.nn.utils.clip_grad_norm_(self.params, self.args.grad_clipping) 74 | assert grad_norm >= 0, 'encounter nan in gradients.' 75 | if isinstance(grad_norm, torch.Tensor): 76 | grad_norm = grad_norm.item() 77 | self.opt.step() 78 | self._update_schedule() 79 | self.updates += 1 80 | stats = { 81 | 'updates': self.updates, 82 | 'loss': loss.item(), 83 | 'lr': self.opt.param_groups[0]['lr'], 84 | 'gnorm': grad_norm, 85 | 'summary': summary, 86 | } 87 | return stats 88 | 89 | def evaluate(self, data): 90 | self.network.eval() 91 | targets = [] 92 | probabilities = [] 93 | predictions = [] 94 | losses = [] 95 | for batch in tqdm(data[:self.args.eval_subset], desc='evaluating', leave=False): 96 | inputs, target = self.process_data(batch) 97 | with torch.no_grad(): 98 | output = self.network(inputs) 99 | loss = self.get_loss(output, target) 100 | pred = torch.argmax(output, dim=1) 101 | prob = torch.nn.functional.softmax(output, dim=1) 102 | losses.append(loss.item()) 103 | targets.extend(target.tolist()) 104 | probabilities.extend(prob.tolist()) 105 | predictions.extend(pred.tolist()) 106 | outputs = { 107 | 'target': targets, 108 | 'prob': probabilities, 109 | 'pred': predictions, 110 | 'args': self.args, 111 | } 112 | stats = { 113 | 'updates': self.updates, 114 | 'loss': sum(losses[:-1]) / (len(losses) - 1) if len(losses) > 1 else sum(losses), 115 | } 116 | for metric in self.args.watch_metrics: 117 | if metric not in stats: # multiple metrics could be computed by the same function 118 | stats.update(metrics[metric](outputs)) 119 | assert 'score' not in stats, 'metric name collides with "score"' 120 | eval_score = stats[self.args.metric] 121 | stats['score'] = eval_score 122 | return eval_score, stats # first value is for early stopping 123 | 124 | def predict(self, batch): 125 | self.network.eval() 126 | inputs, _ = self.process_data(batch) 127 | with torch.no_grad(): 128 | output = self.network(inputs) 129 | output = torch.nn.functional.softmax(output, dim=1) 130 | return output.tolist() 131 | 132 | def process_data(self, batch): 133 | text1 = torch.LongTensor(batch['text1']).to(self.device) 134 | text2 = torch.LongTensor(batch['text2']).to(self.device) 135 | mask1 = torch.ne(text1, self.args.padding).unsqueeze(2) 136 | mask2 = torch.ne(text2, self.args.padding).unsqueeze(2) 137 | inputs = { 138 | 'text1': text1, 139 | 'text2': text2, 140 | 'mask1': mask1, 141 | 'mask2': mask2, 142 | } 143 | if 'target' in batch: 144 | target = torch.LongTensor(batch['target']).to(self.device) 145 | return inputs, target 146 | return inputs, None 147 | 148 | @staticmethod 149 | def get_loss(logits, target): 150 | return f.cross_entropy(logits, target) 151 | 152 | def save(self, states, name=None): 153 | if name: 154 | filename = os.path.join(self.args.summary_dir, name) 155 | else: 156 | filename = os.path.join(self.args.summary_dir, f'{self.prefix}_{self.updates}.pt') 157 | params = { 158 | 'state_dict': { 159 | 'model': self.network.state_dict(), 160 | 'opt': self.opt.state_dict(), 161 | 'updates': self.updates, 162 | }, 163 | 'args': self.args, 164 | 'random_state': random.getstate(), 165 | 'torch_state': torch.random.get_rng_state() 166 | } 167 | params.update(states) 168 | if self.args.cuda: 169 | params['torch_cuda_state'] = torch.cuda.get_rng_state() 170 | torch.save(params, filename) 171 | 172 | @classmethod 173 | def load(cls, file): 174 | checkpoint = torch.load(file, map_location=( 175 | lambda s, _: torch.serialization.default_restore_location(s, 'cpu') 176 | )) 177 | prev_args = checkpoint['args'] 178 | # update args 179 | prev_args.output_dir = os.path.dirname(os.path.dirname(file)) 180 | prev_args.summary_dir = os.path.join(prev_args.output_dir, prev_args.name) 181 | prev_args.cuda = prev_args.cuda and torch.cuda.is_available() 182 | return cls(prev_args, state_dict=checkpoint['state_dict']), checkpoint 183 | 184 | def num_parameters(self, exclude_embed=False): 185 | num_params = sum(p.numel() for p in self.network.parameters() if p.requires_grad) 186 | if exclude_embed: 187 | num_params -= 0 if self.args.fix_embeddings else next(self.network.embedding.parameters()).numel() 188 | return num_params 189 | 190 | def set_embeddings(self, embeddings): 191 | self.network.embedding.set_(embeddings) 192 | -------------------------------------------------------------------------------- /src/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 Alibaba Group Holding Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | from typing import Collection 18 | import math 19 | import torch 20 | import torch.nn as nn 21 | 22 | 23 | class Module(nn.Module): 24 | def __init__(self): 25 | super().__init__() 26 | self.summary = {} 27 | 28 | def add_summary(self, name, val): 29 | if self.training: 30 | self.summary[name] = val.clone().detach().cpu().numpy() 31 | 32 | def get_summary(self, base_name=''): 33 | summary = {} 34 | if base_name: 35 | base_name += '/' 36 | if self.summary: 37 | summary.update({base_name + name: val for name, val in self.summary.items()}) 38 | for name, child in self.named_children(): 39 | if hasattr(child, 'get_summary'): 40 | name = base_name + name 41 | summary.update(child.get_summary(name)) 42 | return summary 43 | 44 | 45 | class ModuleList(nn.ModuleList): 46 | def get_summary(self, base_name=''): 47 | summary = {} 48 | if base_name: 49 | base_name += '/' 50 | for i, module in enumerate(self): 51 | if hasattr(module, 'get_summary'): 52 | name = base_name + str(i) 53 | summary.update(module.get_summary(name)) 54 | return summary 55 | 56 | 57 | class ModuleDict(nn.ModuleDict): 58 | def get_summary(self, base_name=''): 59 | summary = {} 60 | if base_name: 61 | base_name += '/' 62 | for key, module in self.items(): 63 | if hasattr(module, 'get_summary'): 64 | name = base_name + key 65 | summary.update(module.get_summary(name)) 66 | return summary 67 | 68 | 69 | class GeLU(nn.Module): 70 | def forward(self, x): 71 | return 0.5 * x * (1. + torch.tanh(x * 0.7978845608 * (1. + 0.044715 * x * x))) 72 | 73 | 74 | class Linear(nn.Module): 75 | def __init__(self, in_features, out_features, activations=False): 76 | super().__init__() 77 | linear = nn.Linear(in_features, out_features) 78 | nn.init.normal_(linear.weight, std=math.sqrt((2. if activations else 1.) / in_features)) 79 | nn.init.zeros_(linear.bias) 80 | modules = [nn.utils.weight_norm(linear)] 81 | if activations: 82 | modules.append(GeLU()) 83 | self.model = nn.Sequential(*modules) 84 | 85 | def forward(self, x): 86 | return self.model(x) 87 | 88 | 89 | class Conv1d(Module): 90 | def __init__(self, in_channels, out_channels, kernel_sizes: Collection[int]): 91 | super().__init__() 92 | assert all(k % 2 == 1 for k in kernel_sizes), 'only support odd kernel sizes' 93 | assert out_channels % len(kernel_sizes) == 0, 'out channels must be dividable by kernels' 94 | out_channels = out_channels // len(kernel_sizes) 95 | convs = [] 96 | for kernel_size in kernel_sizes: 97 | conv = nn.Conv1d(in_channels, out_channels, kernel_size, 98 | padding=(kernel_size - 1) // 2) 99 | nn.init.normal_(conv.weight, std=math.sqrt(2. / (in_channels * kernel_size))) 100 | nn.init.zeros_(conv.bias) 101 | convs.append(nn.Sequential(nn.utils.weight_norm(conv), GeLU())) 102 | self.model = nn.ModuleList(convs) 103 | 104 | def forward(self, x): 105 | return torch.cat([encoder(x) for encoder in self.model], dim=-1) 106 | -------------------------------------------------------------------------------- /src/modules/alignment.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 Alibaba Group Holding Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import math 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as f 21 | from functools import partial 22 | from src.utils.registry import register 23 | from . import Linear, Module 24 | 25 | registry = {} 26 | register = partial(register, registry=registry) 27 | 28 | 29 | @register('identity') 30 | class Alignment(Module): 31 | def __init__(self, args, __): 32 | super().__init__() 33 | self.temperature = nn.Parameter(torch.tensor(1 / math.sqrt(args.hidden_size))) 34 | 35 | def _attention(self, a, b): 36 | return torch.matmul(a, b.transpose(1, 2)) * self.temperature 37 | 38 | def forward(self, a, b, mask_a, mask_b): 39 | attn = self._attention(a, b) 40 | mask = torch.matmul(mask_a.float(), mask_b.transpose(1, 2).float()) 41 | if tuple(torch.__version__.split('.')) < ('1', '2'): 42 | mask = mask.byte() 43 | else: 44 | mask = mask.bool() 45 | attn.masked_fill_(~mask, -1e7) 46 | attn_a = f.softmax(attn, dim=1) 47 | attn_b = f.softmax(attn, dim=2) 48 | feature_b = torch.matmul(attn_a.transpose(1, 2), a) 49 | feature_a = torch.matmul(attn_b, b) 50 | self.add_summary('temperature', self.temperature) 51 | self.add_summary('attention_a', attn_a) 52 | self.add_summary('attention_b', attn_b) 53 | return feature_a, feature_b 54 | 55 | 56 | @register('linear') 57 | class MappedAlignment(Alignment): 58 | def __init__(self, args, input_size): 59 | super().__init__(args, input_size) 60 | self.projection = nn.Sequential( 61 | nn.Dropout(args.dropout), 62 | Linear(input_size, args.hidden_size, activations=True), 63 | ) 64 | 65 | def _attention(self, a, b): 66 | a = self.projection(a) 67 | b = self.projection(b) 68 | return super()._attention(a, b) 69 | -------------------------------------------------------------------------------- /src/modules/connection.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 Alibaba Group Holding Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import math 18 | import torch 19 | import torch.nn as nn 20 | from . import Linear 21 | from functools import partial 22 | from src.utils.registry import register 23 | registry = {} 24 | register = partial(register, registry=registry) 25 | 26 | 27 | @register('none') 28 | class NullConnection(nn.Module): 29 | def __init__(self, _): 30 | super().__init__() 31 | 32 | def forward(self, x, _, __): 33 | return x 34 | 35 | 36 | @register('residual') 37 | class Residual(nn.Module): 38 | def __init__(self, args): 39 | super().__init__() 40 | self.linear = Linear(args.embedding_dim, args.hidden_size) 41 | 42 | def forward(self, x, res, i): 43 | if i == 1: 44 | res = self.linear(res) 45 | return (x + res) * math.sqrt(0.5) 46 | 47 | 48 | @register('aug') 49 | class AugmentedResidual(nn.Module): 50 | def __init__(self, _): 51 | super().__init__() 52 | 53 | def forward(self, x, res, i): 54 | if i == 1: 55 | return torch.cat([x, res], dim=-1) # res is embedding 56 | hidden_size = x.size(-1) 57 | x = (res[:, :, :hidden_size] + x) * math.sqrt(0.5) 58 | return torch.cat([x, res[:, :, hidden_size:]], dim=-1) # latter half of res is embedding 59 | -------------------------------------------------------------------------------- /src/modules/embedding.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 Alibaba Group Holding Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as f 20 | 21 | 22 | class Embedding(nn.Module): 23 | def __init__(self, args): 24 | super().__init__() 25 | self.fix_embeddings = args.fix_embeddings 26 | self.embedding = nn.Embedding(args.num_vocab, args.embedding_dim, padding_idx=0) 27 | self.dropout = args.dropout 28 | 29 | def set_(self, value): 30 | self.embedding.weight.requires_grad = not self.fix_embeddings 31 | self.embedding.load_state_dict({'weight': torch.tensor(value)}) 32 | 33 | def forward(self, x): 34 | x = self.embedding(x) 35 | x = f.dropout(x, self.dropout, self.training) 36 | return x 37 | -------------------------------------------------------------------------------- /src/modules/encoder.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 Alibaba Group Holding Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import torch.nn as nn 18 | import torch.nn.functional as f 19 | from . import Conv1d 20 | 21 | 22 | class Encoder(nn.Module): 23 | def __init__(self, args, input_size): 24 | super().__init__() 25 | self.dropout = args.dropout 26 | self.encoders = nn.ModuleList([Conv1d( 27 | in_channels=input_size if i == 0 else args.hidden_size, 28 | out_channels=args.hidden_size, 29 | kernel_sizes=args.kernel_sizes) for i in range(args.enc_layers)]) 30 | 31 | def forward(self, x, mask): 32 | x = x.transpose(1, 2) # B x C x L 33 | mask = mask.transpose(1, 2) 34 | for i, encoder in enumerate(self.encoders): 35 | x.masked_fill_(~mask, 0.) 36 | if i > 0: 37 | x = f.dropout(x, self.dropout, self.training) 38 | x = encoder(x) 39 | x = f.dropout(x, self.dropout, self.training) 40 | return x.transpose(1, 2) # B x L x C 41 | -------------------------------------------------------------------------------- /src/modules/fusion.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 Alibaba Group Holding Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as f 20 | from functools import partial 21 | from src.utils.registry import register 22 | from . import Linear 23 | 24 | registry = {} 25 | register = partial(register, registry=registry) 26 | 27 | 28 | @register('simple') 29 | class Fusion(nn.Module): 30 | def __init__(self, args, input_size): 31 | super().__init__() 32 | self.fusion = Linear(input_size * 2, args.hidden_size, activations=True) 33 | 34 | def forward(self, x, align): 35 | return self.fusion(torch.cat([x, align], dim=-1)) 36 | 37 | 38 | @register('full') 39 | class FullFusion(nn.Module): 40 | def __init__(self, args, input_size): 41 | super().__init__() 42 | self.dropout = args.dropout 43 | self.fusion1 = Linear(input_size * 2, args.hidden_size, activations=True) 44 | self.fusion2 = Linear(input_size * 2, args.hidden_size, activations=True) 45 | self.fusion3 = Linear(input_size * 2, args.hidden_size, activations=True) 46 | self.fusion = Linear(args.hidden_size * 3, args.hidden_size, activations=True) 47 | 48 | def forward(self, x, align): 49 | x1 = self.fusion1(torch.cat([x, align], dim=-1)) 50 | x2 = self.fusion2(torch.cat([x, x - align], dim=-1)) 51 | x3 = self.fusion3(torch.cat([x, x * align], dim=-1)) 52 | x = torch.cat([x1, x2, x3], dim=-1) 53 | x = f.dropout(x, self.dropout, self.training) 54 | return self.fusion(x) 55 | -------------------------------------------------------------------------------- /src/modules/pooling.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 Alibaba Group Holding Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import torch.nn as nn 18 | 19 | 20 | class Pooling(nn.Module): 21 | def forward(self, x, mask): 22 | return x.masked_fill_(~mask, -float('inf')).max(dim=1)[0] 23 | -------------------------------------------------------------------------------- /src/modules/prediction.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 Alibaba Group Holding Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import torch 18 | import torch.nn as nn 19 | from functools import partial 20 | from src.utils.registry import register 21 | from . import Linear 22 | 23 | registry = {} 24 | register = partial(register, registry=registry) 25 | 26 | 27 | @register('simple') 28 | class Prediction(nn.Module): 29 | def __init__(self, args, inp_features=2): 30 | super().__init__() 31 | self.dense = nn.Sequential( 32 | nn.Dropout(args.dropout), 33 | Linear(args.hidden_size * inp_features, args.hidden_size, activations=True), 34 | nn.Dropout(args.dropout), 35 | Linear(args.hidden_size, args.num_classes), 36 | ) 37 | 38 | def forward(self, a, b): 39 | return self.dense(torch.cat([a, b], dim=-1)) 40 | 41 | 42 | @register('full') 43 | class AdvancedPrediction(Prediction): 44 | def __init__(self, args): 45 | super().__init__(args, inp_features=4) 46 | 47 | def forward(self, a, b): 48 | return self.dense(torch.cat([a, b, a - b, a * b], dim=-1)) 49 | 50 | 51 | @register('symmetric') 52 | class SymmetricPrediction(AdvancedPrediction): 53 | def forward(self, a, b): 54 | return self.dense(torch.cat([a, b, (a - b).abs(), a * b], dim=-1)) 55 | -------------------------------------------------------------------------------- /src/network.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 Alibaba Group Holding Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import torch 18 | from .modules import Module, ModuleList, ModuleDict 19 | from .modules.embedding import Embedding 20 | from .modules.encoder import Encoder 21 | from .modules.alignment import registry as alignment 22 | from .modules.fusion import registry as fusion 23 | from .modules.connection import registry as connection 24 | from .modules.pooling import Pooling 25 | from .modules.prediction import registry as prediction 26 | 27 | 28 | class Network(Module): 29 | def __init__(self, args): 30 | super().__init__() 31 | self.dropout = args.dropout 32 | self.embedding = Embedding(args) 33 | input_emb_size = args.embedding_dim if args.connection == 'aug' else 0 34 | self.blocks = ModuleList([ModuleDict({ 35 | 'encoder': Encoder(args, args.embedding_dim if i == 0 else input_emb_size + args.hidden_size), 36 | 'alignment': alignment[args.alignment]( 37 | args, args.embedding_dim + args.hidden_size if i == 0 else input_emb_size + args.hidden_size * 2), 38 | 'fusion': fusion[args.fusion]( 39 | args, args.embedding_dim + args.hidden_size if i == 0 else input_emb_size + args.hidden_size * 2), 40 | }) for i in range(args.blocks)]) 41 | 42 | self.connection = connection[args.connection](args) 43 | self.pooling = Pooling() 44 | self.prediction = prediction[args.prediction](args) 45 | 46 | def forward(self, inputs): 47 | a = inputs['text1'] 48 | b = inputs['text2'] 49 | mask_a = inputs['mask1'] 50 | mask_b = inputs['mask2'] 51 | 52 | a = self.embedding(a) 53 | b = self.embedding(b) 54 | res_a, res_b = a, b 55 | 56 | for i, block in enumerate(self.blocks): 57 | if i > 0: 58 | a = self.connection(a, res_a, i) 59 | b = self.connection(b, res_b, i) 60 | res_a, res_b = a, b 61 | a_enc = block['encoder'](a, mask_a) 62 | b_enc = block['encoder'](b, mask_b) 63 | a = torch.cat([a, a_enc], dim=-1) 64 | b = torch.cat([b, b_enc], dim=-1) 65 | align_a, align_b = block['alignment'](a, b, mask_a, mask_b) 66 | a = block['fusion'](a, align_a) 67 | b = block['fusion'](b, align_b) 68 | a = self.pooling(a, mask_a) 69 | b = self.pooling(b, mask_b) 70 | return self.prediction(a, b) 71 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 Alibaba Group Holding Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import os 18 | import random 19 | import json5 20 | import torch 21 | from datetime import datetime 22 | from pprint import pformat 23 | from .utils.loader import load_data 24 | from .utils.logger import Logger 25 | from .utils.params import validate_params 26 | from .model import Model 27 | from .interface import Interface 28 | 29 | 30 | class Trainer: 31 | def __init__(self, args): 32 | self.args = args 33 | self.log = Logger(self.args) 34 | 35 | def train(self): 36 | start_time = datetime.now() 37 | model, interface, states = self.build_model() 38 | train = load_data(self.args.data_dir, 'train') 39 | dev = load_data(self.args.data_dir, self.args.eval_file) 40 | self.log(f'train ({len(train)}) | {self.args.eval_file} ({len(dev)})') 41 | train_batches = interface.pre_process(train) 42 | dev_batches = interface.pre_process(dev, training=False) 43 | self.log('setup complete: {}s.'.format(str(datetime.now() - start_time).split(".")[0])) 44 | 45 | try: 46 | for epoch in range(states['start_epoch'], self.args.epochs + 1): 47 | states['epoch'] = epoch 48 | self.log.set_epoch(epoch) 49 | 50 | batches = interface.shuffle_batch(train_batches) 51 | for batch_id, batch in enumerate(batches): 52 | stats = model.update(batch) 53 | self.log.update(stats) 54 | eval_per_updates = self.args.eval_per_updates \ 55 | if model.updates > self.args.eval_warmup_steps else self.args.eval_per_updates_warmup 56 | if model.updates % eval_per_updates == 0 or (self.args.eval_epoch and batch_id + 1 == len(batches)): 57 | self.log.newline() 58 | score, dev_stats = model.evaluate(dev_batches) 59 | if score > states['best_eval']: 60 | states['best_eval'], states['best_epoch'], states['best_step'] = score, epoch, model.updates 61 | if self.args.save: 62 | model.save(states, name=model.best_model_name) 63 | self.log.log_eval(dev_stats) 64 | if self.args.save_all: 65 | model.save(states) 66 | model.save(states, name='last') 67 | if model.updates - states['best_step'] > self.args.early_stopping \ 68 | and model.updates > self.args.min_steps: 69 | self.log('[Tolerance reached. Training is stopped early.]') 70 | raise EarlyStop('[Tolerance reached. Training is stopped early.]') 71 | if stats['loss'] > self.args.max_loss: 72 | raise EarlyStop('[Loss exceeds tolerance. Unstable training is stopped early.]') 73 | if stats['lr'] < self.args.min_lr - 1e-6: 74 | raise EarlyStop('[Learning rate has decayed below min_lr. Training is stopped early.]') 75 | self.log.newline() 76 | self.log('Training complete.') 77 | except KeyboardInterrupt: 78 | self.log.newline() 79 | self.log(f'Training interrupted. Stopped early.') 80 | except EarlyStop as e: 81 | self.log.newline() 82 | self.log(str(e)) 83 | self.log(f'best dev score {states["best_eval"]} at step {states["best_step"]} ' 84 | f'(epoch {states["best_epoch"]}).') 85 | self.log(f'best eval stats [{self.log.best_eval_str}]') 86 | training_time = str(datetime.now() - start_time).split('.')[0] 87 | self.log(f'Training time: {training_time}.') 88 | states['start_time'] = str(start_time).split('.')[0] 89 | states['training_time'] = training_time 90 | return states 91 | 92 | def build_model(self): 93 | states = {} 94 | interface = Interface(self.args, self.log) 95 | self.log(f'#classes: {self.args.num_classes}; #vocab: {self.args.num_vocab}') 96 | if self.args.seed: 97 | random.seed(self.args.seed) 98 | torch.manual_seed(self.args.seed) 99 | if self.args.cuda: 100 | torch.cuda.manual_seed(self.args.seed) 101 | if self.args.deterministic: 102 | torch.backends.cudnn.deterministic = True 103 | 104 | model = Model(self.args) 105 | if self.args.pretrained_embeddings: 106 | embeddings = interface.load_embeddings() 107 | model.set_embeddings(embeddings) 108 | 109 | # set initial states 110 | states['start_epoch'] = 1 111 | states['best_eval'] = 0. 112 | states['best_epoch'] = 0 113 | states['best_step'] = 0 114 | 115 | self.log(f'trainable params: {model.num_parameters():,d}') 116 | self.log(f'trainable params (exclude embeddings): {model.num_parameters(exclude_embed=True):,d}') 117 | validate_params(self.args) 118 | with open(os.path.join(self.args.summary_dir, 'args.json5'), 'w') as f: 119 | json5.dump(self.args.__dict__, f, indent=2) 120 | self.log(pformat(vars(self.args), indent=2, width=120)) 121 | return model, interface, states 122 | 123 | 124 | class EarlyStop(Exception): 125 | pass 126 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/utils/loader.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 Alibaba Group Holding Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import os 18 | import numpy as np 19 | 20 | 21 | def load_data(data_dir, split=None): 22 | data = [] 23 | if split is None: 24 | files = [os.path.join(data_dir, file) for file in os.listdir(data_dir) if file.endswith('.txt')] 25 | else: 26 | if not split.endswith('.txt'): 27 | split += '.txt' 28 | files = [os.path.join(data_dir, f'{split}')] 29 | for file in files: 30 | with open(file) as f: 31 | for line in f: 32 | text1, text2, label = line.rstrip().split('\t') 33 | data.append({ 34 | 'text1': text1, 35 | 'text2': text2, 36 | 'target': label, 37 | }) 38 | return data 39 | 40 | 41 | def load_embeddings(file, vocab, dim, lower, mode='freq'): 42 | embedding = np.zeros((len(vocab), dim)) 43 | count = np.zeros((len(vocab), 1)) 44 | with open(file) as f: 45 | for line in f: 46 | elems = line.rstrip().split() 47 | if len(elems) != dim + 1: 48 | continue 49 | token = elems[0] 50 | if lower and mode != 'strict': 51 | token = token.lower() 52 | if token in vocab: 53 | index = vocab.index(token) 54 | vector = [float(x) for x in elems[1:]] 55 | if mode == 'freq' or mode == 'strict': 56 | if not count[index]: 57 | embedding[index] = vector 58 | count[index] = 1. 59 | elif mode == 'last': 60 | embedding[index] = vector 61 | count[index] = 1. 62 | elif mode == 'avg': 63 | embedding[index] += vector 64 | count[index] += 1. 65 | else: 66 | raise NotImplementedError('Unknown embedding loading mode: ' + mode) 67 | if mode == 'avg': 68 | inverse_mask = np.where(count == 0, 1., 0.) 69 | embedding /= count + inverse_mask 70 | return embedding.tolist() 71 | -------------------------------------------------------------------------------- /src/utils/logger.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 Alibaba Group Holding Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import os 18 | import sys 19 | import logging 20 | 21 | 22 | class Logger: 23 | def __init__(self, args): 24 | log = logging.getLogger(args.summary_dir) 25 | if not log.handlers: 26 | log.setLevel(logging.DEBUG) 27 | fh = logging.FileHandler(os.path.join(args.summary_dir, args.log_file)) 28 | fh.setLevel(logging.INFO) 29 | ch = ProgressHandler() 30 | ch.setLevel(logging.DEBUG) 31 | formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S') 32 | fh.setFormatter(formatter) 33 | ch.setFormatter(formatter) 34 | log.addHandler(fh) 35 | log.addHandler(ch) 36 | self.log = log 37 | # setup TensorBoard 38 | if args.tensorboard: 39 | from tensorboardX import SummaryWriter 40 | self.writer = SummaryWriter(os.path.join(args.summary_dir, 'viz')) 41 | self.log.info(f'TensorBoard activated.') 42 | else: 43 | self.writer = None 44 | self.log_per_updates = args.log_per_updates 45 | self.summary_per_updates = args.summary_per_updates 46 | self.grad_clipping = args.grad_clipping 47 | self.clips = 0 48 | self.train_meters = {} 49 | self.epoch = None 50 | self.best_eval = 0. 51 | self.best_eval_str = '' 52 | 53 | def set_epoch(self, epoch): 54 | self(f'Epoch: {epoch}') 55 | self.epoch = epoch 56 | 57 | @staticmethod 58 | def _format_number(x): 59 | return f'{x:.4f}' if float(x) > 1e-3 else f'{x:.4e}' 60 | 61 | def update(self, stats): 62 | updates = stats.pop('updates') 63 | summary = stats.pop('summary') 64 | if updates % self.log_per_updates == 0: 65 | self.clips += int(stats['gnorm'] > self.grad_clipping) 66 | stats_str = ' '.join(f'{key}: ' + self._format_number(val) for key, val in stats.items()) 67 | for key, val in stats.items(): 68 | if key not in self.train_meters: 69 | self.train_meters[key] = AverageMeter() 70 | self.train_meters[key].update(val) 71 | msg = f'epoch {self.epoch} updates {updates} {stats_str}' 72 | if self.log_per_updates != 1: 73 | msg = '> ' + msg 74 | self.log.info(msg) 75 | if self.writer and updates % self.summary_per_updates == 0: 76 | for key, val in stats.items(): 77 | self.writer.add_scalar(f'train/{key}', val, updates) 78 | for key, val in summary.items(): 79 | self.writer.add_histogram(key, val, updates) 80 | 81 | def newline(self): 82 | self.log.debug('') 83 | 84 | def log_eval(self, valid_stats): 85 | self.newline() 86 | updates = valid_stats.pop('updates') 87 | eval_score = valid_stats.pop('score') 88 | # report the exponential averaged training stats, while reporting the full dev set stats 89 | if self.train_meters: 90 | train_stats_str = ' '.join(f'{key}: ' + self._format_number(val) for key, val in self.train_meters.items()) 91 | train_stats_str += ' ' + f'clip: {self.clips}' 92 | self.log.info(f'train {train_stats_str}') 93 | valid_stats_str = ' '.join(f'{key}: ' + self._format_number(val) for key, val in valid_stats.items()) 94 | if eval_score > self.best_eval: 95 | self.best_eval_str = valid_stats_str 96 | self.best_eval = eval_score 97 | valid_stats_str += ' [NEW BEST]' 98 | else: 99 | valid_stats_str += f' [BEST: {self._format_number(self.best_eval)}]' 100 | self.log.info(f'valid {valid_stats_str}') 101 | if self.writer: 102 | for key in valid_stats.keys(): 103 | group = {'valid': valid_stats[key]} 104 | if self.train_meters and key in self.train_meters: 105 | group['train'] = float(self.train_meters[key]) 106 | self.writer.add_scalars(f'valid/{key}', group, updates) 107 | self.train_meters = {} 108 | self.clips = 0 109 | 110 | def __call__(self, msg): 111 | self.log.info(msg) 112 | 113 | 114 | class ProgressHandler(logging.Handler): 115 | def __init__(self, level=logging.NOTSET): 116 | super().__init__(level) 117 | 118 | def emit(self, record): 119 | log_entry = self.format(record) 120 | if record.message.startswith('> '): 121 | sys.stdout.write('{}\r'.format(log_entry.rstrip())) 122 | sys.stdout.flush() 123 | else: 124 | sys.stdout.write('{}\n'.format(log_entry)) 125 | 126 | 127 | class AverageMeter(object): 128 | """Keep exponential weighted averages.""" 129 | def __init__(self, beta=0.99): 130 | self.beta = beta 131 | self.moment = 0. 132 | self.value = 0. 133 | self.t = 0. 134 | 135 | def update(self, val): 136 | self.t += 1 137 | self.moment = self.beta * self.moment + (1 - self.beta) * val 138 | # bias correction 139 | self.value = self.moment / (1 - self.beta ** self.t) 140 | 141 | def __format__(self, spec): 142 | return format(self.value, spec) 143 | 144 | def __float__(self): 145 | return self.value 146 | -------------------------------------------------------------------------------- /src/utils/metrics.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 Alibaba Group Holding Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import os 18 | import subprocess 19 | from functools import partial 20 | import numpy as np 21 | from sklearn import metrics 22 | 23 | from .registry import register 24 | 25 | registry = {} 26 | register = partial(register, registry=registry) 27 | 28 | 29 | @register('acc') 30 | def acc(outputs): 31 | target = outputs['target'] 32 | pred = outputs['pred'] 33 | return { 34 | 'acc': metrics.accuracy_score(target, pred).item(), 35 | } 36 | 37 | 38 | @register('f1') 39 | def f1(outputs): 40 | target = outputs['target'] 41 | pred = outputs['pred'] 42 | return { 43 | 'f1': metrics.f1_score(target, pred).item(), 44 | } 45 | 46 | 47 | @register('auc') 48 | def auc(outputs): 49 | target = outputs['target'] 50 | prob = np.array(outputs['prob']) 51 | return { 52 | 'auc': metrics.roc_auc_score(target, prob[:, 1]).item(), 53 | } 54 | 55 | 56 | @register('map') 57 | @register('mrr') 58 | def ranking(outputs): 59 | args = outputs['args'] 60 | prediction = [o[1] for o in outputs['prob']] 61 | ref_file = os.path.join(args.data_dir, '{}.ref'.format(args.eval_file)) 62 | rank_file = os.path.join(args.data_dir, '{}.rank'.format(args.eval_file)) 63 | tmp_file = os.path.join(args.summary_dir, 'tmp-pred.txt') 64 | with open(rank_file) as f: 65 | prefix = [] 66 | for line in f: 67 | prefix.append(line.strip().split()) 68 | assert len(prefix) == len(prediction), \ 69 | 'prefix {}, while prediction {}'.format(len(prefix), len(prediction)) 70 | with open(tmp_file, 'w') as f: 71 | for prefix, pred in zip(prefix, prediction): 72 | prefix[-2] = str(pred) 73 | f.write(' '.join(prefix) + '\n') 74 | sp = subprocess.Popen('./resources/trec_eval {} {} | egrep "map|recip_rank"'.format(ref_file, tmp_file), 75 | shell=True, 76 | stdout=subprocess.PIPE, stderr=subprocess.PIPE) 77 | stdout, stderr = sp.communicate() 78 | stdout, stderr = stdout.decode(), stderr.decode() 79 | os.remove(tmp_file) 80 | map_, mrr = [float(s[-6:]) for s in stdout.strip().split('\n')] 81 | return { 82 | 'map': map_, 83 | 'mrr': mrr, 84 | } 85 | -------------------------------------------------------------------------------- /src/utils/params.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 Alibaba Group Holding Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import os 18 | import math 19 | import shutil 20 | from datetime import datetime 21 | import torch 22 | import json5 23 | 24 | 25 | class Object: 26 | """ 27 | @DynamicAttrs 28 | """ 29 | pass 30 | 31 | 32 | def parse(config_file): 33 | root = os.path.dirname(config_file) # __parent__ in config is a relative path 34 | config_group = _load_param('', config_file) 35 | if type(config_group) is dict: 36 | config_group = [config_group] 37 | configs = [] 38 | for config in config_group: 39 | try: 40 | choice = config.pop('__iter__') 41 | assert len(choice) == 1, 'only support iterating over 1 variable' 42 | key, values = next(iter(choice.items())) 43 | except KeyError: 44 | key, value = config.popitem() 45 | values = [value] 46 | for value in values: 47 | config[key] = value 48 | repeat = config.get('__repeat__', 1) 49 | for index in range(repeat): 50 | config_ = config.copy() 51 | config_['__index__'] = index 52 | if repeat > 1: 53 | config_['name'] += '-' + str(index) 54 | args = _parse_args(root, config_) 55 | configs.append((args, config_)) 56 | return configs 57 | 58 | 59 | def _parse_args(root, config): 60 | args = Object() 61 | assert type(config) is dict 62 | parents = config.get('__parents__', []) 63 | for parent in parents: 64 | parent = _load_param(root, parent) 65 | assert type(parent) is dict, 'only top-level configs can be a sequence' 66 | _add_param(args, parent) 67 | _add_param(args, config) 68 | _post_process(args) 69 | return args 70 | 71 | 72 | def _add_param(args, x: dict): 73 | for k, v in x.items(): 74 | if type(v) is dict: 75 | _add_param(args, v) 76 | else: 77 | k = _validate_param(k) 78 | if hasattr(args, k): 79 | previous_type = type(getattr(args, k)) 80 | current_type = type(v) 81 | assert previous_type is current_type or \ 82 | isinstance(None, previous_type) or \ 83 | isinstance(None, current_type) or \ 84 | (previous_type is float and current_type is int), \ 85 | f'param "{k}" of type {previous_type} can not be overwritten by type {current_type}' 86 | setattr(args, k, v) 87 | 88 | 89 | def _load_param(root, file: str): 90 | file = os.path.join(root, file) 91 | if not file.endswith('.json5'): 92 | file += '.json5' 93 | with open(file) as f: 94 | config = json5.load(f) 95 | return config 96 | 97 | 98 | def _post_process(args: Object): 99 | if not args.output_dir.startswith('models'): 100 | args.output_dir = os.path.join('models', args.output_dir) 101 | os.makedirs(args.output_dir, exist_ok=True) 102 | if not args.name: 103 | args.name = str(datetime.now()) 104 | args.summary_dir = os.path.join(args.output_dir, args.name) 105 | if os.path.exists(args.summary_dir): 106 | shutil.rmtree(args.summary_dir) 107 | os.makedirs(args.summary_dir) 108 | data_config_file = os.path.join(args.output_dir, 'data_config.json5') 109 | if os.path.exists(data_config_file): 110 | with open(data_config_file) as f: 111 | config = json5.load(f) 112 | for k, v in config.items(): 113 | if not hasattr(args, k) or getattr(args, k) != v: 114 | print('ERROR: Data configurations are different. Please use another output_dir or ' 115 | 'remove the older one manually.') 116 | exit() 117 | else: 118 | with open(data_config_file, 'w') as f: 119 | keys = ['data_dir', 'min_df', 'max_vocab', 'max_len', 'min_len', 'lower_case', 120 | 'pretrained_embeddings', 'embedding_mode'] 121 | json5.dump({k: getattr(args, k) for k in keys}, f) 122 | args.metric = args.metric.lower() 123 | args.watch_metrics = [m.lower() for m in args.watch_metrics] 124 | if args.metric not in args.watch_metrics: 125 | args.watch_metrics.append(args.metric) 126 | args.cuda = args.cuda and torch.cuda.is_available() 127 | args.fix_embeddings = args.pretrained_embeddings and args.fix_embeddings 128 | 129 | def samples2steps(n): 130 | return int(math.ceil(n / args.batch_size)) 131 | 132 | if not hasattr(args, 'log_per_updates'): 133 | args.log_per_updates = samples2steps(args.log_per_samples) 134 | if not hasattr(args, 'eval_per_updates'): 135 | args.eval_per_updates = samples2steps(args.eval_per_samples) 136 | if not hasattr(args, 'eval_per_updates_warmup'): 137 | args.eval_per_updates_warmup = samples2steps(args.eval_per_samples_warmup) 138 | if not hasattr(args, 'eval_warmup_steps'): 139 | args.eval_warmup_steps = samples2steps(args.eval_warmup_samples) 140 | if not hasattr(args, 'min_steps'): 141 | args.min_steps = samples2steps(args.min_samples) 142 | if not hasattr(args, 'early_stopping'): 143 | args.early_stopping = samples2steps(args.tolerance_samples) 144 | if not hasattr(args, 'lr_warmup_steps'): 145 | args.lr_warmup_steps = samples2steps(args.lr_warmup_samples) 146 | if not hasattr(args, 'lr_decay_steps'): 147 | args.lr_decay_steps = samples2steps(args.lr_decay_samples) 148 | if not hasattr(args, 'summary_per_updates'): 149 | args.summary_per_updates = args.summary_per_logs * args.log_per_updates 150 | assert args.lr >= args.min_lr, 'initial learning rate must be larger than min_lr' 151 | 152 | 153 | def validate_params(args): 154 | """validate params after interface initialization""" 155 | assert args.num_classes == 2 or ('f1' not in args.watch_metrics and 'auc' not in args.watch_metrics), \ 156 | f'F1 and AUC are only valid for 2 classes.' 157 | assert args.num_classes == 2 or 'ranking' not in args.watch_metrics, \ 158 | f'ranking metrics are only valid for 2 classes.' 159 | assert args.num_vocab > 0 160 | 161 | 162 | def _validate_param(name): 163 | name = name.replace('-', '_') 164 | if not str.isidentifier(name): 165 | raise ValueError(f'Invalid param name: {name}') 166 | return name 167 | -------------------------------------------------------------------------------- /src/utils/registry.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 Alibaba Group Holding Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | def register(name=None, registry=None): 18 | def decorator(fn, registration_name=None): 19 | module_name = registration_name or _default_name(fn) 20 | if module_name in registry: 21 | raise LookupError(f"module {module_name} already registered.") 22 | registry[module_name] = fn 23 | return fn 24 | return lambda fn: decorator(fn, name) 25 | 26 | 27 | def _default_name(obj_class): 28 | return obj_class.__name__ 29 | -------------------------------------------------------------------------------- /src/utils/vocab.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 Alibaba Group Holding Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | from collections import Counter 18 | 19 | 20 | class Indexer: 21 | def __init__(self): 22 | self.w2id = {} 23 | self.id2w = {} 24 | 25 | @property 26 | def n_spec(self): 27 | return 0 28 | 29 | def __len__(self): 30 | return len(self.w2id) 31 | 32 | def __getitem__(self, index): 33 | if index not in self.id2w: 34 | raise IndexError(f'invalid index {index} in indices.') 35 | return self.id2w[index] 36 | 37 | def __contains__(self, item): 38 | return item in self.w2id 39 | 40 | def index(self, symbol): 41 | if symbol in self.w2id: 42 | return self.w2id[symbol] 43 | raise IndexError(f'Unknown symbol {symbol}') 44 | 45 | def keys(self): 46 | return self.w2id.keys() 47 | 48 | def indices(self): 49 | return self.id2w.keys() 50 | 51 | def add_symbol(self, symbol): 52 | if symbol not in self.w2id: 53 | self.id2w[len(self.id2w)] = symbol 54 | self.w2id[symbol] = len(self.w2id) 55 | 56 | @classmethod 57 | def build(cls, symbols, min_counts=1, dump_filtered=None, log=print): 58 | counter = Counter(symbols) 59 | symbols = sorted([t for t, c in counter.items() if c >= min_counts], 60 | key=counter.get, reverse=True) 61 | log(f'''{len(symbols)} symbols found: {' '.join(symbols[:15]) + ('...' if len(symbols) > 15 else '')}''') 62 | filtered = sorted(list(counter.keys() - set(symbols)), key=counter.get, reverse=True) 63 | if filtered: 64 | log('filtered classes:') 65 | if len(filtered) > 20: 66 | log('{} ... {}'.format(' '.join(filtered[:10]), ' '.join(filtered[-10:]))) 67 | else: 68 | log(' '.join(filtered)) 69 | if dump_filtered: 70 | with open(dump_filtered, 'w') as f: 71 | for name in filtered: 72 | f.write(f'{name} {counter.get(name)}\n') 73 | indexer = cls() 74 | try: # restore numeric order if labels are represented by integers already 75 | symbols = list(map(int, symbols)) 76 | symbols.sort() 77 | symbols = list(map(str, symbols)) 78 | except ValueError: 79 | pass 80 | for symbol in symbols: 81 | if symbol: 82 | indexer.add_symbol(symbol) 83 | return indexer 84 | 85 | def save(self, file): 86 | with open(file, 'w') as f: 87 | for symbol, index in self.w2id.items(): 88 | if index < self.n_spec: 89 | continue 90 | f.write('{}\n'.format(symbol)) 91 | 92 | @classmethod 93 | def load(cls, file): 94 | indexer = cls() 95 | with open(file) as f: 96 | for line in f: 97 | symbol = line.rstrip() 98 | assert len(symbol) > 0, 'Empty symbol encountered.' 99 | indexer.add_symbol(symbol) 100 | return indexer 101 | 102 | 103 | class RobustIndexer(Indexer): 104 | def __init__(self, validate=True): 105 | super().__init__() 106 | self.w2id.update({self.unk_symbol(): self.unk()}) 107 | self.id2w = {i: w for w, i in self.w2id.items()} 108 | if validate: 109 | self.validate_spec() 110 | 111 | @property 112 | def n_spec(self): 113 | return 1 114 | 115 | def index(self, symbol): 116 | return self.w2id[symbol] if symbol in self.w2id else self.unk() 117 | 118 | @staticmethod 119 | def unk(): 120 | return 0 121 | 122 | @staticmethod 123 | def unk_symbol(): 124 | return '' 125 | 126 | def validate_spec(self): 127 | assert self.n_spec == len(self.w2id), f'{self.n_spec}, {len(self.w2id)}' 128 | assert len(self.w2id) == max(self.id2w.keys()) + 1, "empty indices found in special tokens" 129 | assert len(self.w2id) == len(self.id2w), "index conflict in special tokens" 130 | 131 | 132 | class Vocab(RobustIndexer): 133 | def __init__(self): 134 | super().__init__(validate=False) 135 | self.w2id.update({ 136 | self.pad_symbol(): self.pad(), 137 | }) 138 | self.id2w = {i: w for w, i in self.w2id.items()} 139 | self.validate_spec() 140 | 141 | @classmethod 142 | def build(cls, words, lower=False, min_df=1, max_tokens=float('inf'), pretrained_embeddings=None, 143 | dump_filtered=None, log=print): 144 | if pretrained_embeddings: 145 | wv_vocab = cls.load_embedding_vocab(pretrained_embeddings, lower) 146 | else: 147 | wv_vocab = set() 148 | if lower: 149 | words = (word.lower() for word in words) 150 | counter = Counter(words) 151 | candidate_tokens = sorted([t for t, c in counter.items() if t in wv_vocab or c >= min_df], 152 | key=counter.get, reverse=True) 153 | if len(candidate_tokens) > max_tokens: 154 | tokens = [] 155 | for i, token in enumerate(candidate_tokens): 156 | if i < max_tokens: 157 | tokens.append(token) 158 | elif token in wv_vocab: 159 | tokens.append(token) 160 | else: 161 | tokens = candidate_tokens 162 | total = sum(counter.values()) 163 | matched = sum(counter[t] for t in tokens) 164 | stats = (len(tokens), len(counter), total - matched, total, (total - matched) / total * 100) 165 | log('vocab coverage {}/{} | OOV occurrences {}/{} ({:.4f}%)'.format(*stats)) 166 | tokens_set = set(tokens) 167 | if pretrained_embeddings: 168 | oop_samples = sorted(list(tokens_set - wv_vocab), key=counter.get, reverse=True) 169 | log('Covered by pretrained vectors {:.4f}%. '.format(len(tokens_set & wv_vocab) / len(tokens) * 100) + 170 | ('outside pretrained: ' + ' '.join(oop_samples[:10]) + ' ...' if len(oop_samples) > 10 else '') 171 | if oop_samples else '') 172 | log('top words:\n{}'.format(' '.join(tokens[:10]))) 173 | filtered = sorted(list(counter.keys() - set(tokens)), key=counter.get, reverse=True) 174 | if filtered: 175 | if len(filtered) > 20: 176 | log('filtered words:\n{} ... {}'.format(' '.join(filtered[:10]), ' '.join(filtered[-10:]))) 177 | else: 178 | log('filtered words:\n' + ' '.join(filtered)) 179 | if dump_filtered: 180 | with open(dump_filtered, 'w') as f: 181 | for name in filtered: 182 | f.write(f'{name} {counter.get(name)}\n') 183 | 184 | vocab = cls() 185 | for token in tokens: 186 | vocab.add_symbol(token) 187 | return vocab 188 | 189 | @staticmethod 190 | def load_embedding_vocab(file, lower): 191 | wv_vocab = set() 192 | with open(file) as f: 193 | for line in f: 194 | token = line.rstrip().split(' ')[0] 195 | if lower: 196 | token = token.lower() 197 | wv_vocab.add(token) 198 | return wv_vocab 199 | 200 | @staticmethod 201 | def pad(): 202 | return 0 203 | 204 | @staticmethod 205 | def unk(): 206 | return 1 207 | 208 | @property 209 | def n_spec(self): 210 | return 2 211 | 212 | @staticmethod 213 | def pad_symbol(): 214 | return '' 215 | 216 | char_map = { # escape special characters for safe serialization 217 | '\n': '', 218 | } 219 | 220 | def save(self, file): 221 | with open(file, 'w') as f: 222 | for symbol, index in self.w2id.items(): 223 | if index < self.n_spec: 224 | continue 225 | symbol = self.char_map.get(symbol, symbol) 226 | f.write(f'{symbol}\n') 227 | 228 | @classmethod 229 | def load(cls, file): 230 | vocab = cls() 231 | reverse_char_map = {v: k for k, v in cls.char_map.items()} 232 | with open(file) as f: 233 | for line in f: 234 | symbol = line.rstrip('\n') 235 | symbol = reverse_char_map.get(symbol, symbol) 236 | vocab.add_symbol(symbol) 237 | return vocab 238 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (C) 2019 Alibaba Group Holding Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | import os 18 | import sys 19 | import json5 20 | from pprint import pprint 21 | from src.utils import params 22 | from src.trainer import Trainer 23 | 24 | 25 | def main(): 26 | argv = sys.argv 27 | if len(argv) == 2: 28 | arg_groups = params.parse(sys.argv[1]) 29 | for args, config in arg_groups: 30 | trainer = Trainer(args) 31 | states = trainer.train() 32 | with open('models/log.jsonl', 'a') as f: 33 | f.write(json5.dumps({ 34 | 'data': os.path.basename(args.data_dir), 35 | 'params': config, 36 | 'state': states, 37 | })) 38 | f.write('\n') 39 | elif len(argv) == 3 and '--dry' in argv: 40 | argv.remove('--dry') 41 | arg_groups = params.parse(sys.argv[1]) 42 | pprint([args.__dict__ for args, _ in arg_groups]) 43 | else: 44 | print('Usage: "python train.py configs/xxx.json5"') 45 | 46 | 47 | if __name__ == '__main__': 48 | main() 49 | --------------------------------------------------------------------------------