├── .gitignore
├── LICENSE
├── README.md
├── configs
├── ablation.json5
├── data
│ ├── quora.json5
│ ├── scitail.json5
│ ├── snli.json5
│ └── wikiqa.json5
├── debug.json5
├── default.json5
└── main.json5
├── data
├── prepare_quora.py
├── prepare_scitail.py
├── prepare_snli.py
└── prepare_wikiqa.py
├── evaluate.py
├── figure.png
├── requirements.txt
├── src
├── __init__.py
├── evaluator.py
├── interface.py
├── model.py
├── modules
│ ├── __init__.py
│ ├── alignment.py
│ ├── connection.py
│ ├── embedding.py
│ ├── encoder.py
│ ├── fusion.py
│ ├── pooling.py
│ └── prediction.py
├── network.py
├── trainer.py
└── utils
│ ├── __init__.py
│ ├── loader.py
│ ├── logger.py
│ ├── metrics.py
│ ├── params.py
│ ├── registry.py
│ └── vocab.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | /data/*
2 | !/data/*.py
3 | /models/
4 | /resources/
5 |
6 | /.idea
7 | __pycache__/
8 | .DS_Store
9 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RE2
2 |
3 | This is a pytorch implementation of the ACL 2019 paper "Simple and Effective Text Matching with Richer Alignment Features". The original Tensorflow implementation: https://github.com/alibaba-edu/simple-effective-text-matching.
4 |
5 | ## Quick Links
6 |
7 | - [About](#simple-and-effective-text-matching)
8 | - [Setup](#setup)
9 | - [Usage](#usage)
10 |
11 | ## Simple and Effective Text Matching
12 |
13 | RE2 is a fast and strong neural architecture for general purpose text matching applications.
14 | In a text matching task, a model takes two text sequences as input and predicts their relationship.
15 | This method aims to explore what is sufficient for strong performance in these tasks.
16 | It simplifies many slow components which are previously considered as core building blocks in text matching,
17 | while keeping three key features directly available for inter-sequence alignment:
18 | original point-wise features, previous aligned features, and contextual features.
19 |
20 | RE2 achieves performance on par with the state of the art on four benchmark datasets: SNLI, SciTail, Quora and WikiQA,
21 | across tasks of natural language inference, paraphrase identification and answer selection
22 | with no or few task-specific adaptations. It has at least 6 times faster inference speed compared to similarly performed models.
23 |
24 |

25 |
26 | The following table lists major experiment results.
27 | The paper reports the average and standard deviation of 10 runs.
28 | Inference time (in seconds) is measured by processing a batch of 8 pairs of length 20 on Intel i7 CPUs.
29 | The computation time of POS features used by CSRAN and DIIN is not included.
30 |
31 | |Model|SNLI|SciTail|Quora|WikiQA|Inference Time|
32 | |---|---|---|---|---|---|
33 | |[BiMPM](https://github.com/zhiguowang/BiMPM)|86.9|-|88.2|0.731|0.05|
34 | |[ESIM](https://github.com/lukecq1231/nli)|88.0|70.6|-|-|-|
35 | |[DIIN](https://github.com/YichenGong/Densely-Interactive-Inference-Network)|88.0|-|89.1|-|1.79|
36 | |[CSRAN](https://github.com/vanzytay/EMNLP2018_NLI)|88.7|86.7|89.2|-|0.28|
37 | |RE2|88.9±0.1|86.0±0.6|89.2±0.2|0.7618 ±0.0040|0.03~0.05|
38 |
39 | Refer to the paper for more details of the components and experiment results.
40 |
41 | ## Setup
42 |
43 | - install python >= 3.6 and pip
44 | - `pip install -r requirements.txt`
45 | - install [PyTorch](https://pytorch.org)
46 | - Download [GloVe word vectors](https://nlp.stanford.edu/projects/glove/) (glove.840B.300d) to `resources/`
47 |
48 | Data used in the paper are prepared as follows:
49 |
50 | ### SNLI
51 |
52 | - Download and unzip [SNLI](https://www.dropbox.com/s/0r82spk628ksz70/SNLI.zip?dl=0)
53 | (pre-processed by [Tay et al.](https://github.com/vanzytay/EMNLP2018_NLI)) to `data/orig`.
54 | - Unzip all zip files in the "data/orig/SNLI" folder. (`cd data/orig/SNLI && gunzip *.gz`)
55 | - `cd data && python prepare_snli.py`
56 |
57 | ### SciTail
58 |
59 | - Download and unzip [SciTail](http://data.allenai.org.s3.amazonaws.com/downloads/SciTailV1.1.zip)
60 | dataset to `data/orig`.
61 | - `cd data && python prepare_scitail.py`
62 |
63 | ### Quora
64 |
65 | - Download and unzip [Quora](https://drive.google.com/file/d/0B0PlTAo--BnaQWlsZl9FZ3l1c28/view?usp=sharing)
66 | dataset (pre-processed by [Wang et al.](https://github.com/zhiguowang/BiMPM)) to `data/orig`.
67 | - `cd data && python prepare_quora.py`
68 |
69 | ### WikiQA
70 |
71 | - Download and unzip [WikiQA](https://www.microsoft.com/en-us/download/details.aspx?id=52419)
72 | to `data/orig`.
73 | - `cd data && python prepare_wikiqa.py`
74 | - Download and unzip [evaluation scripts](http://cs.stanford.edu/people/mengqiu/data/qg-emnlp07-data.tgz).
75 | Use the `make -B` command to compile the source files in `qg-emnlp07-data/eval/trec_eval-8.0`.
76 | Move the binary file "trec_eval" to `resources/`.
77 |
78 | ## Usage
79 |
80 | To train a new text matching model, run the following command:
81 |
82 | ```bash
83 | python train.py $config_file.json5
84 | ```
85 |
86 | Example configuration files are provided in `configs/`:
87 |
88 | - `configs/main.json5`: replicate the main experiment result in the paper.
89 | - `configs/robustness.json5`: robustness checks
90 | - `configs/ablation.json5`: ablation study
91 |
92 | The instructions to write your own configuration files:
93 |
94 | ```json5
95 | [
96 | {
97 | name: 'exp1', // name of your experiment, can be the same across different data
98 | __parents__: [
99 | 'default', // always put the default on top
100 | 'data/quora', // data specific configurations in `configs/data`
101 | // 'debug', // use "debug" to quick debug your code
102 | ],
103 | __repeat__: 5, // how may repetitions you want
104 | blocks: 3, // other configurations for this experiment
105 | },
106 | // multiple configurations are executed sequentially
107 | {
108 | name: 'exp2', // results under the same name will be overwritten
109 | __parents__: [
110 | 'default',
111 | 'data/quora',
112 | ],
113 | __repeat__: 5,
114 | blocks: 4,
115 | }
116 | ]
117 | ```
118 |
119 | To check the configurations only, use
120 |
121 | ```bash
122 | python train.py $config_file.json5 --dry
123 | ```
124 |
125 | To evaluate an existed model, use `python evaluate.py $model_path $data_file`, here's an example:
126 |
127 | ```bash
128 | python evaluate.py models/snli/benchmark/best.pt data/snli/train.txt
129 | python evaluate.py models/snli/benchmark/best.pt data/snli/test.txt
130 | ```
131 |
132 | > Note that multi-GPU training is not yet supported in the pytorch implementation. A single 16G GPU is sufficient for training when blocks < 5 with hidden size 200 and batch size 512. All the results reported in the paper except the robustness checks can be reproduced with a single 16G GPU.
133 |
134 | ## Citation
135 |
136 | Please cite the ACL paper if you use RE2 in your work:
137 |
138 | ```
139 | @inproceedings{yang2019simple,
140 | title={Simple and Effective Text Matching with Richer Alignment Features},
141 | author={Yang, Runqi and Zhang, Jianhai and Gao, Xing and Ji, Feng and Chen, Haiqing},
142 | booktitle={Association for Computational Linguistics (ACL)},
143 | year={2019}
144 | }
145 | ```
146 |
147 | ## License
148 | This project is under Apache License 2.0.
149 |
--------------------------------------------------------------------------------
/configs/ablation.json5:
--------------------------------------------------------------------------------
1 | [
2 | // original version
3 | {
4 | name: 'original',
5 | __parents__: [
6 | 'default',
7 | 'data/snli',
8 | ],
9 | __repeat__: 10,
10 | },
11 | {
12 | name: 'original',
13 | __parents__: [
14 | 'default',
15 | 'data/scitail',
16 | ],
17 | __repeat__: 10,
18 | },
19 | {
20 | name: 'original',
21 | __parents__: [
22 | 'default',
23 | 'data/quora',
24 | ],
25 | __repeat__: 10,
26 | },
27 | {
28 | name: 'original',
29 | __parents__: [
30 | 'default',
31 | 'data/wikiqa',
32 | ],
33 | __repeat__: 10,
34 | },
35 | // alignment alternative
36 | {
37 | name: 'alignment-alt',
38 | __parents__: [
39 | 'default',
40 | 'data/snli',
41 | ],
42 | alignment: 'identity',
43 | __repeat__: 10,
44 | },
45 | {
46 | name: 'alignment-alt',
47 | __parents__: [
48 | 'default',
49 | 'data/scitail',
50 | ],
51 | __repeat__: 10,
52 | alignment: 'linear'
53 | },
54 | {
55 | name: 'alignment-alt',
56 | __parents__: [
57 | 'default',
58 | 'data/quora',
59 | ],
60 | __repeat__: 10,
61 | alignment: 'identity'
62 | },
63 | {
64 | name: 'alignment-alt',
65 | __parents__: [
66 | 'default',
67 | 'data/wikiqa',
68 | ],
69 | __repeat__: 10,
70 | alignment: 'identity'
71 | },
72 | // prediction alternative
73 | {
74 | name: 'prediction-alt',
75 | __parents__: [
76 | 'default',
77 | 'data/snli',
78 | ],
79 | prediction: 'simple',
80 | __repeat__: 10,
81 | },
82 | {
83 | name: 'prediction-alt',
84 | __parents__: [
85 | 'default',
86 | 'data/scitail',
87 | ],
88 | __repeat__: 10,
89 | prediction: 'simple',
90 | },
91 | {
92 | name: 'prediction-alt',
93 | __parents__: [
94 | 'default',
95 | 'data/quora',
96 | ],
97 | __repeat__: 10,
98 | prediction: 'simple',
99 | },
100 | {
101 | name: 'prediction-alt',
102 | __parents__: [
103 | 'default',
104 | 'data/wikiqa',
105 | ],
106 | __repeat__: 10,
107 | prediction: 'full',
108 | },
109 | // residual connection
110 | {
111 | name: 'residual-conn',
112 | __parents__: [
113 | 'default',
114 | 'data/snli',
115 | ],
116 | connection: 'residual',
117 | __repeat__: 10,
118 | },
119 | {
120 | name: 'residual-conn',
121 | __parents__: [
122 | 'default',
123 | 'data/scitail',
124 | ],
125 | __repeat__: 10,
126 | connection: 'residual'
127 | },
128 | {
129 | name: 'residual-conn',
130 | __parents__: [
131 | 'default',
132 | 'data/quora',
133 | ],
134 | __repeat__: 10,
135 | connection: 'residual'
136 | },
137 | {
138 | name: 'residual-conn',
139 | __parents__: [
140 | 'default',
141 | 'data/wikiqa',
142 | ],
143 | __repeat__: 10,
144 | connection: 'residual'
145 | },
146 | // simple fusion
147 | {
148 | name: 'simple-fusion',
149 | __parents__: [
150 | 'default',
151 | 'data/snli',
152 | ],
153 | __repeat__: 10,
154 | fusion: 'simple'
155 | },
156 | {
157 | name: 'simple-fusion',
158 | __parents__: [
159 | 'default',
160 | 'data/scitail',
161 | ],
162 | __repeat__: 10,
163 | fusion: 'simple'
164 | },
165 | {
166 | name: 'simple-fusion',
167 | __parents__: [
168 | 'default',
169 | 'data/quora',
170 | ],
171 | __repeat__: 10,
172 | fusion: 'simple'
173 | },
174 | {
175 | name: 'simple-fusion',
176 | __parents__: [
177 | 'default',
178 | 'data/wikiqa',
179 | ],
180 | __repeat__: 10,
181 | fusion: 'simple'
182 | }
183 | ]
--------------------------------------------------------------------------------
/configs/data/quora.json5:
--------------------------------------------------------------------------------
1 | {
2 | data_dir: 'data/quora',
3 | output_dir: 'quora',
4 | metric: 'acc',
5 |
6 | model: {
7 | enc_layers: 2,
8 | blocks: 2,
9 | prediction: 'symmetric',
10 | hidden_size: 200,
11 | max_len: 100,
12 | },
13 |
14 | routine: {
15 | eval_per_samples: 12800,
16 | eval_warmup_samples: 3584000,
17 | eval_per_samples_warmup: 512000,
18 | min_samples: 5120000,
19 | tolerance_samples: 2560000,
20 | },
21 |
22 | optim: {
23 | lr: 0.0012,
24 | min_lr: 6e-5,
25 | lr_decay_samples: 256000,
26 | batch_size: 512,
27 | lr_warmup_samples: 0,
28 | },
29 | }
--------------------------------------------------------------------------------
/configs/data/scitail.json5:
--------------------------------------------------------------------------------
1 | {
2 | data_dir: 'data/scitail',
3 | output_dir: 'scitail',
4 | metric: 'acc',
5 |
6 | model: {
7 | alignment: 'identity',
8 | enc_layers: 3,
9 | blocks: 2,
10 | hidden_size: 200,
11 | },
12 |
13 | routine: {
14 | epochs: 80,
15 | log_per_samples: 1280,
16 | eval_per_samples: 6400,
17 | },
18 |
19 | optim: { // fixed learning rate
20 | lr: 0.001,
21 | lr_warmup_samples: 0,
22 | lr_decay_rate: 1.0,
23 | },
24 | }
--------------------------------------------------------------------------------
/configs/data/snli.json5:
--------------------------------------------------------------------------------
1 | {
2 | data_dir: 'data/snli',
3 | output_dir: 'snli',
4 | metric: 'acc',
5 | watch_metrics: [],
6 |
7 | model: {
8 | enc_layers: 2,
9 | blocks: 3,
10 | },
11 |
12 | routine: {
13 | eval_per_samples: 12800,
14 | eval_warmup_samples: 5120000,
15 | eval_per_samples_warmup: 512000,
16 | min_samples: 5120000,
17 | tolerance_samples: 2560000,
18 | },
19 |
20 | optim: {
21 | lr: 0.002,
22 | min_lr: 1e-4,
23 | lr_decay_samples: 256000,
24 | lr_decay_rate: 0.94,
25 | batch_size: 512,
26 | lr_warmup_samples: 2048000,
27 | },
28 | }
--------------------------------------------------------------------------------
/configs/data/wikiqa.json5:
--------------------------------------------------------------------------------
1 | {
2 | data_dir: 'data/wikiqa',
3 | output_dir: 'wikiqa',
4 | metric: 'mrr',
5 | watch_metrics: ['map'],
6 |
7 | model: {
8 | enc_layers: 3,
9 | blocks: 2,
10 | hidden_size: 200,
11 | prediction: 'simple',
12 | },
13 |
14 | routine: {
15 | log_per_samples: 256,
16 | eval_per_samples: 1280,
17 | tolerance_samples: 256000,
18 | eval_epoch: false,
19 | },
20 |
21 | optim: {
22 | lr: 0.001,
23 | lr_decay_rate: 1.0,
24 | batch_size: 128,
25 | },
26 | }
--------------------------------------------------------------------------------
/configs/debug.json5:
--------------------------------------------------------------------------------
1 | {
2 | batch_size: 8,
3 | blocks: 2,
4 | tensorboard: true,
5 | log_per_updates: 2,
6 | summary_per_logs: 1,
7 | eval_subset: 100,
8 | eval_per_updates: 50,
9 | eval_warmup_samples: 0,
10 | save_all: true,
11 | sort_by_len: true,
12 | seed: 123,
13 | pretrained_embeddings: 'resources/glove.6B.300d.txt',
14 | }
--------------------------------------------------------------------------------
/configs/default.json5:
--------------------------------------------------------------------------------
1 | {
2 | basic: {
3 | output_dir: 'default',
4 | seed: null,
5 | cuda: true,
6 | multi_gpu: false,
7 | deterministic: true, // GPU deterministic mode, will slow down training
8 | },
9 |
10 | data: {
11 | data_dir: null,
12 | min_df: 5,
13 | max_vocab: 999999, // capacity for words including out of embedding words
14 | max_len: 999, // large enough number, treated as unlimited
15 | min_len: 1,
16 | lower_case: true, // whether to treat the data and embedding as lowercase.
17 | sort_by_len: false,
18 | pretrained_embeddings: 'resources/glove.840B.300d.txt',
19 | embedding_dim: 300,
20 | embedding_mode: 'freq', // (options: 'freq', 'last', 'avg', 'strict') what to do when duplicated embedding tokens (after normalization) are found.
21 | },
22 |
23 | model: {
24 | hidden_size: 150,
25 | dropout: 0.2,
26 | blocks: 2,
27 | fix_embeddings: true,
28 | encoder: {
29 | encoder: 'cnn', // cnn, lstm
30 | enc_layers: 2,
31 | kernel_sizes: [3],
32 | },
33 | alignment: 'linear', // linear, identity
34 | fusion: 'full', // full, simple
35 | connection: 'aug', // aug, residual
36 | prediction: 'full', // full, symmetric, simple
37 |
38 | },
39 |
40 | logging: {
41 | log_file: 'log.txt',
42 | log_per_samples: 5120,
43 | summary_per_logs: 20,
44 | tensorboard: true,
45 | },
46 |
47 | training: {
48 | epochs: 30,
49 | batch_size: 128,
50 | grad_clipping: 5,
51 | weight_decay: 0,
52 | lr: 1e-3,
53 | beta1: 0.9,
54 | beta2: 0.999,
55 | max_loss: 999., // tolerance for unstable training
56 | lr_decay_rate: 0.95, // exp decay rate for lr
57 | lr_decay_samples: 128000,
58 | min_lr: 6e-5,
59 | lr_warmup_samples: 0, // linear warmup steps for lr
60 | },
61 |
62 | evaluation: {
63 | // available metrics: acc, auc, f1, map, mrr
64 | metric: 'acc', // for early stopping
65 | watch_metrics: ['auc', 'f1'], // shown in logs
66 | eval_file: 'dev',
67 | eval_per_samples: 6400,
68 | eval_per_samples_warmup: 40000,
69 | eval_warmup_samples: 0, // after this many steps warmup mode for eval ends
70 | min_samples: 0, // train at least these many steps, not affected by early stopping
71 | tolerance_samples: 400000, // early stopping
72 | eval_epoch: true, // eval after epoch
73 | eval_subset: null,
74 | },
75 |
76 | persistence: {
77 | resume: null,
78 | save: true,
79 | save_all: false,
80 | },
81 | }
--------------------------------------------------------------------------------
/configs/main.json5:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | name: 'benchmark',
4 | __parents__: [
5 | 'default',
6 | 'data/snli',
7 | ],
8 | __repeat__: 10,
9 | eval_file: 'test',
10 | },
11 | {
12 | name: 'benchmark',
13 | __parents__: [
14 | 'default',
15 | 'data/scitail',
16 | ],
17 | __repeat__: 10,
18 | eval_file: 'test',
19 | },
20 | {
21 | name: 'benchmark',
22 | __parents__: [
23 | 'default',
24 | 'data/quora',
25 | ],
26 | __repeat__: 10,
27 | eval_file: 'test',
28 | },
29 | {
30 | name: 'benchmark',
31 | __parents__: [
32 | 'default',
33 | 'data/wikiqa',
34 | ],
35 | __repeat__: 10,
36 | eval_file: 'test',
37 | },
38 | ]
--------------------------------------------------------------------------------
/data/prepare_quora.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 Alibaba Group Holding Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import os
18 | from tqdm import tqdm
19 |
20 |
21 | print('processing quora')
22 | os.makedirs('quora', exist_ok=True)
23 | # use the partition on https://zhiguowang.github.io
24 | for split in ('train', 'dev', 'test'):
25 | with open('orig/Quora_question_pair_partition/{}.tsv'.format(split)) as f, \
26 | open('quora/{}.txt'.format(split), 'w') as fout:
27 | n_lines = 0
28 | for _ in f:
29 | n_lines += 1
30 | f.seek(0)
31 | for line in tqdm(f, total=n_lines, leave=False):
32 | elements = line.rstrip().split('\t')
33 | fout.write('{}\t{}\t{}\n'.format(elements[1], elements[2], int(elements[0])))
34 |
--------------------------------------------------------------------------------
/data/prepare_scitail.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 Alibaba Group Holding Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import re
18 | import os
19 | import json
20 | from tqdm import tqdm
21 | from nltk.tokenize import TweetTokenizer
22 |
23 |
24 | tokenizer = TweetTokenizer()
25 | label_map = {
26 | 'entailment': 0,
27 | 'neutral': 1,
28 | 'contradiction': 2,
29 | }
30 |
31 |
32 | def tokenize(string):
33 | string = ' '.join(tokenizer.tokenize(string))
34 | string = re.sub(r"[-.#\"/]", " ", string)
35 | string = re.sub(r"\'(?!(s|m|ve|t|re|d|ll)( |$))", " ", string)
36 | string = re.sub(r"\'s", " \'s", string)
37 | string = re.sub(r"\'m", " \'m", string)
38 | string = re.sub(r"\'ve", " \'ve", string)
39 | string = re.sub(r"n\'t", " n\'t", string)
40 | string = re.sub(r"\'re", " \'re", string)
41 | string = re.sub(r"\'d", " \'d", string)
42 | string = re.sub(r"\'ll", " \'ll", string)
43 | string = re.sub(r"\s{2,}", " ", string)
44 | return string.strip()
45 |
46 |
47 | os.makedirs('scitail', exist_ok=True)
48 |
49 |
50 | for split in ['train', 'dev', 'test']:
51 | print('processing SciTail', split)
52 | with open('orig/SciTailV1.1/snli_format/scitail_1.0_{}.txt'.format(split)) as f, \
53 | open('scitail/{}.txt'.format(split), 'w', encoding='utf8') as fout:
54 | n_lines = 0
55 | for _ in f:
56 | n_lines += 1
57 | f.seek(0)
58 | for line in tqdm(f, total=n_lines, desc=split, leave=False):
59 | sample = json.loads(line)
60 | sentence1 = tokenize(sample['sentence1'])
61 | sentence2 = tokenize(sample['sentence2'])
62 | label = sample["gold_label"]
63 | assert label in label_map
64 | label = label_map[label]
65 | fout.write('{}\t{}\t{}\n'.format(sentence1, sentence2, label))
66 |
--------------------------------------------------------------------------------
/data/prepare_snli.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 Alibaba Group Holding Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import os
18 | import json
19 | import string
20 | import numpy as np
21 | import msgpack
22 | from collections import Counter
23 |
24 | in_dir = 'orig/SNLI'
25 | out_dir = '../models/snli/'
26 | data_dir = 'snli'
27 | label_map = {2: '0', 1: '1', 0: '2'}
28 |
29 | os.makedirs(out_dir, exist_ok=True)
30 | os.makedirs(data_dir, exist_ok=True)
31 | with open(os.path.join(in_dir, 'env')) as f:
32 | env = json.load(f)
33 |
34 | print('convert embeddings ...')
35 | emb = np.load(os.path.join(in_dir, 'emb_glove_300.npy'))
36 | print(len(emb))
37 | with open(os.path.join(out_dir, 'embedding.msgpack'), 'wb') as f:
38 | msgpack.dump(emb.tolist(), f)
39 |
40 | print('convert_vocab ...')
41 | w2idx = env['word_index']
42 | print(len(w2idx))
43 | idx2w = {i: w for w, i in w2idx.items()}
44 | with open(os.path.join(out_dir, 'vocab.txt'), 'w') as f:
45 | for index in range(len(idx2w)):
46 | if index >= 2:
47 | f.write('{}\n'.format(idx2w[index]))
48 | with open(os.path.join(out_dir, 'target_map.txt'), 'w') as f:
49 | for label in (0, 1, 2):
50 | f.write('{}\n'.format(label))
51 |
52 | # save data files
53 | punctuactions = set(string.punctuation)
54 | for split in ['train', 'dev', 'test']:
55 | labels = Counter()
56 | print('convert', split, '...')
57 | data = env[split]
58 | with open(os.path.join(data_dir, f'{split}.txt'), 'w') as f_out:
59 | for sample in data:
60 | a, b, label = sample
61 | a = a[1:-1]
62 | b = b[1:-1]
63 | a = [w.lower() for w in a if w and w not in punctuactions]
64 | b = [w.lower() for w in b if w and w not in punctuactions]
65 | assert all(w in w2idx for w in a) and all(w in w2idx for w in b)
66 | a = ' '.join(a)
67 | b = ' '.join(b)
68 | assert len(a) != 0 and len(b) != 0
69 | labels.update({label: 1})
70 | assert label in label_map
71 | label = label_map[label]
72 | f_out.write(f'{a}\t{b}\t{label}\n')
73 | print('labels:', labels)
74 |
--------------------------------------------------------------------------------
/data/prepare_wikiqa.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 Alibaba Group Holding Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import os
18 | from shutil import copyfile
19 |
20 |
21 | def copy(src, tgt):
22 | copyfile(os.path.abspath(src), os.path.abspath(tgt))
23 |
24 |
25 | os.makedirs('wikiqa', exist_ok=True)
26 |
27 |
28 | copy('orig/WikiQACorpus/WikiQA-dev-filtered.ref', 'wikiqa/dev.ref')
29 | copy('orig/WikiQACorpus/WikiQA-test-filtered.ref', 'wikiqa/test.ref')
30 | copy('orig/WikiQACorpus/emnlp-table/WikiQA.CNN.dev.rank', 'wikiqa/dev.rank')
31 | copy('orig/WikiQACorpus/emnlp-table/WikiQA.CNN.test.rank', 'wikiqa/test.rank')
32 | for split in ['train', 'dev', 'test']:
33 | print('processing WikiQA', split)
34 | copy('orig/WikiQACorpus/WikiQA-{}.txt'.format(split), 'wikiqa/{}.txt'.format(split))
35 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 Alibaba Group Holding Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import sys
18 | from src.evaluator import Evaluator
19 |
20 |
21 | def main():
22 | argv = sys.argv
23 | if len(argv) == 3:
24 | model_path, data_file = argv[1:]
25 | evaluator = Evaluator(model_path, data_file)
26 | evaluator.evaluate()
27 | else:
28 | print('Usage: "python evaluate.py $model_path $data_file"')
29 |
30 |
31 | if __name__ == '__main__':
32 | main()
33 |
--------------------------------------------------------------------------------
/figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba-edu/simple-effective-text-matching-pytorch/05d572e30801b235e989c78c95dd24d5f5d35f74/figure.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm
2 | nltk
3 | numpy
4 | scikit-learn
5 | msgpack-python
6 | tensorboardX
7 | json5
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba-edu/simple-effective-text-matching-pytorch/05d572e30801b235e989c78c95dd24d5f5d35f74/src/__init__.py
--------------------------------------------------------------------------------
/src/evaluator.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 Alibaba Group Holding Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import os
18 | from pprint import pprint
19 | from .model import Model
20 | from .interface import Interface
21 | from .utils.loader import load_data
22 |
23 |
24 | class Evaluator:
25 | def __init__(self, model_path, data_file):
26 | self.model_path = model_path
27 | self.data_file = data_file
28 |
29 | def evaluate(self):
30 | data = load_data(*os.path.split(self.data_file))
31 | model, checkpoint = Model.load(self.model_path)
32 | args = checkpoint['args']
33 | interface = Interface(args)
34 | batches = interface.pre_process(data, training=False)
35 | _, stats = model.evaluate(batches)
36 | pprint(stats)
37 |
--------------------------------------------------------------------------------
/src/interface.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 Alibaba Group Holding Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import os
18 | import random
19 | import msgpack
20 | from .utils.vocab import Vocab, Indexer
21 | from .utils.loader import load_data, load_embeddings
22 |
23 |
24 | class Interface:
25 | def __init__(self, args, log=None):
26 | self.args = args
27 | # build/load vocab and target map
28 | vocab_file = os.path.join(args.output_dir, 'vocab.txt')
29 | target_map_file = os.path.join(args.output_dir, 'target_map.txt')
30 | if not os.path.exists(vocab_file):
31 | data = load_data(self.args.data_dir)
32 | self.target_map = Indexer.build((sample['target'] for sample in data), log=log)
33 | self.target_map.save(target_map_file)
34 | self.vocab = Vocab.build((word for sample in data
35 | for text in (sample['text1'], sample['text2'])
36 | for word in text.split()[:self.args.max_len]),
37 | lower=args.lower_case, min_df=self.args.min_df, log=log,
38 | pretrained_embeddings=args.pretrained_embeddings,
39 | dump_filtered=os.path.join(args.output_dir, 'filtered_words.txt'))
40 | self.vocab.save(vocab_file)
41 |
42 | else:
43 | self.target_map = Indexer.load(target_map_file)
44 | self.vocab = Vocab.load(vocab_file)
45 | args.num_classes = len(self.target_map)
46 | args.num_vocab = len(self.vocab)
47 | args.padding = Vocab.pad()
48 |
49 | def load_embeddings(self):
50 | """generate embeddings suited for the current vocab or load previously cached ones."""
51 | assert self.args.pretrained_embeddings
52 | embedding_file = os.path.join(self.args.output_dir, 'embedding.msgpack')
53 | if not os.path.exists(embedding_file):
54 | embeddings = load_embeddings(self.args.pretrained_embeddings, self.vocab,
55 | self.args.embedding_dim, mode=self.args.embedding_mode,
56 | lower=self.args.lower_case)
57 | with open(embedding_file, 'wb') as f:
58 | msgpack.dump(embeddings, f)
59 | else:
60 | with open(embedding_file, 'rb') as f:
61 | embeddings = msgpack.load(f)
62 | return embeddings
63 |
64 | def pre_process(self, data, training=True):
65 | result = [self.process_sample(sample) for sample in data]
66 | if training:
67 | result = list(filter(lambda x: len(x['text1']) < self.args.max_len and len(x['text2']) < self.args.max_len,
68 | result))
69 | if not self.args.sort_by_len:
70 | return result
71 | result = sorted(result, key=lambda x: (len(x['text1']), len(x['text2']), x['text1']))
72 | batch_size = self.args.batch_size
73 | return [self.make_batch(result[i:i + batch_size]) for i in range(0, len(data), batch_size)]
74 |
75 | def process_sample(self, sample, with_target=True):
76 | text1 = sample['text1']
77 | text2 = sample['text2']
78 | if self.args.lower_case:
79 | text1 = text1.lower()
80 | text2 = text2.lower()
81 | processed = {
82 | 'text1': [self.vocab.index(w) for w in text1.split()[:self.args.max_len]],
83 | 'text2': [self.vocab.index(w) for w in text2.split()[:self.args.max_len]],
84 | }
85 | if 'target' in sample and with_target:
86 | target = sample['target']
87 | assert target in self.target_map
88 | processed['target'] = self.target_map.index(target)
89 | return processed
90 |
91 | def shuffle_batch(self, data):
92 | data = random.sample(data, len(data))
93 | if self.args.sort_by_len:
94 | return data
95 | batch_size = self.args.batch_size
96 | batches = [data[i:i + batch_size] for i in range(0, len(data), batch_size)]
97 | return list(map(self.make_batch, batches))
98 |
99 | def make_batch(self, batch, with_target=True):
100 | batch = {key: [sample[key] for sample in batch] for key in batch[0].keys()}
101 | if 'target' in batch and not with_target:
102 | del batch['target']
103 | batch = {key: self.padding(value, min_len=self.args.min_len) if key.startswith('text') else value
104 | for key, value in batch.items()}
105 | return batch
106 |
107 | @staticmethod
108 | def padding(samples, min_len=1):
109 | max_len = max(max(map(len, samples)), min_len)
110 | batch = [sample + [Vocab.pad()] * (max_len - len(sample)) for sample in samples]
111 | return batch
112 |
113 | def post_process(self, output):
114 | final_prediction = []
115 | for prob in output:
116 | idx = max(range(len(prob)), key=prob.__getitem__)
117 | target = self.target_map[idx]
118 | final_prediction.append(target)
119 | return final_prediction
120 |
--------------------------------------------------------------------------------
/src/model.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 Alibaba Group Holding Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import os
18 | import math
19 | import random
20 | import torch
21 | import torch.nn.functional as f
22 | from tqdm import tqdm
23 | from .network import Network
24 | from .utils.metrics import registry as metrics
25 |
26 |
27 | class Model:
28 | prefix = 'checkpoint'
29 | best_model_name = 'best.pt'
30 |
31 | def __init__(self, args, state_dict=None):
32 | self.args = args
33 |
34 | # network
35 | self.network = Network(args)
36 | self.device = torch.cuda.current_device() if args.cuda else torch.device('cpu')
37 | self.network.to(self.device)
38 | # optimizer
39 | self.params = list(filter(lambda x: x.requires_grad, self.network.parameters()))
40 | self.opt = torch.optim.Adam(self.params, args.lr, betas=(args.beta1, args.beta2),
41 | weight_decay=args.weight_decay)
42 | # updates
43 | self.updates = state_dict['updates'] if state_dict else 0
44 |
45 | if state_dict:
46 | new_state = set(self.network.state_dict().keys())
47 | for k in list(state_dict['model'].keys()):
48 | if k not in new_state:
49 | del state_dict['model'][k]
50 | self.network.load_state_dict(state_dict['model'])
51 | self.opt.load_state_dict(state_dict['opt'])
52 |
53 | def _update_schedule(self):
54 | if self.args.lr_decay_rate < 1.:
55 | args = self.args
56 | t = self.updates
57 | base_ratio = args.min_lr / args.lr
58 | if t < args.lr_warmup_steps:
59 | ratio = base_ratio + (1. - base_ratio) / max(1., args.lr_warmup_steps) * t
60 | else:
61 | ratio = max(base_ratio, args.lr_decay_rate ** math.floor((t - args.lr_warmup_steps) /
62 | args.lr_decay_steps))
63 | self.opt.param_groups[0]['lr'] = args.lr * ratio
64 |
65 | def update(self, batch):
66 | self.network.train()
67 | self.opt.zero_grad()
68 | inputs, target = self.process_data(batch)
69 | output = self.network(inputs)
70 | summary = self.network.get_summary()
71 | loss = self.get_loss(output, target)
72 | loss.backward()
73 | grad_norm = torch.nn.utils.clip_grad_norm_(self.params, self.args.grad_clipping)
74 | assert grad_norm >= 0, 'encounter nan in gradients.'
75 | if isinstance(grad_norm, torch.Tensor):
76 | grad_norm = grad_norm.item()
77 | self.opt.step()
78 | self._update_schedule()
79 | self.updates += 1
80 | stats = {
81 | 'updates': self.updates,
82 | 'loss': loss.item(),
83 | 'lr': self.opt.param_groups[0]['lr'],
84 | 'gnorm': grad_norm,
85 | 'summary': summary,
86 | }
87 | return stats
88 |
89 | def evaluate(self, data):
90 | self.network.eval()
91 | targets = []
92 | probabilities = []
93 | predictions = []
94 | losses = []
95 | for batch in tqdm(data[:self.args.eval_subset], desc='evaluating', leave=False):
96 | inputs, target = self.process_data(batch)
97 | with torch.no_grad():
98 | output = self.network(inputs)
99 | loss = self.get_loss(output, target)
100 | pred = torch.argmax(output, dim=1)
101 | prob = torch.nn.functional.softmax(output, dim=1)
102 | losses.append(loss.item())
103 | targets.extend(target.tolist())
104 | probabilities.extend(prob.tolist())
105 | predictions.extend(pred.tolist())
106 | outputs = {
107 | 'target': targets,
108 | 'prob': probabilities,
109 | 'pred': predictions,
110 | 'args': self.args,
111 | }
112 | stats = {
113 | 'updates': self.updates,
114 | 'loss': sum(losses[:-1]) / (len(losses) - 1) if len(losses) > 1 else sum(losses),
115 | }
116 | for metric in self.args.watch_metrics:
117 | if metric not in stats: # multiple metrics could be computed by the same function
118 | stats.update(metrics[metric](outputs))
119 | assert 'score' not in stats, 'metric name collides with "score"'
120 | eval_score = stats[self.args.metric]
121 | stats['score'] = eval_score
122 | return eval_score, stats # first value is for early stopping
123 |
124 | def predict(self, batch):
125 | self.network.eval()
126 | inputs, _ = self.process_data(batch)
127 | with torch.no_grad():
128 | output = self.network(inputs)
129 | output = torch.nn.functional.softmax(output, dim=1)
130 | return output.tolist()
131 |
132 | def process_data(self, batch):
133 | text1 = torch.LongTensor(batch['text1']).to(self.device)
134 | text2 = torch.LongTensor(batch['text2']).to(self.device)
135 | mask1 = torch.ne(text1, self.args.padding).unsqueeze(2)
136 | mask2 = torch.ne(text2, self.args.padding).unsqueeze(2)
137 | inputs = {
138 | 'text1': text1,
139 | 'text2': text2,
140 | 'mask1': mask1,
141 | 'mask2': mask2,
142 | }
143 | if 'target' in batch:
144 | target = torch.LongTensor(batch['target']).to(self.device)
145 | return inputs, target
146 | return inputs, None
147 |
148 | @staticmethod
149 | def get_loss(logits, target):
150 | return f.cross_entropy(logits, target)
151 |
152 | def save(self, states, name=None):
153 | if name:
154 | filename = os.path.join(self.args.summary_dir, name)
155 | else:
156 | filename = os.path.join(self.args.summary_dir, f'{self.prefix}_{self.updates}.pt')
157 | params = {
158 | 'state_dict': {
159 | 'model': self.network.state_dict(),
160 | 'opt': self.opt.state_dict(),
161 | 'updates': self.updates,
162 | },
163 | 'args': self.args,
164 | 'random_state': random.getstate(),
165 | 'torch_state': torch.random.get_rng_state()
166 | }
167 | params.update(states)
168 | if self.args.cuda:
169 | params['torch_cuda_state'] = torch.cuda.get_rng_state()
170 | torch.save(params, filename)
171 |
172 | @classmethod
173 | def load(cls, file):
174 | checkpoint = torch.load(file, map_location=(
175 | lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
176 | ))
177 | prev_args = checkpoint['args']
178 | # update args
179 | prev_args.output_dir = os.path.dirname(os.path.dirname(file))
180 | prev_args.summary_dir = os.path.join(prev_args.output_dir, prev_args.name)
181 | prev_args.cuda = prev_args.cuda and torch.cuda.is_available()
182 | return cls(prev_args, state_dict=checkpoint['state_dict']), checkpoint
183 |
184 | def num_parameters(self, exclude_embed=False):
185 | num_params = sum(p.numel() for p in self.network.parameters() if p.requires_grad)
186 | if exclude_embed:
187 | num_params -= 0 if self.args.fix_embeddings else next(self.network.embedding.parameters()).numel()
188 | return num_params
189 |
190 | def set_embeddings(self, embeddings):
191 | self.network.embedding.set_(embeddings)
192 |
--------------------------------------------------------------------------------
/src/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 Alibaba Group Holding Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | from typing import Collection
18 | import math
19 | import torch
20 | import torch.nn as nn
21 |
22 |
23 | class Module(nn.Module):
24 | def __init__(self):
25 | super().__init__()
26 | self.summary = {}
27 |
28 | def add_summary(self, name, val):
29 | if self.training:
30 | self.summary[name] = val.clone().detach().cpu().numpy()
31 |
32 | def get_summary(self, base_name=''):
33 | summary = {}
34 | if base_name:
35 | base_name += '/'
36 | if self.summary:
37 | summary.update({base_name + name: val for name, val in self.summary.items()})
38 | for name, child in self.named_children():
39 | if hasattr(child, 'get_summary'):
40 | name = base_name + name
41 | summary.update(child.get_summary(name))
42 | return summary
43 |
44 |
45 | class ModuleList(nn.ModuleList):
46 | def get_summary(self, base_name=''):
47 | summary = {}
48 | if base_name:
49 | base_name += '/'
50 | for i, module in enumerate(self):
51 | if hasattr(module, 'get_summary'):
52 | name = base_name + str(i)
53 | summary.update(module.get_summary(name))
54 | return summary
55 |
56 |
57 | class ModuleDict(nn.ModuleDict):
58 | def get_summary(self, base_name=''):
59 | summary = {}
60 | if base_name:
61 | base_name += '/'
62 | for key, module in self.items():
63 | if hasattr(module, 'get_summary'):
64 | name = base_name + key
65 | summary.update(module.get_summary(name))
66 | return summary
67 |
68 |
69 | class GeLU(nn.Module):
70 | def forward(self, x):
71 | return 0.5 * x * (1. + torch.tanh(x * 0.7978845608 * (1. + 0.044715 * x * x)))
72 |
73 |
74 | class Linear(nn.Module):
75 | def __init__(self, in_features, out_features, activations=False):
76 | super().__init__()
77 | linear = nn.Linear(in_features, out_features)
78 | nn.init.normal_(linear.weight, std=math.sqrt((2. if activations else 1.) / in_features))
79 | nn.init.zeros_(linear.bias)
80 | modules = [nn.utils.weight_norm(linear)]
81 | if activations:
82 | modules.append(GeLU())
83 | self.model = nn.Sequential(*modules)
84 |
85 | def forward(self, x):
86 | return self.model(x)
87 |
88 |
89 | class Conv1d(Module):
90 | def __init__(self, in_channels, out_channels, kernel_sizes: Collection[int]):
91 | super().__init__()
92 | assert all(k % 2 == 1 for k in kernel_sizes), 'only support odd kernel sizes'
93 | assert out_channels % len(kernel_sizes) == 0, 'out channels must be dividable by kernels'
94 | out_channels = out_channels // len(kernel_sizes)
95 | convs = []
96 | for kernel_size in kernel_sizes:
97 | conv = nn.Conv1d(in_channels, out_channels, kernel_size,
98 | padding=(kernel_size - 1) // 2)
99 | nn.init.normal_(conv.weight, std=math.sqrt(2. / (in_channels * kernel_size)))
100 | nn.init.zeros_(conv.bias)
101 | convs.append(nn.Sequential(nn.utils.weight_norm(conv), GeLU()))
102 | self.model = nn.ModuleList(convs)
103 |
104 | def forward(self, x):
105 | return torch.cat([encoder(x) for encoder in self.model], dim=-1)
106 |
--------------------------------------------------------------------------------
/src/modules/alignment.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 Alibaba Group Holding Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import math
18 | import torch
19 | import torch.nn as nn
20 | import torch.nn.functional as f
21 | from functools import partial
22 | from src.utils.registry import register
23 | from . import Linear, Module
24 |
25 | registry = {}
26 | register = partial(register, registry=registry)
27 |
28 |
29 | @register('identity')
30 | class Alignment(Module):
31 | def __init__(self, args, __):
32 | super().__init__()
33 | self.temperature = nn.Parameter(torch.tensor(1 / math.sqrt(args.hidden_size)))
34 |
35 | def _attention(self, a, b):
36 | return torch.matmul(a, b.transpose(1, 2)) * self.temperature
37 |
38 | def forward(self, a, b, mask_a, mask_b):
39 | attn = self._attention(a, b)
40 | mask = torch.matmul(mask_a.float(), mask_b.transpose(1, 2).float())
41 | if tuple(torch.__version__.split('.')) < ('1', '2'):
42 | mask = mask.byte()
43 | else:
44 | mask = mask.bool()
45 | attn.masked_fill_(~mask, -1e7)
46 | attn_a = f.softmax(attn, dim=1)
47 | attn_b = f.softmax(attn, dim=2)
48 | feature_b = torch.matmul(attn_a.transpose(1, 2), a)
49 | feature_a = torch.matmul(attn_b, b)
50 | self.add_summary('temperature', self.temperature)
51 | self.add_summary('attention_a', attn_a)
52 | self.add_summary('attention_b', attn_b)
53 | return feature_a, feature_b
54 |
55 |
56 | @register('linear')
57 | class MappedAlignment(Alignment):
58 | def __init__(self, args, input_size):
59 | super().__init__(args, input_size)
60 | self.projection = nn.Sequential(
61 | nn.Dropout(args.dropout),
62 | Linear(input_size, args.hidden_size, activations=True),
63 | )
64 |
65 | def _attention(self, a, b):
66 | a = self.projection(a)
67 | b = self.projection(b)
68 | return super()._attention(a, b)
69 |
--------------------------------------------------------------------------------
/src/modules/connection.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 Alibaba Group Holding Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import math
18 | import torch
19 | import torch.nn as nn
20 | from . import Linear
21 | from functools import partial
22 | from src.utils.registry import register
23 | registry = {}
24 | register = partial(register, registry=registry)
25 |
26 |
27 | @register('none')
28 | class NullConnection(nn.Module):
29 | def __init__(self, _):
30 | super().__init__()
31 |
32 | def forward(self, x, _, __):
33 | return x
34 |
35 |
36 | @register('residual')
37 | class Residual(nn.Module):
38 | def __init__(self, args):
39 | super().__init__()
40 | self.linear = Linear(args.embedding_dim, args.hidden_size)
41 |
42 | def forward(self, x, res, i):
43 | if i == 1:
44 | res = self.linear(res)
45 | return (x + res) * math.sqrt(0.5)
46 |
47 |
48 | @register('aug')
49 | class AugmentedResidual(nn.Module):
50 | def __init__(self, _):
51 | super().__init__()
52 |
53 | def forward(self, x, res, i):
54 | if i == 1:
55 | return torch.cat([x, res], dim=-1) # res is embedding
56 | hidden_size = x.size(-1)
57 | x = (res[:, :, :hidden_size] + x) * math.sqrt(0.5)
58 | return torch.cat([x, res[:, :, hidden_size:]], dim=-1) # latter half of res is embedding
59 |
--------------------------------------------------------------------------------
/src/modules/embedding.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 Alibaba Group Holding Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.functional as f
20 |
21 |
22 | class Embedding(nn.Module):
23 | def __init__(self, args):
24 | super().__init__()
25 | self.fix_embeddings = args.fix_embeddings
26 | self.embedding = nn.Embedding(args.num_vocab, args.embedding_dim, padding_idx=0)
27 | self.dropout = args.dropout
28 |
29 | def set_(self, value):
30 | self.embedding.weight.requires_grad = not self.fix_embeddings
31 | self.embedding.load_state_dict({'weight': torch.tensor(value)})
32 |
33 | def forward(self, x):
34 | x = self.embedding(x)
35 | x = f.dropout(x, self.dropout, self.training)
36 | return x
37 |
--------------------------------------------------------------------------------
/src/modules/encoder.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 Alibaba Group Holding Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import torch.nn as nn
18 | import torch.nn.functional as f
19 | from . import Conv1d
20 |
21 |
22 | class Encoder(nn.Module):
23 | def __init__(self, args, input_size):
24 | super().__init__()
25 | self.dropout = args.dropout
26 | self.encoders = nn.ModuleList([Conv1d(
27 | in_channels=input_size if i == 0 else args.hidden_size,
28 | out_channels=args.hidden_size,
29 | kernel_sizes=args.kernel_sizes) for i in range(args.enc_layers)])
30 |
31 | def forward(self, x, mask):
32 | x = x.transpose(1, 2) # B x C x L
33 | mask = mask.transpose(1, 2)
34 | for i, encoder in enumerate(self.encoders):
35 | x.masked_fill_(~mask, 0.)
36 | if i > 0:
37 | x = f.dropout(x, self.dropout, self.training)
38 | x = encoder(x)
39 | x = f.dropout(x, self.dropout, self.training)
40 | return x.transpose(1, 2) # B x L x C
41 |
--------------------------------------------------------------------------------
/src/modules/fusion.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 Alibaba Group Holding Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.functional as f
20 | from functools import partial
21 | from src.utils.registry import register
22 | from . import Linear
23 |
24 | registry = {}
25 | register = partial(register, registry=registry)
26 |
27 |
28 | @register('simple')
29 | class Fusion(nn.Module):
30 | def __init__(self, args, input_size):
31 | super().__init__()
32 | self.fusion = Linear(input_size * 2, args.hidden_size, activations=True)
33 |
34 | def forward(self, x, align):
35 | return self.fusion(torch.cat([x, align], dim=-1))
36 |
37 |
38 | @register('full')
39 | class FullFusion(nn.Module):
40 | def __init__(self, args, input_size):
41 | super().__init__()
42 | self.dropout = args.dropout
43 | self.fusion1 = Linear(input_size * 2, args.hidden_size, activations=True)
44 | self.fusion2 = Linear(input_size * 2, args.hidden_size, activations=True)
45 | self.fusion3 = Linear(input_size * 2, args.hidden_size, activations=True)
46 | self.fusion = Linear(args.hidden_size * 3, args.hidden_size, activations=True)
47 |
48 | def forward(self, x, align):
49 | x1 = self.fusion1(torch.cat([x, align], dim=-1))
50 | x2 = self.fusion2(torch.cat([x, x - align], dim=-1))
51 | x3 = self.fusion3(torch.cat([x, x * align], dim=-1))
52 | x = torch.cat([x1, x2, x3], dim=-1)
53 | x = f.dropout(x, self.dropout, self.training)
54 | return self.fusion(x)
55 |
--------------------------------------------------------------------------------
/src/modules/pooling.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 Alibaba Group Holding Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import torch.nn as nn
18 |
19 |
20 | class Pooling(nn.Module):
21 | def forward(self, x, mask):
22 | return x.masked_fill_(~mask, -float('inf')).max(dim=1)[0]
23 |
--------------------------------------------------------------------------------
/src/modules/prediction.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 Alibaba Group Holding Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import torch
18 | import torch.nn as nn
19 | from functools import partial
20 | from src.utils.registry import register
21 | from . import Linear
22 |
23 | registry = {}
24 | register = partial(register, registry=registry)
25 |
26 |
27 | @register('simple')
28 | class Prediction(nn.Module):
29 | def __init__(self, args, inp_features=2):
30 | super().__init__()
31 | self.dense = nn.Sequential(
32 | nn.Dropout(args.dropout),
33 | Linear(args.hidden_size * inp_features, args.hidden_size, activations=True),
34 | nn.Dropout(args.dropout),
35 | Linear(args.hidden_size, args.num_classes),
36 | )
37 |
38 | def forward(self, a, b):
39 | return self.dense(torch.cat([a, b], dim=-1))
40 |
41 |
42 | @register('full')
43 | class AdvancedPrediction(Prediction):
44 | def __init__(self, args):
45 | super().__init__(args, inp_features=4)
46 |
47 | def forward(self, a, b):
48 | return self.dense(torch.cat([a, b, a - b, a * b], dim=-1))
49 |
50 |
51 | @register('symmetric')
52 | class SymmetricPrediction(AdvancedPrediction):
53 | def forward(self, a, b):
54 | return self.dense(torch.cat([a, b, (a - b).abs(), a * b], dim=-1))
55 |
--------------------------------------------------------------------------------
/src/network.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 Alibaba Group Holding Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import torch
18 | from .modules import Module, ModuleList, ModuleDict
19 | from .modules.embedding import Embedding
20 | from .modules.encoder import Encoder
21 | from .modules.alignment import registry as alignment
22 | from .modules.fusion import registry as fusion
23 | from .modules.connection import registry as connection
24 | from .modules.pooling import Pooling
25 | from .modules.prediction import registry as prediction
26 |
27 |
28 | class Network(Module):
29 | def __init__(self, args):
30 | super().__init__()
31 | self.dropout = args.dropout
32 | self.embedding = Embedding(args)
33 | input_emb_size = args.embedding_dim if args.connection == 'aug' else 0
34 | self.blocks = ModuleList([ModuleDict({
35 | 'encoder': Encoder(args, args.embedding_dim if i == 0 else input_emb_size + args.hidden_size),
36 | 'alignment': alignment[args.alignment](
37 | args, args.embedding_dim + args.hidden_size if i == 0 else input_emb_size + args.hidden_size * 2),
38 | 'fusion': fusion[args.fusion](
39 | args, args.embedding_dim + args.hidden_size if i == 0 else input_emb_size + args.hidden_size * 2),
40 | }) for i in range(args.blocks)])
41 |
42 | self.connection = connection[args.connection](args)
43 | self.pooling = Pooling()
44 | self.prediction = prediction[args.prediction](args)
45 |
46 | def forward(self, inputs):
47 | a = inputs['text1']
48 | b = inputs['text2']
49 | mask_a = inputs['mask1']
50 | mask_b = inputs['mask2']
51 |
52 | a = self.embedding(a)
53 | b = self.embedding(b)
54 | res_a, res_b = a, b
55 |
56 | for i, block in enumerate(self.blocks):
57 | if i > 0:
58 | a = self.connection(a, res_a, i)
59 | b = self.connection(b, res_b, i)
60 | res_a, res_b = a, b
61 | a_enc = block['encoder'](a, mask_a)
62 | b_enc = block['encoder'](b, mask_b)
63 | a = torch.cat([a, a_enc], dim=-1)
64 | b = torch.cat([b, b_enc], dim=-1)
65 | align_a, align_b = block['alignment'](a, b, mask_a, mask_b)
66 | a = block['fusion'](a, align_a)
67 | b = block['fusion'](b, align_b)
68 | a = self.pooling(a, mask_a)
69 | b = self.pooling(b, mask_b)
70 | return self.prediction(a, b)
71 |
--------------------------------------------------------------------------------
/src/trainer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 Alibaba Group Holding Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import os
18 | import random
19 | import json5
20 | import torch
21 | from datetime import datetime
22 | from pprint import pformat
23 | from .utils.loader import load_data
24 | from .utils.logger import Logger
25 | from .utils.params import validate_params
26 | from .model import Model
27 | from .interface import Interface
28 |
29 |
30 | class Trainer:
31 | def __init__(self, args):
32 | self.args = args
33 | self.log = Logger(self.args)
34 |
35 | def train(self):
36 | start_time = datetime.now()
37 | model, interface, states = self.build_model()
38 | train = load_data(self.args.data_dir, 'train')
39 | dev = load_data(self.args.data_dir, self.args.eval_file)
40 | self.log(f'train ({len(train)}) | {self.args.eval_file} ({len(dev)})')
41 | train_batches = interface.pre_process(train)
42 | dev_batches = interface.pre_process(dev, training=False)
43 | self.log('setup complete: {}s.'.format(str(datetime.now() - start_time).split(".")[0]))
44 |
45 | try:
46 | for epoch in range(states['start_epoch'], self.args.epochs + 1):
47 | states['epoch'] = epoch
48 | self.log.set_epoch(epoch)
49 |
50 | batches = interface.shuffle_batch(train_batches)
51 | for batch_id, batch in enumerate(batches):
52 | stats = model.update(batch)
53 | self.log.update(stats)
54 | eval_per_updates = self.args.eval_per_updates \
55 | if model.updates > self.args.eval_warmup_steps else self.args.eval_per_updates_warmup
56 | if model.updates % eval_per_updates == 0 or (self.args.eval_epoch and batch_id + 1 == len(batches)):
57 | self.log.newline()
58 | score, dev_stats = model.evaluate(dev_batches)
59 | if score > states['best_eval']:
60 | states['best_eval'], states['best_epoch'], states['best_step'] = score, epoch, model.updates
61 | if self.args.save:
62 | model.save(states, name=model.best_model_name)
63 | self.log.log_eval(dev_stats)
64 | if self.args.save_all:
65 | model.save(states)
66 | model.save(states, name='last')
67 | if model.updates - states['best_step'] > self.args.early_stopping \
68 | and model.updates > self.args.min_steps:
69 | self.log('[Tolerance reached. Training is stopped early.]')
70 | raise EarlyStop('[Tolerance reached. Training is stopped early.]')
71 | if stats['loss'] > self.args.max_loss:
72 | raise EarlyStop('[Loss exceeds tolerance. Unstable training is stopped early.]')
73 | if stats['lr'] < self.args.min_lr - 1e-6:
74 | raise EarlyStop('[Learning rate has decayed below min_lr. Training is stopped early.]')
75 | self.log.newline()
76 | self.log('Training complete.')
77 | except KeyboardInterrupt:
78 | self.log.newline()
79 | self.log(f'Training interrupted. Stopped early.')
80 | except EarlyStop as e:
81 | self.log.newline()
82 | self.log(str(e))
83 | self.log(f'best dev score {states["best_eval"]} at step {states["best_step"]} '
84 | f'(epoch {states["best_epoch"]}).')
85 | self.log(f'best eval stats [{self.log.best_eval_str}]')
86 | training_time = str(datetime.now() - start_time).split('.')[0]
87 | self.log(f'Training time: {training_time}.')
88 | states['start_time'] = str(start_time).split('.')[0]
89 | states['training_time'] = training_time
90 | return states
91 |
92 | def build_model(self):
93 | states = {}
94 | interface = Interface(self.args, self.log)
95 | self.log(f'#classes: {self.args.num_classes}; #vocab: {self.args.num_vocab}')
96 | if self.args.seed:
97 | random.seed(self.args.seed)
98 | torch.manual_seed(self.args.seed)
99 | if self.args.cuda:
100 | torch.cuda.manual_seed(self.args.seed)
101 | if self.args.deterministic:
102 | torch.backends.cudnn.deterministic = True
103 |
104 | model = Model(self.args)
105 | if self.args.pretrained_embeddings:
106 | embeddings = interface.load_embeddings()
107 | model.set_embeddings(embeddings)
108 |
109 | # set initial states
110 | states['start_epoch'] = 1
111 | states['best_eval'] = 0.
112 | states['best_epoch'] = 0
113 | states['best_step'] = 0
114 |
115 | self.log(f'trainable params: {model.num_parameters():,d}')
116 | self.log(f'trainable params (exclude embeddings): {model.num_parameters(exclude_embed=True):,d}')
117 | validate_params(self.args)
118 | with open(os.path.join(self.args.summary_dir, 'args.json5'), 'w') as f:
119 | json5.dump(self.args.__dict__, f, indent=2)
120 | self.log(pformat(vars(self.args), indent=2, width=120))
121 | return model, interface, states
122 |
123 |
124 | class EarlyStop(Exception):
125 | pass
126 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/utils/loader.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 Alibaba Group Holding Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import os
18 | import numpy as np
19 |
20 |
21 | def load_data(data_dir, split=None):
22 | data = []
23 | if split is None:
24 | files = [os.path.join(data_dir, file) for file in os.listdir(data_dir) if file.endswith('.txt')]
25 | else:
26 | if not split.endswith('.txt'):
27 | split += '.txt'
28 | files = [os.path.join(data_dir, f'{split}')]
29 | for file in files:
30 | with open(file) as f:
31 | for line in f:
32 | text1, text2, label = line.rstrip().split('\t')
33 | data.append({
34 | 'text1': text1,
35 | 'text2': text2,
36 | 'target': label,
37 | })
38 | return data
39 |
40 |
41 | def load_embeddings(file, vocab, dim, lower, mode='freq'):
42 | embedding = np.zeros((len(vocab), dim))
43 | count = np.zeros((len(vocab), 1))
44 | with open(file) as f:
45 | for line in f:
46 | elems = line.rstrip().split()
47 | if len(elems) != dim + 1:
48 | continue
49 | token = elems[0]
50 | if lower and mode != 'strict':
51 | token = token.lower()
52 | if token in vocab:
53 | index = vocab.index(token)
54 | vector = [float(x) for x in elems[1:]]
55 | if mode == 'freq' or mode == 'strict':
56 | if not count[index]:
57 | embedding[index] = vector
58 | count[index] = 1.
59 | elif mode == 'last':
60 | embedding[index] = vector
61 | count[index] = 1.
62 | elif mode == 'avg':
63 | embedding[index] += vector
64 | count[index] += 1.
65 | else:
66 | raise NotImplementedError('Unknown embedding loading mode: ' + mode)
67 | if mode == 'avg':
68 | inverse_mask = np.where(count == 0, 1., 0.)
69 | embedding /= count + inverse_mask
70 | return embedding.tolist()
71 |
--------------------------------------------------------------------------------
/src/utils/logger.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 Alibaba Group Holding Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import os
18 | import sys
19 | import logging
20 |
21 |
22 | class Logger:
23 | def __init__(self, args):
24 | log = logging.getLogger(args.summary_dir)
25 | if not log.handlers:
26 | log.setLevel(logging.DEBUG)
27 | fh = logging.FileHandler(os.path.join(args.summary_dir, args.log_file))
28 | fh.setLevel(logging.INFO)
29 | ch = ProgressHandler()
30 | ch.setLevel(logging.DEBUG)
31 | formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S')
32 | fh.setFormatter(formatter)
33 | ch.setFormatter(formatter)
34 | log.addHandler(fh)
35 | log.addHandler(ch)
36 | self.log = log
37 | # setup TensorBoard
38 | if args.tensorboard:
39 | from tensorboardX import SummaryWriter
40 | self.writer = SummaryWriter(os.path.join(args.summary_dir, 'viz'))
41 | self.log.info(f'TensorBoard activated.')
42 | else:
43 | self.writer = None
44 | self.log_per_updates = args.log_per_updates
45 | self.summary_per_updates = args.summary_per_updates
46 | self.grad_clipping = args.grad_clipping
47 | self.clips = 0
48 | self.train_meters = {}
49 | self.epoch = None
50 | self.best_eval = 0.
51 | self.best_eval_str = ''
52 |
53 | def set_epoch(self, epoch):
54 | self(f'Epoch: {epoch}')
55 | self.epoch = epoch
56 |
57 | @staticmethod
58 | def _format_number(x):
59 | return f'{x:.4f}' if float(x) > 1e-3 else f'{x:.4e}'
60 |
61 | def update(self, stats):
62 | updates = stats.pop('updates')
63 | summary = stats.pop('summary')
64 | if updates % self.log_per_updates == 0:
65 | self.clips += int(stats['gnorm'] > self.grad_clipping)
66 | stats_str = ' '.join(f'{key}: ' + self._format_number(val) for key, val in stats.items())
67 | for key, val in stats.items():
68 | if key not in self.train_meters:
69 | self.train_meters[key] = AverageMeter()
70 | self.train_meters[key].update(val)
71 | msg = f'epoch {self.epoch} updates {updates} {stats_str}'
72 | if self.log_per_updates != 1:
73 | msg = '> ' + msg
74 | self.log.info(msg)
75 | if self.writer and updates % self.summary_per_updates == 0:
76 | for key, val in stats.items():
77 | self.writer.add_scalar(f'train/{key}', val, updates)
78 | for key, val in summary.items():
79 | self.writer.add_histogram(key, val, updates)
80 |
81 | def newline(self):
82 | self.log.debug('')
83 |
84 | def log_eval(self, valid_stats):
85 | self.newline()
86 | updates = valid_stats.pop('updates')
87 | eval_score = valid_stats.pop('score')
88 | # report the exponential averaged training stats, while reporting the full dev set stats
89 | if self.train_meters:
90 | train_stats_str = ' '.join(f'{key}: ' + self._format_number(val) for key, val in self.train_meters.items())
91 | train_stats_str += ' ' + f'clip: {self.clips}'
92 | self.log.info(f'train {train_stats_str}')
93 | valid_stats_str = ' '.join(f'{key}: ' + self._format_number(val) for key, val in valid_stats.items())
94 | if eval_score > self.best_eval:
95 | self.best_eval_str = valid_stats_str
96 | self.best_eval = eval_score
97 | valid_stats_str += ' [NEW BEST]'
98 | else:
99 | valid_stats_str += f' [BEST: {self._format_number(self.best_eval)}]'
100 | self.log.info(f'valid {valid_stats_str}')
101 | if self.writer:
102 | for key in valid_stats.keys():
103 | group = {'valid': valid_stats[key]}
104 | if self.train_meters and key in self.train_meters:
105 | group['train'] = float(self.train_meters[key])
106 | self.writer.add_scalars(f'valid/{key}', group, updates)
107 | self.train_meters = {}
108 | self.clips = 0
109 |
110 | def __call__(self, msg):
111 | self.log.info(msg)
112 |
113 |
114 | class ProgressHandler(logging.Handler):
115 | def __init__(self, level=logging.NOTSET):
116 | super().__init__(level)
117 |
118 | def emit(self, record):
119 | log_entry = self.format(record)
120 | if record.message.startswith('> '):
121 | sys.stdout.write('{}\r'.format(log_entry.rstrip()))
122 | sys.stdout.flush()
123 | else:
124 | sys.stdout.write('{}\n'.format(log_entry))
125 |
126 |
127 | class AverageMeter(object):
128 | """Keep exponential weighted averages."""
129 | def __init__(self, beta=0.99):
130 | self.beta = beta
131 | self.moment = 0.
132 | self.value = 0.
133 | self.t = 0.
134 |
135 | def update(self, val):
136 | self.t += 1
137 | self.moment = self.beta * self.moment + (1 - self.beta) * val
138 | # bias correction
139 | self.value = self.moment / (1 - self.beta ** self.t)
140 |
141 | def __format__(self, spec):
142 | return format(self.value, spec)
143 |
144 | def __float__(self):
145 | return self.value
146 |
--------------------------------------------------------------------------------
/src/utils/metrics.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 Alibaba Group Holding Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import os
18 | import subprocess
19 | from functools import partial
20 | import numpy as np
21 | from sklearn import metrics
22 |
23 | from .registry import register
24 |
25 | registry = {}
26 | register = partial(register, registry=registry)
27 |
28 |
29 | @register('acc')
30 | def acc(outputs):
31 | target = outputs['target']
32 | pred = outputs['pred']
33 | return {
34 | 'acc': metrics.accuracy_score(target, pred).item(),
35 | }
36 |
37 |
38 | @register('f1')
39 | def f1(outputs):
40 | target = outputs['target']
41 | pred = outputs['pred']
42 | return {
43 | 'f1': metrics.f1_score(target, pred).item(),
44 | }
45 |
46 |
47 | @register('auc')
48 | def auc(outputs):
49 | target = outputs['target']
50 | prob = np.array(outputs['prob'])
51 | return {
52 | 'auc': metrics.roc_auc_score(target, prob[:, 1]).item(),
53 | }
54 |
55 |
56 | @register('map')
57 | @register('mrr')
58 | def ranking(outputs):
59 | args = outputs['args']
60 | prediction = [o[1] for o in outputs['prob']]
61 | ref_file = os.path.join(args.data_dir, '{}.ref'.format(args.eval_file))
62 | rank_file = os.path.join(args.data_dir, '{}.rank'.format(args.eval_file))
63 | tmp_file = os.path.join(args.summary_dir, 'tmp-pred.txt')
64 | with open(rank_file) as f:
65 | prefix = []
66 | for line in f:
67 | prefix.append(line.strip().split())
68 | assert len(prefix) == len(prediction), \
69 | 'prefix {}, while prediction {}'.format(len(prefix), len(prediction))
70 | with open(tmp_file, 'w') as f:
71 | for prefix, pred in zip(prefix, prediction):
72 | prefix[-2] = str(pred)
73 | f.write(' '.join(prefix) + '\n')
74 | sp = subprocess.Popen('./resources/trec_eval {} {} | egrep "map|recip_rank"'.format(ref_file, tmp_file),
75 | shell=True,
76 | stdout=subprocess.PIPE, stderr=subprocess.PIPE)
77 | stdout, stderr = sp.communicate()
78 | stdout, stderr = stdout.decode(), stderr.decode()
79 | os.remove(tmp_file)
80 | map_, mrr = [float(s[-6:]) for s in stdout.strip().split('\n')]
81 | return {
82 | 'map': map_,
83 | 'mrr': mrr,
84 | }
85 |
--------------------------------------------------------------------------------
/src/utils/params.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 Alibaba Group Holding Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import os
18 | import math
19 | import shutil
20 | from datetime import datetime
21 | import torch
22 | import json5
23 |
24 |
25 | class Object:
26 | """
27 | @DynamicAttrs
28 | """
29 | pass
30 |
31 |
32 | def parse(config_file):
33 | root = os.path.dirname(config_file) # __parent__ in config is a relative path
34 | config_group = _load_param('', config_file)
35 | if type(config_group) is dict:
36 | config_group = [config_group]
37 | configs = []
38 | for config in config_group:
39 | try:
40 | choice = config.pop('__iter__')
41 | assert len(choice) == 1, 'only support iterating over 1 variable'
42 | key, values = next(iter(choice.items()))
43 | except KeyError:
44 | key, value = config.popitem()
45 | values = [value]
46 | for value in values:
47 | config[key] = value
48 | repeat = config.get('__repeat__', 1)
49 | for index in range(repeat):
50 | config_ = config.copy()
51 | config_['__index__'] = index
52 | if repeat > 1:
53 | config_['name'] += '-' + str(index)
54 | args = _parse_args(root, config_)
55 | configs.append((args, config_))
56 | return configs
57 |
58 |
59 | def _parse_args(root, config):
60 | args = Object()
61 | assert type(config) is dict
62 | parents = config.get('__parents__', [])
63 | for parent in parents:
64 | parent = _load_param(root, parent)
65 | assert type(parent) is dict, 'only top-level configs can be a sequence'
66 | _add_param(args, parent)
67 | _add_param(args, config)
68 | _post_process(args)
69 | return args
70 |
71 |
72 | def _add_param(args, x: dict):
73 | for k, v in x.items():
74 | if type(v) is dict:
75 | _add_param(args, v)
76 | else:
77 | k = _validate_param(k)
78 | if hasattr(args, k):
79 | previous_type = type(getattr(args, k))
80 | current_type = type(v)
81 | assert previous_type is current_type or \
82 | isinstance(None, previous_type) or \
83 | isinstance(None, current_type) or \
84 | (previous_type is float and current_type is int), \
85 | f'param "{k}" of type {previous_type} can not be overwritten by type {current_type}'
86 | setattr(args, k, v)
87 |
88 |
89 | def _load_param(root, file: str):
90 | file = os.path.join(root, file)
91 | if not file.endswith('.json5'):
92 | file += '.json5'
93 | with open(file) as f:
94 | config = json5.load(f)
95 | return config
96 |
97 |
98 | def _post_process(args: Object):
99 | if not args.output_dir.startswith('models'):
100 | args.output_dir = os.path.join('models', args.output_dir)
101 | os.makedirs(args.output_dir, exist_ok=True)
102 | if not args.name:
103 | args.name = str(datetime.now())
104 | args.summary_dir = os.path.join(args.output_dir, args.name)
105 | if os.path.exists(args.summary_dir):
106 | shutil.rmtree(args.summary_dir)
107 | os.makedirs(args.summary_dir)
108 | data_config_file = os.path.join(args.output_dir, 'data_config.json5')
109 | if os.path.exists(data_config_file):
110 | with open(data_config_file) as f:
111 | config = json5.load(f)
112 | for k, v in config.items():
113 | if not hasattr(args, k) or getattr(args, k) != v:
114 | print('ERROR: Data configurations are different. Please use another output_dir or '
115 | 'remove the older one manually.')
116 | exit()
117 | else:
118 | with open(data_config_file, 'w') as f:
119 | keys = ['data_dir', 'min_df', 'max_vocab', 'max_len', 'min_len', 'lower_case',
120 | 'pretrained_embeddings', 'embedding_mode']
121 | json5.dump({k: getattr(args, k) for k in keys}, f)
122 | args.metric = args.metric.lower()
123 | args.watch_metrics = [m.lower() for m in args.watch_metrics]
124 | if args.metric not in args.watch_metrics:
125 | args.watch_metrics.append(args.metric)
126 | args.cuda = args.cuda and torch.cuda.is_available()
127 | args.fix_embeddings = args.pretrained_embeddings and args.fix_embeddings
128 |
129 | def samples2steps(n):
130 | return int(math.ceil(n / args.batch_size))
131 |
132 | if not hasattr(args, 'log_per_updates'):
133 | args.log_per_updates = samples2steps(args.log_per_samples)
134 | if not hasattr(args, 'eval_per_updates'):
135 | args.eval_per_updates = samples2steps(args.eval_per_samples)
136 | if not hasattr(args, 'eval_per_updates_warmup'):
137 | args.eval_per_updates_warmup = samples2steps(args.eval_per_samples_warmup)
138 | if not hasattr(args, 'eval_warmup_steps'):
139 | args.eval_warmup_steps = samples2steps(args.eval_warmup_samples)
140 | if not hasattr(args, 'min_steps'):
141 | args.min_steps = samples2steps(args.min_samples)
142 | if not hasattr(args, 'early_stopping'):
143 | args.early_stopping = samples2steps(args.tolerance_samples)
144 | if not hasattr(args, 'lr_warmup_steps'):
145 | args.lr_warmup_steps = samples2steps(args.lr_warmup_samples)
146 | if not hasattr(args, 'lr_decay_steps'):
147 | args.lr_decay_steps = samples2steps(args.lr_decay_samples)
148 | if not hasattr(args, 'summary_per_updates'):
149 | args.summary_per_updates = args.summary_per_logs * args.log_per_updates
150 | assert args.lr >= args.min_lr, 'initial learning rate must be larger than min_lr'
151 |
152 |
153 | def validate_params(args):
154 | """validate params after interface initialization"""
155 | assert args.num_classes == 2 or ('f1' not in args.watch_metrics and 'auc' not in args.watch_metrics), \
156 | f'F1 and AUC are only valid for 2 classes.'
157 | assert args.num_classes == 2 or 'ranking' not in args.watch_metrics, \
158 | f'ranking metrics are only valid for 2 classes.'
159 | assert args.num_vocab > 0
160 |
161 |
162 | def _validate_param(name):
163 | name = name.replace('-', '_')
164 | if not str.isidentifier(name):
165 | raise ValueError(f'Invalid param name: {name}')
166 | return name
167 |
--------------------------------------------------------------------------------
/src/utils/registry.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 Alibaba Group Holding Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | def register(name=None, registry=None):
18 | def decorator(fn, registration_name=None):
19 | module_name = registration_name or _default_name(fn)
20 | if module_name in registry:
21 | raise LookupError(f"module {module_name} already registered.")
22 | registry[module_name] = fn
23 | return fn
24 | return lambda fn: decorator(fn, name)
25 |
26 |
27 | def _default_name(obj_class):
28 | return obj_class.__name__
29 |
--------------------------------------------------------------------------------
/src/utils/vocab.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 Alibaba Group Holding Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | from collections import Counter
18 |
19 |
20 | class Indexer:
21 | def __init__(self):
22 | self.w2id = {}
23 | self.id2w = {}
24 |
25 | @property
26 | def n_spec(self):
27 | return 0
28 |
29 | def __len__(self):
30 | return len(self.w2id)
31 |
32 | def __getitem__(self, index):
33 | if index not in self.id2w:
34 | raise IndexError(f'invalid index {index} in indices.')
35 | return self.id2w[index]
36 |
37 | def __contains__(self, item):
38 | return item in self.w2id
39 |
40 | def index(self, symbol):
41 | if symbol in self.w2id:
42 | return self.w2id[symbol]
43 | raise IndexError(f'Unknown symbol {symbol}')
44 |
45 | def keys(self):
46 | return self.w2id.keys()
47 |
48 | def indices(self):
49 | return self.id2w.keys()
50 |
51 | def add_symbol(self, symbol):
52 | if symbol not in self.w2id:
53 | self.id2w[len(self.id2w)] = symbol
54 | self.w2id[symbol] = len(self.w2id)
55 |
56 | @classmethod
57 | def build(cls, symbols, min_counts=1, dump_filtered=None, log=print):
58 | counter = Counter(symbols)
59 | symbols = sorted([t for t, c in counter.items() if c >= min_counts],
60 | key=counter.get, reverse=True)
61 | log(f'''{len(symbols)} symbols found: {' '.join(symbols[:15]) + ('...' if len(symbols) > 15 else '')}''')
62 | filtered = sorted(list(counter.keys() - set(symbols)), key=counter.get, reverse=True)
63 | if filtered:
64 | log('filtered classes:')
65 | if len(filtered) > 20:
66 | log('{} ... {}'.format(' '.join(filtered[:10]), ' '.join(filtered[-10:])))
67 | else:
68 | log(' '.join(filtered))
69 | if dump_filtered:
70 | with open(dump_filtered, 'w') as f:
71 | for name in filtered:
72 | f.write(f'{name} {counter.get(name)}\n')
73 | indexer = cls()
74 | try: # restore numeric order if labels are represented by integers already
75 | symbols = list(map(int, symbols))
76 | symbols.sort()
77 | symbols = list(map(str, symbols))
78 | except ValueError:
79 | pass
80 | for symbol in symbols:
81 | if symbol:
82 | indexer.add_symbol(symbol)
83 | return indexer
84 |
85 | def save(self, file):
86 | with open(file, 'w') as f:
87 | for symbol, index in self.w2id.items():
88 | if index < self.n_spec:
89 | continue
90 | f.write('{}\n'.format(symbol))
91 |
92 | @classmethod
93 | def load(cls, file):
94 | indexer = cls()
95 | with open(file) as f:
96 | for line in f:
97 | symbol = line.rstrip()
98 | assert len(symbol) > 0, 'Empty symbol encountered.'
99 | indexer.add_symbol(symbol)
100 | return indexer
101 |
102 |
103 | class RobustIndexer(Indexer):
104 | def __init__(self, validate=True):
105 | super().__init__()
106 | self.w2id.update({self.unk_symbol(): self.unk()})
107 | self.id2w = {i: w for w, i in self.w2id.items()}
108 | if validate:
109 | self.validate_spec()
110 |
111 | @property
112 | def n_spec(self):
113 | return 1
114 |
115 | def index(self, symbol):
116 | return self.w2id[symbol] if symbol in self.w2id else self.unk()
117 |
118 | @staticmethod
119 | def unk():
120 | return 0
121 |
122 | @staticmethod
123 | def unk_symbol():
124 | return ''
125 |
126 | def validate_spec(self):
127 | assert self.n_spec == len(self.w2id), f'{self.n_spec}, {len(self.w2id)}'
128 | assert len(self.w2id) == max(self.id2w.keys()) + 1, "empty indices found in special tokens"
129 | assert len(self.w2id) == len(self.id2w), "index conflict in special tokens"
130 |
131 |
132 | class Vocab(RobustIndexer):
133 | def __init__(self):
134 | super().__init__(validate=False)
135 | self.w2id.update({
136 | self.pad_symbol(): self.pad(),
137 | })
138 | self.id2w = {i: w for w, i in self.w2id.items()}
139 | self.validate_spec()
140 |
141 | @classmethod
142 | def build(cls, words, lower=False, min_df=1, max_tokens=float('inf'), pretrained_embeddings=None,
143 | dump_filtered=None, log=print):
144 | if pretrained_embeddings:
145 | wv_vocab = cls.load_embedding_vocab(pretrained_embeddings, lower)
146 | else:
147 | wv_vocab = set()
148 | if lower:
149 | words = (word.lower() for word in words)
150 | counter = Counter(words)
151 | candidate_tokens = sorted([t for t, c in counter.items() if t in wv_vocab or c >= min_df],
152 | key=counter.get, reverse=True)
153 | if len(candidate_tokens) > max_tokens:
154 | tokens = []
155 | for i, token in enumerate(candidate_tokens):
156 | if i < max_tokens:
157 | tokens.append(token)
158 | elif token in wv_vocab:
159 | tokens.append(token)
160 | else:
161 | tokens = candidate_tokens
162 | total = sum(counter.values())
163 | matched = sum(counter[t] for t in tokens)
164 | stats = (len(tokens), len(counter), total - matched, total, (total - matched) / total * 100)
165 | log('vocab coverage {}/{} | OOV occurrences {}/{} ({:.4f}%)'.format(*stats))
166 | tokens_set = set(tokens)
167 | if pretrained_embeddings:
168 | oop_samples = sorted(list(tokens_set - wv_vocab), key=counter.get, reverse=True)
169 | log('Covered by pretrained vectors {:.4f}%. '.format(len(tokens_set & wv_vocab) / len(tokens) * 100) +
170 | ('outside pretrained: ' + ' '.join(oop_samples[:10]) + ' ...' if len(oop_samples) > 10 else '')
171 | if oop_samples else '')
172 | log('top words:\n{}'.format(' '.join(tokens[:10])))
173 | filtered = sorted(list(counter.keys() - set(tokens)), key=counter.get, reverse=True)
174 | if filtered:
175 | if len(filtered) > 20:
176 | log('filtered words:\n{} ... {}'.format(' '.join(filtered[:10]), ' '.join(filtered[-10:])))
177 | else:
178 | log('filtered words:\n' + ' '.join(filtered))
179 | if dump_filtered:
180 | with open(dump_filtered, 'w') as f:
181 | for name in filtered:
182 | f.write(f'{name} {counter.get(name)}\n')
183 |
184 | vocab = cls()
185 | for token in tokens:
186 | vocab.add_symbol(token)
187 | return vocab
188 |
189 | @staticmethod
190 | def load_embedding_vocab(file, lower):
191 | wv_vocab = set()
192 | with open(file) as f:
193 | for line in f:
194 | token = line.rstrip().split(' ')[0]
195 | if lower:
196 | token = token.lower()
197 | wv_vocab.add(token)
198 | return wv_vocab
199 |
200 | @staticmethod
201 | def pad():
202 | return 0
203 |
204 | @staticmethod
205 | def unk():
206 | return 1
207 |
208 | @property
209 | def n_spec(self):
210 | return 2
211 |
212 | @staticmethod
213 | def pad_symbol():
214 | return ''
215 |
216 | char_map = { # escape special characters for safe serialization
217 | '\n': '',
218 | }
219 |
220 | def save(self, file):
221 | with open(file, 'w') as f:
222 | for symbol, index in self.w2id.items():
223 | if index < self.n_spec:
224 | continue
225 | symbol = self.char_map.get(symbol, symbol)
226 | f.write(f'{symbol}\n')
227 |
228 | @classmethod
229 | def load(cls, file):
230 | vocab = cls()
231 | reverse_char_map = {v: k for k, v in cls.char_map.items()}
232 | with open(file) as f:
233 | for line in f:
234 | symbol = line.rstrip('\n')
235 | symbol = reverse_char_map.get(symbol, symbol)
236 | vocab.add_symbol(symbol)
237 | return vocab
238 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (C) 2019 Alibaba Group Holding Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | import os
18 | import sys
19 | import json5
20 | from pprint import pprint
21 | from src.utils import params
22 | from src.trainer import Trainer
23 |
24 |
25 | def main():
26 | argv = sys.argv
27 | if len(argv) == 2:
28 | arg_groups = params.parse(sys.argv[1])
29 | for args, config in arg_groups:
30 | trainer = Trainer(args)
31 | states = trainer.train()
32 | with open('models/log.jsonl', 'a') as f:
33 | f.write(json5.dumps({
34 | 'data': os.path.basename(args.data_dir),
35 | 'params': config,
36 | 'state': states,
37 | }))
38 | f.write('\n')
39 | elif len(argv) == 3 and '--dry' in argv:
40 | argv.remove('--dry')
41 | arg_groups = params.parse(sys.argv[1])
42 | pprint([args.__dict__ for args, _ in arg_groups])
43 | else:
44 | print('Usage: "python train.py configs/xxx.json5"')
45 |
46 |
47 | if __name__ == '__main__':
48 | main()
49 |
--------------------------------------------------------------------------------