├── .idea
├── encodings.xml
├── misc.xml
├── modules.xml
├── t2d_discourseparser.iml
├── vcs.xml
└── workspace.xml
├── LICENSE
├── README.md
├── dataset
├── __init__.py
└── cdtb.py
├── evaluate.py
├── interface.py
├── new_ctb.py
├── parse.py
├── pipeline.py
├── requirements.txt
├── rst_dev_description.txt
├── sample.txt
├── sample.xml
├── segmenter
├── __init__.py
├── gcn
│ ├── __init__.py
│ ├── model.py
│ ├── segmenter.py
│ ├── test.py
│ └── train.py
├── rnn
│ ├── __init__.py
│ ├── model.py
│ ├── segmenter.py
│ ├── test.py
│ └── train.py
└── svm
│ ├── __init__.py
│ ├── model.py
│ ├── segmenter.py
│ ├── test.py
│ └── train.py
├── structure
├── __init__.py
├── nodes.py
└── vocab.py
├── treebuilder
├── __init__.py
├── partptr
│ ├── __init__.py
│ ├── model.py
│ ├── parser.py
│ ├── test.py
│ ├── train.py
│ └── train_b.py
└── shiftreduce
│ ├── __init__.py
│ ├── model.py
│ ├── parser.py
│ └── train.py
└── util
├── __init__.py
├── berkely.py
├── eval.py
└── ltp.py
/.idea/encodings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/t2d_discourseparser.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 | 1621475776557
113 |
114 |
115 | 1621475776557
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Top-down Text-level DRS Parser
2 |
3 | I often fail to access GitHub, so just send emails to zzlynx@outlook.com (Longyin Zhang) if you have any questions.
4 |
5 | -- General Information
6 | ```
7 | This project presents the top-down DRS parser described in the paper "Longyin Zhang,
8 | Yuqing Xing, Fang Kong, Peifeng Li, and Guodong Zhou. A Top-Down Neural Architecture
9 | towards Text-Level Parsing of Discourse Rhetorical Structure (ACL2020)".
10 | ```
11 |
12 | #### Installation
13 | - Python 3.6
14 | - java for the use of Berkeley Parser
15 | - other packages in requirements.txt
16 |
17 | #### Project Structure
18 | ```
19 | ---ChineseDiscourseParser
20 | |-berkeleyparser Berkeley
21 | |-data / corpus and models
22 | | |-cache / processed data
23 | | |-CDTB / the CDTB corpus
24 | | |-CTB / the ChineseTreebank corpus
25 | | |-CTB_auto / use Berkeley Parser to parse CTB sentences
26 | | |-log / tensorboard log files
27 | | |-models / selected model
28 | | |-pretrained / word vectors
29 | |-dataset / utils for data utilization
30 | | |-cdtb / utils for CDTB processing
31 | |-pub
32 | |-models / related pre-trained models
33 | |-pyltp_models / third-party models of pyltp
34 | |-segmenter / EDU segmentation
35 | | |- gcn / GCN based EDU segmenter
36 | | |-rnn / LSTM-based EDU segmenter
37 | | |-svm / SVM-based EDU segmenter
38 | |-structure / tree structures
39 | | |-nodes.py / tree nodes
40 | | |-vocab.py / vocabolary list
41 | |-treebuilder / tree parser
42 | | |-partptr / the top-down DRS parser
43 | | |-shiftreduce / the transition-based DRS parser
44 | |-util / some utils
45 | | |-berkeley.py / functions of Berkeley Parser
46 | | |-eval.py / evaluation methods
47 | | |-ltp.py / PyLTP tools
48 | |-evaluate.py / evaluation
49 | |-interface.py
50 | |-parser.py / parsing with pre-trained parsers
51 | |-pipeline.py / a pipelined framework
52 | ```
53 |
54 | ##### Project Functions
55 |
56 | 1. discourse rhetorical structure parsing
57 |
58 | Run the following command for DRS parsing:
59 | ```shell
60 | python3 parser.py source save [-schema schema_name] [-segmenter_name segmenter_name] [--encoding utf-8] [--draw] [--use_gpu]
61 | ```
62 |
63 | - source: the path of input texts where each line refers to a paragraph;
64 | - save: path to save the parse trees;
65 | - schema: `shiftreduce` and `topdown` / different parsing strategies;
66 | - segmenter_name: different segmentation strategies;
67 | - encoding: encoding format, UTF-8 in default;
68 | - draw: whether draw the tree or not through the tkinter tool;
69 | - use_gpu:use GPU or not.
70 |
71 |
72 | 2. performance evaluation
73 |
74 | Run the following command for performance evaluation:
75 | `python3 evaluate.py data [--ctb_dir ctb_dir] [-schema topdown|shiftreduce] [-segmenter_name svm|gcn] [-use_gold_edu] [--use_gpu]`
76 |
77 | - data: the path of the CDTB corpus;
78 | - ctb_dir: the path of the CTB corpus with CTB based on gold standard syntax and CTB_auto based on auto-syntax;
79 | - cache_dir: the path of cached data;
80 | - schema: the evaluation method to use;
81 | - segmenter_name: the EDU segmenter to use;
82 | - use_gold_edu: whether use Gold EDU or not;
83 | - use_gpu: use GPU or not.
84 |
85 | 3. model training
86 |
87 | Taking the EDU segmenter for example:
88 | ```shell
89 | python -m segmenter.gcn.train data/CDTB -model_save data/models/segmenter.gcn.model -pretrained data/pretrained/sgns.renmin.word --ctb_dir data/CTB --cache_dir data/cache --w2v_freeze --use_gpu
90 | ```
91 |
92 |
93 | #### Key classes and interfaces
94 |
95 | 1. SegmenterI
96 |
97 | The splitter has three interfaces for segmenting paragraphs into sentences, segmenting sentences into EDU, and segmenting paragraphs into EDU in one step, respectively.
98 |
99 | 2. ParserI
100 |
101 | The parser interface transforms Chinese paragraph EDUs into a discourse tree.
102 |
103 | 3. PipelineI
104 |
105 | Pipeline class, which assembles SegmenterI and ParserI as a complete DRS parser.
106 |
107 | 4. EDU, Relation, Sentence, Paragraph, and Discourse
108 |
109 | They correspond to the data structures of EDU, relation, sentence, paragraph, and discourse, respectively,
110 | which can also be regarded as list containers. Among them, the Paragraph structure represents the discourse
111 | tree in the Chinese CDTB corpus, and it can be visualized by calling the draw method.
112 |
113 |
114 | #### Evaluations
115 |
116 | In this paper, we report our performance based on the **soft** micro-averaged F1-score as detailed in
117 | the programs. In addition, this project also contains an unpublished **strict** evaluation method where
118 | the split position is also taken into consideration for more accurate NR prediction performance.
119 | Specifically, given two adjacent gold standard text spans (1, 5) and (6, 8), the upper-layer span is (1, 8),
120 | and we judge the correction of a predicted span only according to the span boundaries 1 and 8 in the original
121 | soft metric, and the NR tag between the two child nodes is also assigned. Moreover, if the predicted split
122 | position is 3 and the obtained child spans are (1, 3) and (4, 8), then the soft metric still thinks the span
123 | is correct, and the NR relation between (1, 3) and (4, 8) are predicted as that between (1, 5) and (6, 8) for
124 | evaluation. Obviously, the soft metric is far from rigorous. In this project, we also present a strict evaluation
125 | method where both the span boundaries and the split point are considered for span prediction evaluation. And
126 | the performance of this top-down DRS parser under the strict evaluation is **(84.0, 59.0, 54.2, 47.8) (macro-averaged)**.
127 | One can directly use these evaluation scripts for performance calculation.
128 |
129 | #### Application
130 |
131 | We've published a pre-trained end-to-end DRS parser at https://pan.baidu.com/s/1wY4em6ViF0T8LfZPAzTuOg, and the passcode is n6hx.
132 |
133 | ```
134 |
135 | -- License
136 | ```
137 | Copyright (c) 2019, Soochow University NLP research group. All rights reserved.
138 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that
139 | the following conditions are met:
140 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the
141 | following disclaimer.
142 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the
143 | following disclaimer in the documentation and/or other materials provided with the distribution.
144 | ```
145 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 |
3 | from .cdtb import CDTB
4 |
--------------------------------------------------------------------------------
/dataset/cdtb.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 |
3 | import os
4 | import re
5 | import pickle
6 | import gzip
7 | import hashlib
8 | import thulac
9 | import tqdm
10 | import logging
11 | from itertools import chain
12 | from nltk.tree import Tree as ParseTree
13 | from structure import Discourse, Sentence, EDU, TEXT, Connective, node_type_filter
14 |
15 |
16 | # note: this module will print a junk line "Model loaded succeed" to stdio when initializing
17 | thulac = thulac.thulac()
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | class CDTB:
22 | def __init__(self, cdtb_dir, train, validate, test, encoding="UTF-8",
23 | ctb_dir=None, ctb_encoding="UTF-8", cache_dir=None, preprocess=False):
24 | if cache_dir:
25 | hash_figure = "-".join(map(str, [cdtb_dir, train, validate, test, ctb_dir, preprocess]))
26 | hash_key = hashlib.md5(hash_figure.encode()).hexdigest()
27 | cache = os.path.join(cache_dir, hash_key + ".gz")
28 | else:
29 | cache = None
30 |
31 | if cache is not None and os.path.isfile(cache):
32 | with gzip.open(cache, "rb") as cache_fd:
33 | self.preprocess, self.ctb = pickle.load(cache_fd)
34 | self.train = pickle.load(cache_fd)
35 | self.validate = pickle.load(cache_fd)
36 | self.test = pickle.load(cache_fd)
37 | logger.info("load cached dataset from %s" % cache)
38 | return
39 |
40 | self.preprocess = preprocess
41 | self.ctb = self.load_ctb(ctb_dir, ctb_encoding) if ctb_dir else {}
42 | self.train = self.load_dir(cdtb_dir, train, encoding=encoding)
43 | self.validate = self.load_dir(cdtb_dir, validate, encoding=encoding)
44 | self.test = self.load_dir(cdtb_dir, test, encoding=encoding)
45 |
46 | if preprocess:
47 | for discourse in tqdm.tqdm(chain(self.train, self.validate, self.test), desc="preprocessing"):
48 | self.preprocessing(discourse)
49 |
50 | if cache is not None:
51 | with gzip.open(cache, "wb") as cache_fd:
52 | pickle.dump((self.preprocess, self.ctb), cache_fd)
53 | pickle.dump(self.train, cache_fd)
54 | pickle.dump(self.validate, cache_fd)
55 | pickle.dump(self.test, cache_fd)
56 | logger.info("saved cached dataset to %s" % cache)
57 |
58 | def report(self):
59 | # TODO
60 | raise NotImplementedError()
61 |
62 | def preprocessing(self, discourse):
63 | for paragraph in discourse:
64 | for sentence in paragraph.iterfind(filter=node_type_filter(Sentence)):
65 | if self.ctb and (sentence.sid is not None) and (sentence.sid in self.ctb):
66 | parse = self.ctb[sentence.sid]
67 | pairs = [(node[0], node.label()) for node in parse.subtrees()
68 | if node.height() == 2 and node.label() != "-NONE-"]
69 | words, tags = list(zip(*pairs))
70 | else:
71 | words, tags = list(zip(*thulac.cut(sentence.text)))
72 | setattr(sentence, "words", list(words))
73 | setattr(sentence, "tags", list(tags))
74 |
75 | offset = 0
76 | for textnode in sentence.iterfind(filter=node_type_filter([TEXT, Connective, EDU]),
77 | terminal=node_type_filter([TEXT, Connective, EDU])):
78 | if isinstance(textnode, EDU):
79 | edu_words = []
80 | edu_tags = []
81 | cur = 0
82 | for word, tag in zip(sentence.words, sentence.tags):
83 | if offset <= cur < cur + len(word) <= offset + len(textnode.text):
84 | edu_words.append(word)
85 | edu_tags.append(tag)
86 | cur += len(word)
87 | setattr(textnode, "words", edu_words)
88 | setattr(textnode, "tags", edu_tags)
89 | offset += len(textnode.text)
90 | return discourse
91 |
92 | @staticmethod
93 | def load_dir(path, sub, encoding="UTF-8"):
94 | train_path = os.path.join(path, sub)
95 | discourses = []
96 | for file in os.listdir(train_path):
97 | file = os.path.join(train_path, file)
98 | discourse = Discourse.from_xml(file, encoding=encoding)
99 | discourses.append(discourse)
100 | return discourses
101 |
102 | @staticmethod
103 | def load_ctb(ctb_dir, encoding="UTF-8"):
104 | ctb = {}
105 | s_pat = re.compile("\S+?)>(?P.*?)", re.M | re.DOTALL)
106 | for file in os.listdir(ctb_dir):
107 | with open(os.path.join(ctb_dir, file), "r", encoding=encoding) as fd:
108 | doc = fd.read()
109 | for match in s_pat.finditer(doc):
110 | sid = match.group("sid")
111 | sparse = ParseTree.fromstring(match.group("sparse"))
112 | ctb[sid] = sparse
113 | return ctb
114 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import logging
3 | import argparse
4 | from itertools import chain
5 | from dataset import CDTB
6 | from structure import node_type_filter, EDU, Paragraph, TEXT, Sentence
7 | from pipeline import build_pipeline
8 | from util import eval
9 | from tqdm import tqdm
10 | from util.eval import edu_eval, gen_edu_report
11 |
12 | logger = logging.getLogger("evaluation")
13 |
14 |
15 | def evaluate(args):
16 | pipeline = build_pipeline(schema=args.schema, segmenter_name=args.segmenter_name, use_gpu=args.use_gpu)
17 | cdtb = CDTB(args.data, "TRAIN", "VALIDATE", "TEST", ctb_dir=args.ctb_dir, preprocess=True, cache_dir=args.cache_dir)
18 | golds = list(filter(lambda d: d.root_relation(), chain(*cdtb.test)))
19 | parses = []
20 |
21 | if args.use_gold_edu:
22 | logger.info("evaluation with gold edu segmentation")
23 | else:
24 | logger.info("evaluation with auto edu segmentation")
25 |
26 | for para in tqdm(golds, desc="parsing", unit=" para"):
27 | if args.use_gold_edu:
28 | edus = []
29 | for edu in para.edus():
30 | edu_copy = EDU([TEXT(edu.text)])
31 | setattr(edu_copy, "words", edu.words)
32 | setattr(edu_copy, "tags", edu.tags)
33 | edus.append(edu_copy)
34 | else:
35 | sentences = []
36 | for sentence in para.sentences():
37 | if list(sentence.iterfind(node_type_filter(EDU))):
38 | copy_sentence = Sentence([TEXT([sentence.text])])
39 | if hasattr(sentence, "words"):
40 | setattr(copy_sentence, "words", sentence.words)
41 | if hasattr(sentence, "tags"):
42 | setattr(copy_sentence, "tags", sentence.tags)
43 | setattr(copy_sentence, "parse", cdtb.ctb[sentence.sid])
44 | sentences.append(copy_sentence)
45 | para = pipeline.cut_edu(Paragraph(sentences))
46 | edus = []
47 | for edu in para.edus():
48 | edu_copy = EDU([TEXT(edu.text)])
49 | setattr(edu_copy, "words", edu.words)
50 | setattr(edu_copy, "tags", edu.tags)
51 | edus.append(edu_copy)
52 | parse = pipeline.parse(Paragraph(edus))
53 | parses.append(parse)
54 |
55 | # edu score
56 | scores = edu_eval(golds, parses)
57 | logger.info("EDU segmentation scores:")
58 | logger.info(gen_edu_report(scores))
59 |
60 | # parser score
61 | cdtb_macro_scores = eval.parse_eval(parses, golds, average="macro")
62 | logger.info("CDTB macro (strict) scores:")
63 | logger.info(eval.gen_parse_report(*cdtb_macro_scores))
64 |
65 | # nuclear scores
66 | nuclear_scores = eval.nuclear_eval(parses, golds)
67 | logger.info("nuclear scores:")
68 | logger.info(eval.gen_category_report(nuclear_scores))
69 |
70 | # relation scores
71 | ctype_scores, ftype_scores = eval.relation_eval(parses, golds)
72 | logger.info("coarse relation scores:")
73 | logger.info(eval.gen_category_report(ctype_scores))
74 | logger.info("fine relation scores:")
75 | logger.info(eval.gen_category_report(ftype_scores))
76 |
77 | # height eval
78 | height_scores = eval.height_eval(parses, golds)
79 | logger.info("structure precision by node height:")
80 | logger.info(eval.gen_height_report(height_scores))
81 |
82 |
83 | if __name__ == '__main__':
84 | logging.basicConfig(level=logging.INFO)
85 | arg_parser = argparse.ArgumentParser()
86 | # dataset parameters
87 | arg_parser.add_argument("data")
88 | arg_parser.add_argument("--ctb_dir")
89 | arg_parser.add_argument("--cache_dir")
90 | arg_parser.add_argument("-schema", default="topdown")
91 | arg_parser.add_argument("-segmenter_name", default="gcn")
92 | arg_parser.add_argument("--use_gold_edu", dest="use_gold_edu", action="store_true")
93 | arg_parser.set_defaults(use_gold_edu=False)
94 | arg_parser.add_argument("--use_gpu", dest="use_gpu", action="store_true")
95 | arg_parser.set_defaults(use_gpu=False)
96 | evaluate(arg_parser.parse_args())
97 |
--------------------------------------------------------------------------------
/interface.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | from abc import abstractmethod
3 |
4 | from typing import List
5 | from structure.nodes import Sentence, Paragraph, EDU
6 |
7 |
8 | class SegmenterI:
9 | @abstractmethod
10 | def cut(self, text: str) -> Paragraph:
11 | raise NotImplemented()
12 |
13 | @abstractmethod
14 | def cut_sent(self, text: str, sid=None) -> List[Sentence]:
15 | raise NotImplemented()
16 |
17 | @abstractmethod
18 | def cut_edu(self, sent: Sentence) -> List[EDU]:
19 | raise NotImplemented()
20 |
21 |
22 | class ParserI:
23 | @abstractmethod
24 | def parse(self, para: Paragraph) -> Paragraph:
25 | raise NotImplemented()
26 |
27 |
28 | class PipelineI:
29 | @abstractmethod
30 | def cut_sent(self, text: str) -> Paragraph:
31 | raise NotImplemented()
32 |
33 | @abstractmethod
34 | def cut_edu(self, para: Paragraph) -> Paragraph:
35 | raise NotImplemented()
36 |
37 | @abstractmethod
38 | def parse(self, para: Paragraph) -> Paragraph:
39 | raise NotImplemented()
40 |
41 | @abstractmethod
42 | def full_parse(self, text: str):
43 | raise NotImplemented()
44 |
45 | def __call__(self, text: str):
46 | return self.full_parse(text)
47 |
--------------------------------------------------------------------------------
/new_ctb.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import re
3 | import os
4 | from nltk.tree import Tree as ParseTree
5 | from util.berkely import BerkeleyParser
6 | from tqdm import tqdm
7 |
8 |
9 | if __name__ == '__main__':
10 | ctb_dir = "data/CTB"
11 | save_dir = "data/CTB_auto"
12 | encoding = "UTF-8"
13 | ctb = {}
14 | s_pat = re.compile(r"\S+?)>(?P.*?)", re.M | re.DOTALL)
15 | parser = BerkeleyParser()
16 | for file in tqdm(os.listdir(ctb_dir)):
17 | if os.path.isfile(os.path.join(save_dir, file)):
18 | continue
19 | print(file)
20 | with open(os.path.join(ctb_dir, file), "r", encoding=encoding) as fd:
21 | doc = fd.read()
22 | parses = []
23 | for match in s_pat.finditer(doc):
24 | sid = match.group("sid")
25 | sparse = ParseTree.fromstring(match.group("sparse"))
26 | pairs = [(node[0], node.label()) for node in sparse.subtrees()
27 | if node.height() == 2 and node.label() != "-NONE-"]
28 | words, tags = list(zip(*pairs))
29 | print(sid, " ".join(words))
30 | if sid == "5133":
31 | parse = sparse
32 | else:
33 | parse = parser.parse(words, timeout=2000)
34 | parses.append((sid, parse))
35 | with open(os.path.join(save_dir, file), "w+", encoding=encoding) as save_fd:
36 | for sid, parse in parses:
37 | save_fd.write("\n" % sid)
38 | save_fd.write(str(parse))
39 | save_fd.write("\n\n")
40 |
--------------------------------------------------------------------------------
/parse.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # coding: UTF-8
3 | import argparse
4 | from structure import Discourse
5 | from pipeline import build_pipeline
6 | import logging
7 | import tqdm
8 |
9 |
10 | def run(args):
11 | logger = logging.getLogger("dp")
12 | doc = Discourse()
13 | pipeline = build_pipeline(schema=args.schema, segmenter_name=args.segmenter_name, use_gpu=args.use_gpu)
14 | with open(args.source, "r", encoding=args.encoding) as source_fd:
15 | for line in tqdm.tqdm(source_fd, desc="parsing %s" % args.source, unit=" para"):
16 | line = line.strip()
17 | if line:
18 | para = pipeline(line)
19 | if args.draw:
20 | para.draw()
21 | doc.append(para)
22 | logger.info("save parsing to %s" % args.save)
23 | doc.to_xml(args.save, encoding=args.encoding)
24 |
25 |
26 | if __name__ == '__main__':
27 | logging.basicConfig(level=logging.INFO)
28 | arg_parser = argparse.ArgumentParser()
29 | arg_parser.add_argument("source")
30 | arg_parser.add_argument("save")
31 | arg_parser.add_argument("-schema", default="topdown")
32 | arg_parser.add_argument("-segmenter_name", default="svm")
33 | arg_parser.add_argument("--encoding", default="utf-8")
34 | arg_parser.add_argument("--draw", dest="draw", action="store_true")
35 | arg_parser.add_argument("--use_gpu", dest="use_gpu", action="store_true")
36 | arg_parser.set_defaults(use_gpu=False)
37 | arg_parser.set_defaults(draw=False)
38 | run(arg_parser.parse_args())
39 |
--------------------------------------------------------------------------------
/pipeline.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import logging
3 | import pickle
4 | import threading
5 | from interface import PipelineI, SegmenterI, ParserI
6 | from segmenter.gcn import GCNSegmenter
7 | from segmenter.svm import SVMSegmenter
8 | from treebuilder.partptr import PartPtrParser
9 | from treebuilder.shiftreduce import ShiftReduceParser
10 | from structure import Paragraph, Sentence
11 | import torch
12 |
13 |
14 | class BasicPipeline(PipelineI):
15 | def __init__(self, segmenter, parser):
16 | super(BasicPipeline, self).__init__()
17 | self.segmenter = segmenter # type: SegmenterI
18 | self.parser = parser # type: ParserI
19 |
20 | def cut_sent(self, text: str, sid=None):
21 | return Paragraph(self.segmenter.cut_sent(text, sid=sid))
22 |
23 | def cut_edu(self, para: Paragraph) -> Paragraph:
24 | edus = []
25 | for sentence in para.sentences():
26 | edus.extend(self.segmenter.cut_edu(sentence))
27 | return Paragraph(edus)
28 |
29 | def parse(self, para: Paragraph) -> Paragraph:
30 | return self.parser.parse(para)
31 |
32 | def full_parse(self, text: str):
33 | para = self.cut_sent(text)
34 | para = self.cut_edu(para)
35 | para = self.parse(para)
36 | return para
37 |
38 |
39 | class ShiftReducePipeline(BasicPipeline):
40 | def __init__(self, segmenter_name="gcn", use_gpu=False):
41 | if use_gpu and (not torch.cuda.is_available()):
42 | raise Warning("cuda is not available, set use_gpu to False")
43 | if segmenter_name == "svm":
44 | with open("pub/models/segmenter.svm.model", "rb") as segmenter_fd:
45 | segmenter_model = pickle.load(segmenter_fd)
46 | segmenter = SVMSegmenter(segmenter_model)
47 | elif segmenter_name == "gcn":
48 | with open("pub/models/segmenter.gcn.model", "rb") as segmenter_fd:
49 | segmenter_model = torch.load(segmenter_fd, map_location="cpu")
50 | segmenter_model.use_gpu = False
51 | if use_gpu:
52 | segmenter_model.cuda()
53 | segmenter_model.use_gpu = True
54 | segmenter_model.eval()
55 | segmenter = GCNSegmenter(segmenter_model)
56 | with open("pub/models/treebuilder.shiftreduce.model", "rb") as parser_fd:
57 | parser_model = torch.load(parser_fd, map_location="cpu")
58 | parser_model.use_gpu = False
59 | if use_gpu:
60 | parser_model.cuda()
61 | parser_model.use_gpu = True
62 | parser_model.eval()
63 | parser = ShiftReduceParser(parser_model)
64 | super(ShiftReducePipeline, self).__init__(segmenter, parser)
65 |
66 |
67 | class TopDownPipeline(BasicPipeline):
68 | def __init__(self, segmenter_name="gcn", use_gpu=False):
69 | if use_gpu and (not torch.cuda.is_available()):
70 | raise Warning("cuda is not available, set use_gpu to False")
71 | if segmenter_name == "svm":
72 | with open("pub/models/segmenter.svm.model", "rb") as segmenter_fd:
73 | segmenter_model = pickle.load(segmenter_fd)
74 | segmenter = SVMSegmenter(segmenter_model)
75 | elif segmenter_name == "gcn":
76 | with open("pub/models/segmenter.gcn.model", "rb") as segmenter_fd:
77 | segmenter_model = torch.load(segmenter_fd, map_location="cpu")
78 | segmenter_model.use_gpu = False
79 | if use_gpu:
80 | segmenter_model.cuda()
81 | segmenter_model.use_gpu = True
82 | segmenter_model.eval()
83 | segmenter = GCNSegmenter(segmenter_model)
84 | else:
85 | raise NotImplemented("no segmenter found for name \"%s\"" % segmenter_name)
86 | with open("pub/models/treebuilder.partptr.model", "rb") as parser_fd:
87 | parser_model = torch.load(parser_fd, map_location="cpu")
88 | parser_model.use_gpu = False
89 | if use_gpu:
90 | parser_model.cuda()
91 | parser_model.use_gpu = True
92 | parser_model.eval()
93 | parser = PartPtrParser(parser_model)
94 | super(TopDownPipeline, self).__init__(segmenter, parser)
95 |
96 |
97 | def build_pipeline(schema="topdown", segmenter_name="gcn", use_gpu=False):
98 | logging.info("parsing thread %s build pipeline with %s schema and %s segmenter" %
99 | (threading.current_thread().name, schema, segmenter_name))
100 |
101 | if schema == "topdown":
102 | return TopDownPipeline(segmenter_name=segmenter_name, use_gpu=use_gpu)
103 | elif schema == "shiftreduce":
104 | return ShiftReducePipeline(segmenter_name=segmenter_name, use_gpu=use_gpu)
105 | else:
106 | raise NotImplemented("no schema found for \"%s\"" % schema)
107 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | nltk
2 | torch>=0.4
3 | thulac
4 | tqdm
5 | gensim
6 | numpy
7 | tensorboardX
8 |
9 | matplotlib
10 |
--------------------------------------------------------------------------------
/rst_dev_description.txt:
--------------------------------------------------------------------------------
1 | RST-DT does not have an established dev set, so we extracted 34 articles from the training set to serve as the dev corpus.
2 | Since RST-DT is not for free use, I provide the names of these articles as follows:
3 | wsj_0614
4 | wsj_0615
5 | wsj_0617
6 | wsj_0620
7 | wsj_0624
8 | wsj_0635
9 | wsj_0651
10 | wsj_0652
11 | wsj_0662
12 | wsj_0671
13 | wsj_0687
14 | wsj_0688
15 | wsj_0692
16 | wsj_1100
17 | wsj_1101
18 | wsj_1119
19 | wsj_1133
20 | wsj_1149
21 | wsj_1162
22 | wsj_1193
23 | wsj_1194
24 | wsj_1195
25 | wsj_1313
26 | wsj_1337
27 | wsj_1341
28 | wsj_1344
29 | wsj_1355
30 | wsj_1962
31 | wsj_1970
32 | wsj_1973
33 | wsj_2365
34 | wsj_2367
35 | wsj_2393
36 | wsj_2396
37 |
--------------------------------------------------------------------------------
/sample.txt:
--------------------------------------------------------------------------------
1 | 浦东开发开放是一项振兴上海,建设现代化经济、贸易、金融中心的跨世纪工程,因此大量出现的是以前不曾遇到过的新情况、新问题。对此,浦东不是简单的采取“干一段时间,等积累了经验以后再制定法规条例”的做法,而是借鉴发达国家和深圳等特区的经验教训,聘请国内外有关专家学者,积极、及时地制定和推出法规性文件,使这些经济活动一出现就被纳入法制轨道。去年初浦东新区诞生的中国第一家医疗机构药品采购服务中心,正因为一开始就比较规范,运转至今,成交药品一亿多元,没有发现一例回扣。
2 | 建筑是开发浦东的一项主要经济活动,这些年有数百家建筑公司、四千余个建筑工地遍布在这片热土上。为规范建筑行为,防止出现无序现象,新区管委会根据国家和上海市的有关规定,结合浦东开发实际,及时出台了一系列规范建设市场的文件,其中包括工程施工招投标管理办法、拆迁工作若干规定、整治违章建筑实施办法、通信设施及管线配套建设意见、建设工地施工环境管理暂行办法等,基本做到了每个环节都有明确而又具体的规定。建筑公司进区,有关部门先送上这些法规性文件,然后有专门队伍进行监督检查。
--------------------------------------------------------------------------------
/sample.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | 浦东开发开放是一项振兴上海,
8 |
9 |
10 |
11 | 建设现代化经济、贸易、金融中心的跨世纪工程,
12 |
13 |
14 | 因此大量出现的是以前不曾遇到过的新情况、新问题。
15 |
16 |
17 |
18 |
19 |
20 |
21 | 对此,浦东不是简单的采取“干一段时间,等积累了经验以后再制定法规条例”的做法,
22 |
23 |
24 |
25 |
26 | 而是借鉴发达国家和深圳等特区的经验教训,
27 |
28 |
29 |
30 | 聘请国内外有关专家学者,
31 |
32 |
33 | 积极、及时地制定和推出法规性文件,
34 |
35 |
36 |
37 |
38 | 使这些经济活动一出现就被纳入法制轨道。
39 |
40 |
41 |
42 |
43 |
44 | 去年初浦东新区诞生的中国第一家医疗机构药品采购服务中心,正因为一开始就比较规范,
45 |
46 |
47 |
48 | 运转至今,
49 |
50 |
51 |
52 | 成交药品一亿多元,
53 |
54 |
55 | 没有发现一例回扣。
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 | 建筑是开发浦东的一项主要经济活动,
68 |
69 |
70 | 这些年有数百家建筑公司、四千余个建筑工地遍布在这片热土上。
71 |
72 |
73 |
74 |
75 | 为规范建筑行为,
76 |
77 |
78 |
79 | 防止出现无序现象,
80 |
81 |
82 |
83 | 新区管委会根据国家和上海市的有关规定,
84 |
85 |
86 |
87 |
88 | 结合浦东开发实际,
89 |
90 |
91 |
92 | 及时出台了一系列规范建设市场的文件,
93 |
94 |
95 |
96 | 其中包括工程施工招投标管理办法、拆迁工作若干规定、整治违章建筑实施办法、通信设施及管线配套建设意见、建设工地施工环境管理暂行办法等,
97 |
98 |
99 | 基本做到了每个环节都有明确而又具体的规定。
100 |
101 |
102 |
103 |
104 |
105 |
106 | 建筑公司进区,
107 |
108 |
109 |
110 | 有关部门先送上这些法规性文件,
111 |
112 |
113 | 然后有专门队伍进行监督检查。
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
--------------------------------------------------------------------------------
/segmenter/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/t2d_discourseparser/ce552908b1907cf8b59db11802811a6468c9bfc9/segmenter/__init__.py
--------------------------------------------------------------------------------
/segmenter/gcn/__init__.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | from .segmenter import GCNSegmenter
3 |
--------------------------------------------------------------------------------
/segmenter/gcn/model.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
6 |
7 |
8 | class SyntacticGCN(nn.Module):
9 | def __init__(self, input_size, hidden_size, num_labels, bias=True):
10 | super(SyntacticGCN, self).__init__()
11 | self.input_size = input_size
12 | self.hidden_size = hidden_size
13 | self.num_labels = num_labels
14 | self.W = nn.Parameter(torch.empty(num_labels, input_size, hidden_size, dtype=torch.float))
15 | nn.init.xavier_normal_(self.W)
16 | if bias:
17 | self.bias = True
18 | self.b = nn.Parameter(torch.empty(num_labels, hidden_size, dtype=torch.float))
19 | nn.init.xavier_normal_(self.b)
20 |
21 | def forward(self, graph, nodes):
22 | # graph (b, n, n, l)
23 | # nodes (b, n, input_size)
24 | b, n, _ = nodes.size()
25 | l, input_size, hidden_size = self.num_labels, self.input_size, self.hidden_size
26 | # graph (b, n*l, n)
27 | g = graph.transpose(2, 3).float().contiguous().view(b, n*l, n)
28 | # x: (b, n, l*input_size)
29 | x = g.bmm(nodes).view(b, n, l*input_size)
30 | # h: (b, n, hidden_size)
31 | h = x.matmul(self.W.view(l*input_size, hidden_size))
32 | if self.bias:
33 | bias = (graph.float().view(b*n*n, l) @ self.b).view(b, n, n, hidden_size)
34 | bias = bias.sum(2)
35 | h = h + bias
36 | norm = graph.view(b, n, n*l).sum(-1).float().unsqueeze(-1) + 1e-10
37 | # h: (b, n, hidden_size)
38 | hidden = F.relu(h / norm)
39 | return hidden
40 |
41 |
42 | class GCNSegmenterModel(nn.Module):
43 | def __init__(self, hidden_size, dropout, rnn_layers, gcn_layers, word_vocab, pos_vocab, gcn_vocab, tag_label,
44 | pos_size=30, pretrained=None, w2v_size=None, w2v_freeze=False,
45 | use_gpu=False):
46 | super(GCNSegmenterModel, self).__init__()
47 | self.use_gpu = use_gpu
48 | self.word_vocab = word_vocab
49 | self.pos_vocab = pos_vocab
50 | self.gcn_vocab = gcn_vocab
51 | self.tag_label = tag_label
52 | self.word_emb = word_vocab.embedding(pretrained=pretrained, dim=w2v_size, freeze=w2v_freeze, use_gpu=use_gpu)
53 | self.w2v_size = self.word_emb.weight.shape[-1]
54 | self.pos_emb = pos_vocab.embedding(dim=pos_size, use_gpu=use_gpu)
55 | self.pos_size = pos_size
56 | self.hidden_size = hidden_size
57 | self.dropout_p = dropout
58 | self.rnn_layers = rnn_layers
59 | self.rnn = nn.LSTM(self.w2v_size+self.pos_size, self.hidden_size // 2,
60 | num_layers=rnn_layers, dropout=dropout, bidirectional=True, batch_first=True)
61 | self.gcn_layers = gcn_layers
62 | self.gcns = nn.ModuleList([SyntacticGCN(hidden_size, hidden_size, len(gcn_vocab)) for _ in range(gcn_layers)])
63 | self.tagger = nn.Linear(hidden_size, len(tag_label))
64 |
65 | def forward(self, word_ids, pos_ids, graph, masks=None):
66 | self.rnn.flatten_parameters()
67 | word_emb = self.word_emb(word_ids)
68 | pos_emb = self.pos_emb(pos_ids)
69 | rnn_inputs = torch.cat([word_emb, pos_emb], dim=-1)
70 | if masks is not None:
71 | lengths = masks.sum(-1)
72 | rnn_inputs = rnn_inputs * masks.unsqueeze(-1).float()
73 | rnn_inputs_packed = pack_padded_sequence(rnn_inputs, lengths, batch_first=True)
74 | rnn_outputs_packed, _ = self.rnn(rnn_inputs_packed)
75 | rnn_outputs, _ = pad_packed_sequence(rnn_outputs_packed, batch_first=True)
76 | else:
77 | rnn_outputs, _ = self.rnn(rnn_inputs)
78 |
79 | for gcn in self.gcns:
80 | gcn_outputs = gcn(graph, rnn_outputs)
81 | tag_score = self.tagger(gcn_outputs)
82 | return tag_score
83 |
84 | def loss(self, inputs, target):
85 | word_ids, pos_ids, graph, masks = inputs
86 | batch_size, max_seqlen = word_ids.size()
87 | pred = F.log_softmax(self(word_ids, pos_ids, graph, masks), dim=-1)
88 | pred = pred.view(batch_size*max_seqlen, -1)
89 | target = target.view(-1)
90 | masks = masks.view(-1)
91 | losses = F.nll_loss(pred, target, reduction='none')
92 | loss = (losses * masks.float()).sum() / masks.sum().float()
93 | return loss
94 |
--------------------------------------------------------------------------------
/segmenter/gcn/segmenter.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import torch
3 | from interface import SegmenterI
4 | from structure import Sentence, EDU, Paragraph, TEXT
5 | from util.berkely import BerkeleyParser
6 | from util.ltp import LTPParser
7 |
8 |
9 | class GCNSegmenter(SegmenterI):
10 | def __init__(self, model):
11 | self._eos = ['!', '。', '?']
12 | self._pairs = {'“': "”", "「": "」"}
13 | self.model = model
14 | self.model.eval()
15 | self.parser = BerkeleyParser()
16 | self.dep_parser = LTPParser()
17 |
18 | def cut(self, text):
19 | sentences = self.cut_sent(text)
20 | for i, sent in enumerate(sentences):
21 | sentences[i] = Sentence(self.cut_edu(sent))
22 | return Paragraph(sentences)
23 |
24 | def cut_sent(self, text, sid=None):
25 | last_cut = 0
26 | sentences = []
27 | for i in range(0, len(text) - 1):
28 | if text[i] in self._eos:
29 | sentences.append(Sentence([TEXT(text[last_cut: i + 1])]))
30 | last_cut = i + 1
31 | if last_cut < len(text) - 1:
32 | sentences.append(Sentence([TEXT(text[last_cut:])]))
33 | return sentences
34 |
35 | def cut_edu(self, sent):
36 | if (not hasattr(sent, "words")) or (not hasattr(sent, "tags")):
37 | if hasattr(sent, "parse"):
38 | parse = getattr(sent, "parse")
39 | else:
40 | parse = self.parser.parse(sent.text)
41 | children = list(parse.subtrees(lambda t: t.height() == 2 and t.label() != '-NONE-'))
42 | setattr(sent, "words", [child[0] for child in children])
43 | setattr(sent, "tags", [child.label() for child in children])
44 |
45 | if not hasattr(sent, "dependency"):
46 | dep = self.dep_parser.parse(sent.words)
47 | setattr(sent, "dependency", dep)
48 |
49 | word_ids = [self.model.word_vocab[word] for word in sent.words]
50 | pos_ids = [self.model.pos_vocab[pos] for pos in sent.tags]
51 | word_ids = torch.tensor([word_ids]).long()
52 | pos_ids = torch.tensor([pos_ids]).long()
53 | graph = torch.zeros((1, word_ids.size(1), word_ids.size(1), len(self.model.gcn_vocab)), dtype=torch.uint8)
54 | for i, token in enumerate(sent.dependency):
55 | graph[0, i, i, self.model.gcn_vocab['self']] = 1
56 | graph[0, i, token.head-1, self.model.gcn_vocab['head']] = 1
57 | graph[0, token.head-1, i, self.model.gcn_vocab['dep']] = 1
58 | if self.model.use_gpu:
59 | word_ids = word_ids.cuda()
60 | pos_ids = pos_ids.cuda()
61 | graph = graph.cuda()
62 | pred = self.model(word_ids, pos_ids, graph).squeeze(0)
63 | labels = [self.model.tag_label.id2label[t] for t in pred.argmax(-1)]
64 |
65 | edus = []
66 | last_edu_words = []
67 | last_edu_tags = []
68 | for word, pos, label in zip(sent.words, sent.tags, labels):
69 | last_edu_words.append(word)
70 | last_edu_tags.append(pos)
71 | if label == "B":
72 | text = "".join(last_edu_words)
73 | edu = EDU([TEXT(text)])
74 | setattr(edu, "words", last_edu_words)
75 | setattr(edu, "tags", last_edu_tags)
76 | edus.append(edu)
77 | last_edu_words = []
78 | last_edu_tags = []
79 | if last_edu_words:
80 | text = "".join(last_edu_words)
81 | edu = EDU([TEXT(text)])
82 | setattr(edu, "words", last_edu_words)
83 | setattr(edu, "tags", last_edu_tags)
84 | edus.append(edu)
85 | return edus
86 |
--------------------------------------------------------------------------------
/segmenter/gcn/test.py:
--------------------------------------------------------------------------------
1 | # coding: UTf-8
2 | import logging
3 | import torch
4 | import tqdm
5 | from itertools import chain
6 | from segmenter.gcn import GCNSegmenter
7 | from dataset import CDTB
8 | from structure import node_type_filter, Sentence, Paragraph, EDU
9 | from util.eval import edu_eval, gen_edu_report
10 | from .train import preprocessing
11 |
12 |
13 | logger = logging.getLogger("test gcn segmenter")
14 |
15 |
16 | if __name__ == '__main__':
17 | logging.basicConfig(level=logging.INFO)
18 | with open("data/models/segmenter.gcn.model", "rb") as model_fd:
19 | model = torch.load(model_fd, map_location='cpu')
20 | model.use_gpu = False
21 | model.eval()
22 | segmenter = GCNSegmenter(model)
23 | cdtb = CDTB("data/CDTB", "TRAIN", "VALIDATE", "TEST", ctb_dir="data/CTB", preprocess=True, cache_dir="data/cache")
24 | preprocessing(cdtb)
25 | ctb = cdtb.ctb
26 |
27 | golds = []
28 | segs = []
29 | for paragraph in tqdm.tqdm(chain(*cdtb.test), desc="segmenting"):
30 | seged_sents = []
31 | for sentence in paragraph.sentences():
32 | # make sure sentence has edus
33 | if list(sentence.iterfind(node_type_filter(EDU))):
34 | setattr(sentence, "parse", ctb[sentence.sid])
35 | seged_sents.append(Sentence(segmenter.cut_edu(sentence)))
36 | if seged_sents:
37 | segs.append(Paragraph(seged_sents))
38 | golds.append(paragraph)
39 | scores = edu_eval(segs, golds)
40 | logger.info(gen_edu_report(scores))
41 |
--------------------------------------------------------------------------------
/segmenter/gcn/train.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import random
3 | from collections import Counter
4 | from structure.vocab import Vocab, Label
5 | from structure.nodes import node_type_filter, EDU, Sentence, Paragraph
6 | from itertools import chain
7 | import numpy as np
8 | from dataset import CDTB
9 | import logging
10 | import tqdm
11 | import argparse
12 | import torch
13 | import torch.optim as optim
14 | from segmenter.gcn.model import GCNSegmenterModel
15 | from segmenter.gcn import GCNSegmenter
16 | from util.eval import edu_eval, gen_edu_report
17 | from util.ltp import LTPParser
18 |
19 | logger = logging.getLogger("train gcn segmenter")
20 |
21 |
22 | def preprocessing(cdtb):
23 | with LTPParser() as parser:
24 | deps = {}
25 | logger.info("add dependency information to sentence")
26 | for paragraph in tqdm.tqdm(chain(*chain(cdtb.train, cdtb.validate, cdtb.test)), desc="parsing", unit=" sentence"):
27 | for sentence in paragraph.sentences():
28 | if sentence.sid not in deps:
29 | dep = parser.parse(sentence.words)
30 | deps[sentence.sid] = dep
31 | else:
32 | dep = deps[sentence.sid]
33 | setattr(sentence, "dependency", dep)
34 |
35 |
36 | def build_vocab(dataset):
37 | word_freq = Counter()
38 | pos_freq = Counter()
39 | for paragraph in chain(*dataset):
40 | for edu in paragraph.edus():
41 | word_freq.update(edu.words)
42 | pos_freq.update(edu.tags)
43 | word_vocab = Vocab("word", word_freq)
44 | pos_vocab = Vocab("part of speech", pos_freq)
45 | gcn_vocab = Vocab("gcn tag", Counter(["dep", "head", "self"]))
46 | return word_vocab, pos_vocab, gcn_vocab
47 |
48 |
49 | def gen_train_instances(dataset):
50 | instances = []
51 | tags = []
52 | for paragraph in chain(*dataset):
53 | for sentence in paragraph.sentences():
54 | edus = list(sentence.iterfind(node_type_filter(EDU)))
55 | if edus:
56 | sent_words = []
57 | sent_poses = []
58 | sent_tags = []
59 | graph = []
60 | for i, edu in enumerate(edus):
61 | words = edu.words
62 | poses = edu.tags
63 | label = ['O'] * (len(words) - 1)
64 | label += ['B'] if i < len(edus) - 1 else ['O']
65 | sent_words.extend(words)
66 | sent_poses.extend(poses)
67 | sent_tags.extend(label)
68 | for i, token in enumerate(sentence.dependency):
69 | graph.append((i, i, "self"))
70 | if token.head > 0:
71 | graph.append((i, token.head-1, "head"))
72 | graph.append((token.head-1, i, "dep"))
73 | instances.append((sent_words, sent_poses, graph))
74 | tags.append(sent_tags)
75 | return instances, tags
76 |
77 |
78 | def numericalize(instances, tags, word_vocab, pos_vocab, gcn_vocab, tag_label):
79 | trainset = []
80 | for (words, poses, graph), tags in zip(instances, tags):
81 | word_ids = [word_vocab[word] for word in words]
82 | pos_ids = [pos_vocab[pos] for pos in poses]
83 | tag_ids = [tag_label[tag] for tag in tags]
84 | graph_ids = [(x, y, gcn_vocab[l]) for x, y, l in graph]
85 | trainset.append((word_ids, pos_ids, graph_ids, tag_ids))
86 | return trainset
87 |
88 |
89 | def gen_batch_iter(trainset, batch_size, num_gcn_label, use_gpu=False):
90 | random_instances = np.random.permutation(trainset)
91 | num_instances = len(trainset)
92 | offset = 0
93 | while offset < num_instances:
94 | batch = random_instances[offset: min(num_instances, offset + batch_size)]
95 | num_batch = batch.shape[0]
96 | lengths = np.zeros(num_batch, dtype=np.int)
97 | for i, (word_ids, pos_ids, graph_ids, tag_ids) in enumerate(batch):
98 | lengths[i] = len(word_ids)
99 | sort_indices = np.argsort(-lengths)
100 | lengths = lengths[sort_indices]
101 | batch = batch[sort_indices]
102 | max_seqlen = lengths.max()
103 | word_inputs = np.zeros([num_batch, max_seqlen], dtype=np.long)
104 | pos_inputs = np.zeros([num_batch, max_seqlen], dtype=np.long)
105 | graph_inputs = np.zeros([num_batch, max_seqlen, max_seqlen, num_gcn_label], np.uint8)
106 | tag_outputs = np.zeros([num_batch, max_seqlen], dtype=np.long)
107 | masks = np.zeros([num_batch, max_seqlen], dtype=np.uint8)
108 | for i, (word_ids, pos_ids, graph_ids, tag_ids) in enumerate(batch):
109 | seqlen = len(word_ids)
110 | word_inputs[i][:seqlen] = word_ids
111 | pos_inputs[i][:seqlen] = pos_ids
112 | tag_outputs[i][:seqlen] = tag_ids
113 | for x, y, z in graph_ids:
114 | graph_inputs[i, x, y, z] = 1
115 | masks[i][:seqlen] = 1
116 | offset = offset + batch_size
117 |
118 | word_inputs = torch.from_numpy(word_inputs).long()
119 | pos_inputs = torch.from_numpy(pos_inputs).long()
120 | tag_outputs = torch.from_numpy(tag_outputs).long()
121 | graph_inputs = torch.from_numpy(graph_inputs).byte()
122 | masks = torch.from_numpy(masks).byte()
123 |
124 | if use_gpu:
125 | word_inputs = word_inputs.cuda()
126 | pos_inputs = pos_inputs.cuda()
127 | tag_outputs = tag_outputs.cuda()
128 | graph_inputs = graph_inputs.cuda()
129 | masks = masks.cuda()
130 | yield (word_inputs, pos_inputs, graph_inputs, masks), tag_outputs
131 |
132 |
133 | def evaluate(dataset, model):
134 | model.eval()
135 | segmenter = GCNSegmenter(model)
136 | golds = []
137 | segs = []
138 | for paragraph in chain(*dataset):
139 | seged_sents = []
140 | for sentence in paragraph.sentences():
141 | # make sure sentence has edus
142 | if list(sentence.iterfind(node_type_filter(EDU))):
143 | seged_sents.append(Sentence(segmenter.cut_edu(sentence)))
144 | if seged_sents:
145 | segs.append(Paragraph(seged_sents))
146 | golds.append(paragraph)
147 | return edu_eval(segs, golds)
148 |
149 |
150 | def get_lr(optimizer):
151 | for param_group in optimizer.param_groups:
152 | return param_group['lr']
153 |
154 |
155 | def main(args):
156 | random.seed(args.seed)
157 | torch.random.manual_seed(args.seed)
158 | np.random.seed(args.seed)
159 |
160 | logger.info("args:" + str(args))
161 | # load dataset
162 | cdtb = CDTB(args.data, "TRAIN", "VALIDATE", "TEST", ctb_dir=args.ctb_dir, preprocess=True, cache_dir=args.cache_dir)
163 | # builkd
164 | preprocessing(cdtb)
165 | word_vocab, pos_vocab, gcn_vocab = build_vocab(cdtb.train)
166 | instances, tags = gen_train_instances(cdtb.train)
167 | tag_label = Label("tag", Counter(chain(*tags)))
168 | trainset = numericalize(instances, tags, word_vocab, pos_vocab, gcn_vocab, tag_label)
169 |
170 | # build model
171 | model = GCNSegmenterModel(hidden_size=args.hidden_size, dropout=args.dropout,
172 | rnn_layers=args.rnn_layers, gcn_layers=args.gcn_layers,
173 | word_vocab=word_vocab, pos_vocab=pos_vocab, gcn_vocab=gcn_vocab, tag_label=tag_label,
174 | pos_size=args.pos_size, pretrained=args.pretrained, w2v_freeze=args.w2v_freeze,
175 | use_gpu=args.use_gpu)
176 | if args.use_gpu:
177 | model.cuda()
178 | logger.info(model)
179 |
180 | # train
181 | step = 0
182 | best_model_f1 = 0
183 | wait_count = 0
184 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2)
185 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=3)
186 | for nepoch in range(1, args.epoch+1):
187 | batch_iter = gen_batch_iter(trainset, args.batch_size, len(gcn_vocab), use_gpu=args.use_gpu)
188 | for nbatch, (inputs, target) in enumerate(batch_iter, start=1):
189 | step += 1
190 | model.train()
191 | optimizer.zero_grad()
192 | loss = model.loss(inputs, target)
193 | loss.backward()
194 | optimizer.step()
195 | if nbatch > 0 and nbatch % args.log_every == 0:
196 | logger.info("step %d, patient %d, lr %f, epoch %d, batch %d, train loss %.4f" %
197 | (step, wait_count, get_lr(optimizer), nepoch, nbatch, loss.item()))
198 |
199 | # model selection
200 | score = evaluate(cdtb.validate, model)
201 | f1 = score[-1]
202 | scheduler.step(f1, nepoch)
203 | logger.info("evaluation score:")
204 | logger.info("\n" + gen_edu_report(score))
205 | if f1 > best_model_f1:
206 | wait_count = 0
207 | best_model_f1 = f1
208 | logger.info("save new best model to %s" % args.model_save)
209 | with open(args.model_save, "wb+") as model_fd:
210 | torch.save(model, model_fd)
211 | logger.info("test on new best model...")
212 | test_score = evaluate(cdtb.test, model)
213 | logger.info("test score:")
214 | logger.info("\n" + gen_edu_report(test_score))
215 | else:
216 | wait_count += 1
217 | if wait_count > args.patient:
218 | logger.info("early stopping...")
219 | break
220 |
221 | with open(args.model_save, "rb") as model_fd:
222 | best_model = torch.load(model_fd)
223 | test_score = evaluate(cdtb.test, best_model)
224 | logger.info("test score on final best model:")
225 | logger.info("\n" + gen_edu_report(test_score))
226 |
227 |
228 | if __name__ == '__main__':
229 | logging.basicConfig(level=logging.INFO)
230 | arg_parser = argparse.ArgumentParser()
231 | # dataset parameters
232 | arg_parser.add_argument("data")
233 | arg_parser.add_argument("--ctb_dir")
234 | arg_parser.add_argument("--cache_dir")
235 | arg_parser.add_argument("--seed", default=21, type=int)
236 | arg_parser.add_argument("-model_save", required=True)
237 |
238 | # model parameter
239 | arg_parser.add_argument("-hidden_size", default=256, type=int)
240 | arg_parser.add_argument("-rnn_layers", default=3, type=int)
241 | arg_parser.add_argument("-gcn_layers", default=2, type=int)
242 | arg_parser.add_argument("-dropout", default=0.33, type=float)
243 | w2v_group = arg_parser.add_mutually_exclusive_group(required=True)
244 | w2v_group.add_argument("-pretrained")
245 | w2v_group.add_argument("-w2v_size", type=int)
246 | arg_parser.add_argument("-pos_size", default=30, type=int)
247 | arg_parser.add_argument("--w2v_freeze", dest="w2v_freeze", action="store_true")
248 | arg_parser.set_defaults(w2v_freeze=False)
249 |
250 | # train parameter
251 | arg_parser.add_argument("-epoch", default=20, type=int)
252 | arg_parser.add_argument("-lr", default=0.001, type=float)
253 | arg_parser.add_argument("-l2", default=1e-6, type=float)
254 | arg_parser.add_argument("-patient", default=4, type=int)
255 | arg_parser.add_argument("-log_every", default=5, type=int)
256 | arg_parser.add_argument("-batch_size", default=64, type=int)
257 | arg_parser.add_argument("--use_gpu", dest="use_gpu", action="store_true")
258 | arg_parser.set_defaults(use_gpu=False)
259 | main(arg_parser.parse_args())
260 |
--------------------------------------------------------------------------------
/segmenter/rnn/__init__.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | from .segmenter import RNNSegmenter
3 |
--------------------------------------------------------------------------------
/segmenter/rnn/model.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
6 |
7 |
8 | class RNNSegmenterModel(nn.Module):
9 | def __init__(self, hidden_size, dropout, rnn_layers, word_vocab, pos_vocab, tag_label,
10 | pos_size=30, pretrained=None, w2v_size=None, w2v_freeze=False,
11 | use_gpu=False):
12 | super(RNNSegmenterModel, self).__init__()
13 | self.use_gpu = use_gpu
14 | self.word_vocab = word_vocab
15 | self.pos_vocab = pos_vocab
16 | self.tag_label = tag_label
17 | self.word_emb = word_vocab.embedding(pretrained=pretrained, dim=w2v_size, freeze=w2v_freeze, use_gpu=use_gpu)
18 | self.w2v_size = self.word_emb.weight.shape[-1]
19 | self.pos_emb = pos_vocab.embedding(dim=pos_size, use_gpu=use_gpu)
20 | self.pos_size = pos_size
21 | self.hidden_size = hidden_size
22 | self.dropout_p = dropout
23 | self.rnn = nn.LSTM(self.w2v_size+self.pos_size, self.hidden_size // 2,
24 | num_layers=rnn_layers, dropout=dropout, bidirectional=True, batch_first=True)
25 | self.tagger = nn.Linear(hidden_size, len(tag_label))
26 |
27 | def forward(self, word_ids, pos_ids, masks=None):
28 | self.rnn.flatten_parameters()
29 | word_emb = self.word_emb(word_ids)
30 | pos_emb = self.pos_emb(pos_ids)
31 | rnn_inputs = torch.cat([word_emb, pos_emb], dim=-1)
32 | if masks is not None:
33 | lengths = masks.sum(-1)
34 | rnn_inputs = rnn_inputs * masks.unsqueeze(-1).float()
35 | rnn_inputs_packed = pack_padded_sequence(rnn_inputs, lengths, batch_first=True)
36 | rnn_outputs_packed, _ = self.rnn(rnn_inputs_packed)
37 | rnn_outputs, _ = pad_packed_sequence(rnn_outputs_packed, batch_first=True)
38 | else:
39 | rnn_outputs, _ = self.rnn(rnn_inputs)
40 | tag_score = self.tagger(rnn_outputs)
41 | return tag_score
42 |
43 | def loss(self, inputs, target):
44 | word_ids, pos_ids, masks = inputs
45 | batch_size, max_seqlen = word_ids.size()
46 | pred = F.log_softmax(self(word_ids, pos_ids, masks), dim=-1)
47 | pred = pred.view(batch_size*max_seqlen, -1)
48 | target = target.view(-1)
49 | masks = masks.view(-1)
50 | losses = F.nll_loss(pred, target, reduction='none')
51 | loss = (losses * masks.float()).sum() / masks.sum().float()
52 | return loss
53 |
--------------------------------------------------------------------------------
/segmenter/rnn/segmenter.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import torch
3 | from interface import SegmenterI
4 | from structure import Sentence, EDU, Paragraph, TEXT
5 | from util.berkely import BerkeleyParser
6 |
7 |
8 | class RNNSegmenter(SegmenterI):
9 | def __init__(self, model):
10 | self._eos = ['!', '。', '?']
11 | self._pairs = {'“': "”", "「": "」"}
12 | self.model = model
13 | self.model.eval()
14 | self.parser = BerkeleyParser()
15 |
16 | def cut(self, text):
17 | sentences = self.cut_sent(text)
18 | for i, sent in enumerate(sentences):
19 | sentences[i] = Sentence(self.cut_edu(sent))
20 | return Paragraph(sentences)
21 |
22 | def cut_sent(self, text, sid=None):
23 | last_cut = 0
24 | sentences = []
25 | for i in range(0, len(text) - 1):
26 | if text[i] in self._eos:
27 | sentences.append(Sentence([TEXT(text[last_cut: i + 1])]))
28 | last_cut = i + 1
29 | if last_cut < len(text) - 1:
30 | sentences.append(Sentence([TEXT(text[last_cut:])]))
31 | return sentences
32 |
33 | def cut_edu(self, sent):
34 | if (not hasattr(sent, "words")) or (not hasattr(sent, "tags")):
35 | if hasattr(sent, "parse"):
36 | parse = getattr(sent, "parse")
37 | else:
38 | parse = self.parser.parse(sent.text)
39 | children = list(parse.subtrees(lambda t: t.height() == 2 and t.label() != '-NONE-'))
40 | setattr(sent, "words", [child[0] for child in children])
41 | setattr(sent, "tags", [child.label() for child in children])
42 | word_ids = [self.model.word_vocab[word] for word in sent.words]
43 | pos_ids = [self.model.pos_vocab[pos] for pos in sent.tags]
44 | word_ids = torch.tensor([word_ids]).long()
45 | pos_ids = torch.tensor([pos_ids]).long()
46 | if self.model.use_gpu:
47 | word_ids = word_ids.cuda()
48 | pos_ids = pos_ids.cuda()
49 | pred = self.model(word_ids, pos_ids).squeeze(0)
50 | labels = [self.model.tag_label.id2label[t] for t in pred.argmax(-1)]
51 |
52 | edus = []
53 | last_edu_words = []
54 | last_edu_tags = []
55 | for word, pos, label in zip(sent.words, sent.tags, labels):
56 | last_edu_words.append(word)
57 | last_edu_tags.append(pos)
58 | if label == "B":
59 | text = "".join(last_edu_words)
60 | edu = EDU([TEXT(text)])
61 | setattr(edu, "words", last_edu_words)
62 | setattr(edu, "tags", last_edu_tags)
63 | edus.append(edu)
64 | last_edu_words = []
65 | last_edu_tags = []
66 | if last_edu_words:
67 | text = "".join(last_edu_words)
68 | edu = EDU([TEXT(text)])
69 | setattr(edu, "words", last_edu_words)
70 | setattr(edu, "tags", last_edu_tags)
71 | edus.append(edu)
72 | return edus
73 |
--------------------------------------------------------------------------------
/segmenter/rnn/test.py:
--------------------------------------------------------------------------------
1 | # coding: UTf-8
2 | import logging
3 | import torch
4 | import tqdm
5 | from itertools import chain
6 | from segmenter.rnn import RNNSegmenter
7 | from dataset import CDTB
8 | from structure import node_type_filter, Sentence, Paragraph, EDU
9 | from util.eval import edu_eval, gen_edu_report
10 |
11 |
12 | logger = logging.getLogger("test rnn segmenter")
13 |
14 |
15 | if __name__ == '__main__':
16 | logging.basicConfig(level=logging.INFO)
17 | with open("data/models/segmenter.rnn.model", "rb") as model_fd:
18 | model = torch.load(model_fd, map_location='cpu')
19 | model.use_gpu = False
20 | model.eval()
21 | segmenter = RNNSegmenter(model)
22 | cdtb = CDTB("data/CDTB", "TRAIN", "VALIDATE", "TEST", ctb_dir="data/CTB", preprocess=True, cache_dir="data/cache")
23 |
24 | golds = []
25 | segs = []
26 | for paragraph in tqdm.tqdm(chain(*cdtb.test), desc="segmenting"):
27 | seged_sents = []
28 | for sentence in paragraph.sentences():
29 | # make sure sentence has edus
30 | if list(sentence.iterfind(node_type_filter(EDU))):
31 | seged_sents.append(Sentence(segmenter.cut_edu(sentence)))
32 | if seged_sents:
33 | segs.append(Paragraph(seged_sents))
34 | golds.append(paragraph)
35 | scores = edu_eval(segs, golds)
36 | logger.info(gen_edu_report(scores))
37 |
--------------------------------------------------------------------------------
/segmenter/rnn/train.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import random
3 | from collections import Counter
4 | from structure.vocab import Vocab, Label
5 | from structure.nodes import node_type_filter, EDU, Sentence, Paragraph
6 | from itertools import chain
7 | import numpy as np
8 | from dataset import CDTB
9 | import logging
10 | import argparse
11 | import torch
12 | import torch.optim as optim
13 | from segmenter.rnn.model import RNNSegmenterModel
14 | from segmenter.rnn import RNNSegmenter
15 | from util.eval import edu_eval, gen_edu_report
16 |
17 | logger = logging.getLogger("train rnn segmenter")
18 |
19 |
20 | def build_vocab(dataset):
21 | word_freq = Counter()
22 | pos_freq = Counter()
23 | for paragraph in chain(*dataset):
24 | for edu in paragraph.edus():
25 | word_freq.update(edu.words)
26 | pos_freq.update(edu.tags)
27 | word_vocab = Vocab("word", word_freq)
28 | pos_vocab = Vocab("part of speech", pos_freq)
29 | return word_vocab, pos_vocab
30 |
31 |
32 | def gen_train_instances(dataset):
33 | instances = []
34 | tags = []
35 | for paragraph in chain(*dataset):
36 | for sentence in paragraph.sentences():
37 | edus = list(sentence.iterfind(node_type_filter(EDU)))
38 | if edus:
39 | sent_words = []
40 | sent_poses = []
41 | sent_tags = []
42 | for i, edu in enumerate(edus):
43 | words = edu.words
44 | poses = edu.tags
45 | label = ['O'] * (len(words) - 1)
46 | label += ['B'] if i < len(edus) - 1 else ['O']
47 | sent_words.extend(words)
48 | sent_poses.extend(poses)
49 | sent_tags.extend(label)
50 | instances.append((sent_words, sent_poses))
51 | tags.append(sent_tags)
52 | return instances, tags
53 |
54 |
55 | def numericalize(instances, tags, word_vocab, pos_vocab, tag_label):
56 | trainset = []
57 | for (words, poses), tags in zip(instances, tags):
58 | word_ids = [word_vocab[word] for word in words]
59 | pos_ids = [pos_vocab[pos] for pos in poses]
60 | tag_ids = [tag_label[tag] for tag in tags]
61 | trainset.append((word_ids, pos_ids, tag_ids))
62 | return trainset
63 |
64 |
65 | def gen_batch_iter(trainset, batch_size, use_gpu=False):
66 | random_instances = np.random.permutation(trainset)
67 | num_instances = len(trainset)
68 | offset = 0
69 | while offset < num_instances:
70 | batch = random_instances[offset: min(num_instances, offset + batch_size)]
71 | num_batch = batch.shape[0]
72 | lengths = np.zeros(num_batch, dtype=np.int)
73 | for i, (word_ids, pos_ids, tag_ids) in enumerate(batch):
74 | lengths[i] = len(word_ids)
75 | sort_indices = np.argsort(-lengths)
76 | lengths = lengths[sort_indices]
77 | batch = batch[sort_indices]
78 | max_seqlen = lengths.max()
79 | word_inputs = np.zeros([num_batch, max_seqlen], dtype=np.long)
80 | pos_inputs = np.zeros([num_batch, max_seqlen], dtype=np.long)
81 | tag_outputs = np.zeros([num_batch, max_seqlen], dtype=np.long)
82 | masks = np.zeros([num_batch, max_seqlen], dtype=np.uint8)
83 | for i, (word_ids, pos_ids, tag_ids) in enumerate(batch):
84 | seqlen = len(word_ids)
85 | word_inputs[i][:seqlen] = word_ids
86 | pos_inputs[i][:seqlen] = pos_ids
87 | tag_outputs[i][:seqlen] = tag_ids
88 | masks[i][:seqlen] = 1
89 | offset = offset + batch_size
90 |
91 | word_inputs = torch.from_numpy(word_inputs).long()
92 | pos_inputs = torch.from_numpy(pos_inputs).long()
93 | tag_outputs = torch.from_numpy(tag_outputs).long()
94 | masks = torch.from_numpy(masks).byte()
95 |
96 | if use_gpu:
97 | word_inputs = word_inputs.cuda()
98 | pos_inputs = pos_inputs.cuda()
99 | tag_outputs = tag_outputs.cuda()
100 | masks = masks.cuda()
101 | yield (word_inputs, pos_inputs, masks), tag_outputs
102 |
103 |
104 | def evaluate(dataset, model):
105 | model.eval()
106 | segmenter = RNNSegmenter(model)
107 | golds = []
108 | segs = []
109 | for paragraph in chain(*dataset):
110 | seged_sents = []
111 | for sentence in paragraph.sentences():
112 | # make sure sentence has edus
113 | if list(sentence.iterfind(node_type_filter(EDU))):
114 | seged_sents.append(Sentence(segmenter.cut_edu(sentence)))
115 | if seged_sents:
116 | segs.append(Paragraph(seged_sents))
117 | golds.append(paragraph)
118 | return edu_eval(segs, golds)
119 |
120 |
121 | def get_lr(optimizer):
122 | for param_group in optimizer.param_groups:
123 | return param_group['lr']
124 |
125 |
126 | def main(args):
127 | random.seed(args.seed)
128 | torch.random.manual_seed(args.seed)
129 | np.random.seed(args.seed)
130 |
131 | logger.info("args:" + str(args))
132 | # load dataset
133 | cdtb = CDTB(args.data, "TRAIN", "VALIDATE", "TEST", ctb_dir=args.ctb_dir, preprocess=True, cache_dir=args.cache_dir)
134 | word_vocab, pos_vocab = build_vocab(cdtb.train)
135 | instances, tags = gen_train_instances(cdtb.train)
136 | tag_label = Label("tag", Counter(chain(*tags)))
137 | trainset = numericalize(instances, tags, word_vocab, pos_vocab, tag_label)
138 |
139 | # build model
140 | model = RNNSegmenterModel(hidden_size=args.hidden_size, dropout=args.dropout, rnn_layers=args.rnn_layers,
141 | word_vocab=word_vocab, pos_vocab=pos_vocab, tag_label=tag_label,
142 | pos_size=args.pos_size, pretrained=args.pretrained, w2v_freeze=args.w2v_freeze,
143 | use_gpu=args.use_gpu)
144 | if args.use_gpu:
145 | model.cuda()
146 | logger.info(model)
147 |
148 | # train
149 | step = 0
150 | best_model_f1 = 0
151 | wait_count = 0
152 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2)
153 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=3)
154 | for nepoch in range(1, args.epoch+1):
155 | batch_iter = gen_batch_iter(trainset, args.batch_size, use_gpu=args.use_gpu)
156 | for nbatch, (inputs, target) in enumerate(batch_iter, start=1):
157 | step += 1
158 | model.train()
159 | optimizer.zero_grad()
160 | loss = model.loss(inputs, target)
161 | loss.backward()
162 | optimizer.step()
163 | if nbatch > 0 and nbatch % args.log_every == 0:
164 | logger.info("step %d, patient %d, lr %f, epoch %d, batch %d, train loss %.4f" %
165 | (step, wait_count, get_lr(optimizer), nepoch, nbatch, loss.item()))
166 | # model selection
167 | score = evaluate(cdtb.validate, model)
168 | f1 = score[-1]
169 | scheduler.step(f1, nepoch)
170 | logger.info("evaluation score:")
171 | logger.info("\n" + gen_edu_report(score))
172 | if f1 > best_model_f1:
173 | wait_count = 0
174 | best_model_f1 = f1
175 | logger.info("save new best model to %s" % args.model_save)
176 | with open(args.model_save, "wb+") as model_fd:
177 | torch.save(model, model_fd)
178 | logger.info("test on new best model...")
179 | test_score = evaluate(cdtb.test, model)
180 | logger.info("test score:")
181 | logger.info("\n" + gen_edu_report(test_score))
182 | else:
183 | wait_count += 1
184 | if wait_count > args.patient:
185 | logger.info("early stopping...")
186 | break
187 |
188 | with open(args.model_save, "rb") as model_fd:
189 | best_model = torch.load(model_fd)
190 | test_score = evaluate(cdtb.test, best_model)
191 | logger.info("test score on final best model:")
192 | logger.info("\n" + gen_edu_report(test_score))
193 |
194 |
195 | if __name__ == '__main__':
196 | logging.basicConfig(level=logging.INFO)
197 | arg_parser = argparse.ArgumentParser()
198 | # dataset parameters
199 | arg_parser.add_argument("data")
200 | arg_parser.add_argument("--ctb_dir")
201 | arg_parser.add_argument("--cache_dir")
202 | arg_parser.add_argument("--seed", default=21, type=int)
203 | arg_parser.add_argument("-model_save", required=True)
204 |
205 | # model parameter
206 | arg_parser.add_argument("-hidden_size", default=256, type=int)
207 | arg_parser.add_argument("-rnn_layers", default=3, type=int)
208 | arg_parser.add_argument("-dropout", default=0.33, type=float)
209 | w2v_group = arg_parser.add_mutually_exclusive_group(required=True)
210 | w2v_group.add_argument("-pretrained")
211 | w2v_group.add_argument("-w2v_size", type=int)
212 | arg_parser.add_argument("-pos_size", default=30, type=int)
213 | arg_parser.add_argument("--w2v_freeze", dest="w2v_freeze", action="store_true")
214 | arg_parser.set_defaults(w2v_freeze=False)
215 |
216 | # train parameter
217 | arg_parser.add_argument("-epoch", default=20, type=int)
218 | arg_parser.add_argument("-lr", default=0.001, type=float)
219 | arg_parser.add_argument("-l2", default=1e-6, type=float)
220 | arg_parser.add_argument("-patient", default=4, type=int)
221 | arg_parser.add_argument("-log_every", default=5, type=int)
222 | arg_parser.add_argument("-batch_size", default=64, type=int)
223 | arg_parser.add_argument("--use_gpu", dest="use_gpu", action="store_true")
224 | arg_parser.set_defaults(use_gpu=False)
225 | main(arg_parser.parse_args())
226 |
--------------------------------------------------------------------------------
/segmenter/svm/__init__.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 |
3 | from .segmenter import SVMSegmenter
4 |
--------------------------------------------------------------------------------
/segmenter/svm/model.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 |
3 | import logging
4 | from collections import OrderedDict
5 | from nltk import ParentedTree
6 | from sklearn.feature_extraction import DictVectorizer
7 | from sklearn.svm import LinearSVC
8 |
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | class SVMCommaClassifier:
14 | def __init__(self, connectives, candidate=",,;", seed=21):
15 | self.connectives = connectives
16 | self.candidate = candidate
17 | self.fet_vector = DictVectorizer()
18 | self.clf = LinearSVC(random_state=seed)
19 |
20 | def predict(self, comma_pos, parse):
21 | fet = self.extract_features(comma_pos, parse)
22 | x = self.fet_vector.transform([fet])
23 | return self.clf.predict(x)[0]
24 |
25 | def predict_many(self, x):
26 | fets = []
27 | for comma_pos, parse in x:
28 | fets.append(self.extract_features(comma_pos, parse))
29 | x = self.fet_vector.transform(fets)
30 | return self.clf.predict(x)
31 |
32 | def extract_features(self, comma_pos, parse):
33 | childs = list(parse.subtrees(lambda t: t.height() == 2 and t.label() != '-NONE-'))
34 | offset = 0
35 | comma = None
36 | comma_index = -1
37 | for i, child in enumerate(childs):
38 | if offset == comma_pos:
39 | comma = child
40 | comma_index = i
41 | offset += len(child[0])
42 |
43 | if comma is None:
44 | return {}
45 |
46 | comma_prev = []
47 | comma_post = []
48 | if comma_index > 0:
49 | for child in childs[comma_index-1::-1]:
50 | if child[0] == ',' or child[0] == ',':
51 | break
52 | else:
53 | comma_prev.append(child)
54 | comma_prev = comma_prev[::-1]
55 | for child in childs[comma_index+1:]:
56 | if child[0] == ',' or child[0] == ',':
57 | break
58 | else:
59 | comma_post.append(child)
60 |
61 | # extract feature
62 | fet = OrderedDict()
63 | for i, prev in enumerate(comma_prev[:3]):
64 | fet['F1_P_%d' % (i+1)] = prev.label()
65 | fet['F1_W_%d' % (i+1)] = prev[0]
66 | for i, prev in enumerate(comma_prev[-3:]):
67 | fet['F2_P_%d' % (i+1)] = prev.label()
68 | fet['F2_W_%d' % (i+1)] = prev[0]
69 |
70 | if comma_post:
71 | fet['F3'] = comma_post[0].label()
72 | fet['F4'] = comma_post[0][0]
73 |
74 | for node in comma_prev:
75 | if node[0] in self.connectives:
76 | fet['F5_1'] = node[0]
77 | for node in comma_post:
78 | if node[0] in self.connectives:
79 | fet['F5_2'] = node[0]
80 |
81 | lsibling = comma.left_sibling()
82 | rsibling = comma.right_sibling()
83 | while isinstance(lsibling, ParentedTree) and lsibling.label() == '-NONE-':
84 | lsibling = lsibling.left_sibling()
85 | while isinstance(rsibling, ParentedTree) and rsibling.label() == '-NONE-':
86 | rsibling = rsibling.right_sibling()
87 |
88 | if lsibling:
89 | fet['F6'] = lsibling.label()
90 | if rsibling:
91 | fet['F7'] = rsibling.label()
92 | if lsibling and rsibling:
93 | fet['F8'] = '%s_%s' % (fet['F6'], fet['F7'])
94 | fet['F9'] = '%s_%s_%s' % (fet['F6'], comma.parent().label(), fet['F7'])
95 |
96 | for node in comma_prev:
97 | if node.label().startswith('VC'):
98 | fet['F10_1'] = 'True'
99 | if node.label().startswith('VA'):
100 | fet['F10_2'] = 'True'
101 | if node.label().startswith('VE'):
102 | fet['F10_3'] = 'True'
103 | if node.label().startswith('VV'):
104 | fet['F10_4'] = 'True'
105 | if node.label().startswith('CS'):
106 | fet['F10_5'] = 'True'
107 | for node in comma_post:
108 | if node.label().startswith('VC'):
109 | fet['F11_1'] = 'True'
110 | if node.label().startswith('VA'):
111 | fet['F11_2'] = 'True'
112 | if node.label().startswith('VE'):
113 | fet['F11_3'] = 'True'
114 | if node.label().startswith('VV'):
115 | fet['F11_4'] = 'True'
116 | if node.label().startswith('CS'):
117 | fet['F11_5'] = 'True'
118 |
119 | pcomma = comma.parent()
120 | if 'F9' in fet and fet['F9'] == 'IP_IP_IP':
121 | fet['F12'] = 'True'
122 | if parse.height() - pcomma.height() == 1:
123 | fet['F13'] = 'True'
124 | if 'F12' in fet and fet['F12'] and 'F13' in fet and fet['F13']:
125 | fet['F14'] = 'True'
126 |
127 | punct = []
128 | for child in childs:
129 | if child[0] in ',.?!,。?!':
130 | punct.append(child[0])
131 | fet['F15'] = '_'.join(punct)
132 |
133 | pre_len = len(''.join([node[0] for node in comma_prev]))
134 | post_len = len(''.join(node[0] for node in comma_post))
135 | if pre_len < 5:
136 | fet['F16'] = 'True'
137 | if abs(pre_len - post_len) > 7:
138 | fet['F17'] = 'True'
139 |
140 | comma_dept = 0
141 | tmp_node = comma
142 | while tmp_node.parent() and tmp_node.parent() is not parse:
143 | comma_dept += 1
144 | tmp_node = tmp_node.parent()
145 | fet['F18'] = comma_dept
146 | del tmp_node
147 |
148 | if pcomma and pcomma.label().startswith('NP'):
149 | fet['F19'] = 'True'
150 | if isinstance(lsibling, ParentedTree) and lsibling.label().startswith('NP'):
151 | fet['F20'] = 'True'
152 | if isinstance(rsibling, ParentedTree) and rsibling.label().startswith('NP'):
153 | fet['F21'] = 'True'
154 |
155 | if len(comma_prev) >= 2:
156 | fet['F22'] = comma_prev[0].label() + '_' + comma_prev[-1].label()
157 | fet['F23'] = comma_prev[0][0] + '_' + comma_prev[-1][0]
158 |
159 | comma_prev_set = set([(node.label(), node[0]) for node in comma_prev if node.label() != 'PU'])
160 | comma_post_set = set([(node.label(), node[0]) for node in comma_post if node.label() != 'PU'])
161 | if comma_prev_set & comma_post_set:
162 | fet['F24'] = list(comma_prev_set & comma_post_set)[0][0]
163 | return fet
164 |
--------------------------------------------------------------------------------
/segmenter/svm/segmenter.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | from typing import List
3 | from interface import SegmenterI
4 | from nltk import ParentedTree
5 | from structure.nodes import EDU, TEXT, Sentence, Paragraph
6 | from util.berkely import BerkeleyParser
7 |
8 |
9 | class SVMSegmenter(SegmenterI):
10 | def __init__(self, model):
11 | self._eos = ['!', '。', '?']
12 | self._pairs = {'“': "”", "「": "」"}
13 | self.model = model
14 | self.candidate = model.candidate
15 | self.parser = BerkeleyParser()
16 |
17 | def cut(self, text):
18 | sentences = self.cut_sent(text)
19 | for i, sent in enumerate(sentences):
20 | sentences[i] = Sentence(self.cut_edu(sent))
21 | return Paragraph(sentences)
22 |
23 | def cut_sent(self, text: str, sid=None) -> List[Sentence]:
24 | last_cut = 0
25 | sentences = []
26 | for i in range(0, len(text)-1):
27 | if text[i] in self._eos:
28 | sentences.append(Sentence([TEXT(text[last_cut: i+1])]))
29 | last_cut = i + 1
30 | if last_cut < len(text)-1:
31 | sentences.append(Sentence([TEXT(text[last_cut:])]))
32 | return sentences
33 |
34 | def cut_edu(self, sent: Sentence) -> List[EDU]:
35 | if not hasattr(sent, "parse"):
36 | print(sent.text)
37 | parse = self.parser.parse(sent.text)
38 | else:
39 | parse = getattr(sent, "parse")
40 | parse = ParentedTree.fromstring(parse.pformat())
41 | children = list(parse.subtrees(lambda t: t.height() == 2 and t.label() != '-NONE-'))
42 | edus = []
43 | last_edu_words = []
44 | last_edu_tags = []
45 | offset = 0
46 | for child in children:
47 | if child[0] == '-LRB-':
48 | child[0] = '('
49 | if child[0] == '-RRB-':
50 | child[0] = ')'
51 | last_edu_words.append(child[0])
52 | last_edu_tags.append(child.label())
53 | if child[0] in self._eos or (child[0] in self.candidate and self.model.predict(offset, parse)):
54 | text = "".join(last_edu_words)
55 | edu = EDU([TEXT(text)])
56 | setattr(edu, "words", last_edu_words)
57 | setattr(edu, "tags", last_edu_tags)
58 | edus.append(edu)
59 | last_edu_words = []
60 | last_edu_tags = []
61 | offset += len(child[0])
62 | if last_edu_words:
63 | text = "".join(last_edu_words)
64 | edu = EDU([TEXT(text)])
65 | setattr(edu, "words", last_edu_words)
66 | setattr(edu, "tags", last_edu_tags)
67 | edus.append(edu)
68 | return edus
69 |
--------------------------------------------------------------------------------
/segmenter/svm/test.py:
--------------------------------------------------------------------------------
1 | # coding: UTf-8
2 | import logging
3 | import pickle
4 | import tqdm
5 | from itertools import chain
6 | from segmenter.svm import SVMSegmenter
7 | from dataset import CDTB
8 | from structure import node_type_filter, Sentence, Paragraph, EDU
9 | from util.eval import edu_eval, gen_edu_report
10 |
11 |
12 | logger = logging.getLogger("test svm segmenter")
13 |
14 |
15 | if __name__ == '__main__':
16 | logging.basicConfig(level=logging.INFO)
17 | with open("data/models/segmenter.svm.model", "rb") as model_fd:
18 | model = pickle.load(model_fd)
19 | segmenter = SVMSegmenter(model)
20 | cdtb = CDTB("data/CDTB", "TRAIN", "VALIDATE", "TEST", ctb_dir="data/CTB", preprocess=True, cache_dir="data/cache")
21 | ctb = cdtb.ctb
22 |
23 | golds = []
24 | segs = []
25 | for paragraph in tqdm.tqdm(chain(*cdtb.test), desc="segmenting"):
26 | seged_sents = []
27 | for sentence in paragraph.sentences():
28 | # make sure sentence has edus
29 | if list(sentence.iterfind(node_type_filter(EDU))):
30 | setattr(sentence, "parse", ctb[sentence.sid])
31 | seged_sents.append(Sentence(segmenter.cut_edu(sentence)))
32 | if seged_sents:
33 | segs.append(Paragraph(seged_sents))
34 | golds.append(paragraph)
35 | scores = edu_eval(segs, golds)
36 | logger.info(gen_edu_report(scores))
37 |
--------------------------------------------------------------------------------
/segmenter/svm/train.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import argparse
3 | import logging
4 | import random
5 | import re
6 | import pickle
7 | from itertools import chain
8 | import numpy as np
9 | from nltk.tree import ParentedTree
10 | from tqdm import tqdm
11 |
12 | from dataset import CDTB
13 | from segmenter.svm.model import SVMCommaClassifier
14 | from structure.nodes import node_type_filter, EDU, Sentence
15 | from sklearn.metrics import classification_report
16 | from util.berkely import BerkeleyParser
17 |
18 | parser = BerkeleyParser()
19 | logger = logging.getLogger("train svm segmenter")
20 |
21 |
22 | def gen_instances(dataset, parses, model):
23 | instances = []
24 | labels = []
25 | candidate_re = re.compile("[%s]" % model.candidate)
26 | for paragraph in tqdm(chain(*dataset)):
27 | root = paragraph.root_relation()
28 | if root:
29 | sentences = list(root.iterfind(filter=node_type_filter(Sentence)))
30 | # 分割点两边的偏移量
31 | for sentence in sentences:
32 | segments = set() # 分割点两侧的偏移量
33 | candidates = set() # 候选分割词的偏移量
34 | edus = list(sentence.iterfind(filter=node_type_filter(EDU)))
35 | offset = 0
36 | for edu in edus:
37 | segments.add(offset)
38 | segments.add(offset+len(edu.text)-1)
39 | offset += len(edu.text)
40 | # convert tree in parented tree for feature extraction
41 | parse = ParentedTree.fromstring(parser.parse(sentence.text).pformat())
42 | for m in candidate_re.finditer(sentence.text):
43 | candidate = m.start()
44 | instances.append(model.extract_features(candidate, parse))
45 | labels.append(1 if candidate in segments else 0)
46 | return instances, labels
47 |
48 |
49 | def main(args):
50 | random.seed(args.seed)
51 | np.random.seed(args.seed)
52 |
53 | # load dataset
54 | cdtb = CDTB(args.data, "TRAIN", "VALIDATE", "TEST", ctb_dir=args.ctb_dir, preprocess=True, cache_dir=args.cache_dir)
55 | # load connectives
56 | with open(args.connectives, "r", encoding="UTF-8") as connective_fd:
57 | connectives = connective_fd.read().split()
58 | # build model
59 | model = SVMCommaClassifier(connectives, seed=args.seed)
60 | # gen trainning instances
61 | feats, labels = gen_instances(cdtb.train, cdtb.ctb, model)
62 |
63 | # train model
64 | vect = model.fet_vector.fit_transform(feats)
65 | model.clf.fit(vect, labels)
66 | # validate
67 | feats_eval, labels_eval = gen_instances(cdtb.validate, cdtb.ctb, model)
68 | vect_eval = model.fet_vector.transform(feats_eval)
69 | pred_eval = model.clf.predict(vect_eval)
70 | logger.info("validate score:")
71 | logger.info("\n" + classification_report(labels_eval, pred_eval))
72 | # test
73 | feats_test, labels_test = gen_instances(cdtb.test, cdtb.ctb, model)
74 | vect_test = model.fet_vector.transform(feats_test)
75 | pred_test = model.clf.predict(vect_test)
76 | logger.info("test score:")
77 | logger.info("\n" + classification_report(labels_test, pred_test))
78 |
79 | # save
80 | logger.info("save model to %s" % args.model_save)
81 | with open(args.model_save, "wb+") as model_fd:
82 | pickle.dump(model, model_fd)
83 |
84 |
85 | if __name__ == '__main__':
86 | logging.basicConfig(level=logging.INFO)
87 | arg_parser = argparse.ArgumentParser()
88 | # dataset parameters
89 | arg_parser.add_argument("data")
90 | arg_parser.add_argument("--ctb_dir")
91 | arg_parser.add_argument("--cache_dir")
92 | arg_parser.add_argument("-connectives", required=True)
93 | arg_parser.add_argument("--seed", default=21, type=int)
94 | arg_parser.add_argument("-model_save", required=True)
95 |
96 | main(arg_parser.parse_args())
97 |
--------------------------------------------------------------------------------
/structure/__init__.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 |
3 | from .nodes import *
4 |
--------------------------------------------------------------------------------
/structure/vocab.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 |
3 | import logging
4 | from collections import defaultdict
5 | from gensim.models import KeyedVectors
6 | import numpy as np
7 | import torch
8 |
9 |
10 | logger = logging.getLogger(__name__)
11 | PAD_TAG = ""
12 | UNK_TAG = ""
13 |
14 |
15 | class Vocab:
16 | def __init__(self, name, counter, min_occur=1):
17 | self.name = name
18 | self.s2id = defaultdict(int)
19 | self.s2id[UNK_TAG] = 0
20 | self.id2s = [UNK_TAG]
21 |
22 | self.counter = counter
23 | for tag, freq in counter.items():
24 | if tag not in self.s2id and freq >= min_occur:
25 | self.s2id[tag] = len(self.s2id)
26 | self.id2s.append(tag)
27 | logger.info("%s vocabulary size %d" % (self.name, len(self.s2id)))
28 |
29 | def __getitem__(self, item):
30 | return self.s2id[item]
31 |
32 | def embedding(self, dim=None, pretrained=None, binary=False, freeze=False, use_gpu=False):
33 | if dim is not None and pretrained is not None:
34 | raise Warning("dim should not given if pretraiained weights are assigned")
35 |
36 | if dim is None and pretrained is None:
37 | raise Warning("one of dim or pretrained should be assigned")
38 |
39 | if pretrained:
40 | w2v = KeyedVectors.load_word2vec_format(pretrained, binary=binary)
41 | dim = w2v.vector_size
42 | scale = np.sqrt(3.0 / dim)
43 | weights = np.empty([len(self), dim], dtype=np.float32)
44 | oov_count = 0
45 | all_count = 0
46 | for tag, i in self.s2id.items():
47 | if tag in w2v.vocab:
48 | weights[i] = w2v[tag].astype(np.float32)
49 | else:
50 | oov_count += self.counter[tag]
51 | weights[i] = np.zeros(dim).astype(np.float32) if freeze else \
52 | np.random.uniform(-scale, scale, dim).astype(np.float32)
53 | all_count += self.counter[tag]
54 | logger.info("%s vocabulary pretrained OOV %d/%d, %.2f%%" %
55 | (self.name, oov_count, all_count, oov_count/all_count*100))
56 | else:
57 | scale = np.sqrt(3.0 / dim)
58 | weights = np.random.uniform(-scale, scale, [len(self), dim]).astype(np.float32)
59 | weights[0] = np.zeros(dim).astype(np.float32) if freeze else \
60 | np.random.uniform(-scale, scale, dim).astype(np.float32)
61 | weights = torch.from_numpy(weights)
62 | if use_gpu:
63 | weights = weights.cuda()
64 | embedding = torch.nn.Embedding.from_pretrained(weights, freeze=freeze)
65 | return embedding
66 |
67 | def __len__(self):
68 | return len(self.id2s)
69 |
70 |
71 | class Label:
72 | def __init__(self, name, counter, specials=None):
73 | self.name = name
74 | self.counter = counter.copy()
75 | self.label2id = {}
76 | self.id2label = []
77 | if specials:
78 | for label in specials:
79 | del self.counter[label]
80 | self.label2id[label] = len(self.label2id)
81 | self.id2label.append(label)
82 |
83 | for label, freq in self.counter.items():
84 | if label not in self.label2id:
85 | self.label2id[label] = len(self.label2id)
86 | self.id2label.append(label)
87 | logger.info("label %s size %d" % (name, len(self)))
88 |
89 | def __getitem__(self, item):
90 | return self.label2id[item]
91 |
92 | def __len__(self):
93 | return len(self.id2label)
94 |
--------------------------------------------------------------------------------
/treebuilder/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/t2d_discourseparser/ce552908b1907cf8b59db11802811a6468c9bfc9/treebuilder/__init__.py
--------------------------------------------------------------------------------
/treebuilder/partptr/__init__.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 |
3 | from .parser import PartPtrParser
4 |
--------------------------------------------------------------------------------
/treebuilder/partptr/model.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
5 | import torch.nn.functional as F
6 |
7 |
8 | class MaskedGRU(nn.Module):
9 | def __init__(self, *args, **kwargs):
10 | super(MaskedGRU, self).__init__()
11 | self.rnn = nn.GRU(batch_first=True, *args, **kwargs)
12 | self.hidden_size = self.rnn.hidden_size
13 |
14 | def forward(self, padded, lengths, initial_state=None):
15 | # [batch*edu]
16 | zero_mask = lengths != 0
17 | lengths[lengths == 0] += 1 # in case zero length instance
18 | _, indices = lengths.sort(descending=True)
19 | _, rev_indices = indices.sort()
20 |
21 | # [batch*edu, max_word_seqlen, embedding]
22 | padded_sorted = padded[indices]
23 | lengths_sorted = lengths[indices]
24 | padded_packed = pack_padded_sequence(padded_sorted, lengths_sorted, batch_first=True)
25 | self.rnn.flatten_parameters()
26 | outputs_sorted_packed, hidden_sorted = self.rnn(padded_packed, initial_state)
27 | # [batch*edu, max_word_seqlen, ]
28 | outputs_sorted, _ = pad_packed_sequence(outputs_sorted_packed, batch_first=True)
29 | # [batch*edu, max_word_seqlen, output_size]
30 | outputs = outputs_sorted[rev_indices]
31 | # [batch*edu, output_size]
32 | hidden = hidden_sorted.transpose(1, 0).contiguous().view(outputs.size(0), -1)[rev_indices]
33 |
34 | outputs = outputs * zero_mask.view(-1, 1, 1).float()
35 | hidden = hidden * zero_mask.view(-1, 1).float()
36 | return outputs, hidden
37 |
38 |
39 | class BiGRUEDUEncoder(nn.Module):
40 | def __init__(self, input_size, hidden_size):
41 | super(BiGRUEDUEncoder, self).__init__()
42 | self.hidden_size = hidden_size
43 | self.input_size = input_size
44 | self.rnn = MaskedGRU(input_size, hidden_size//2, bidirectional=True)
45 | self.token_scorer = nn.Linear(hidden_size, 1)
46 | self.output_size = hidden_size
47 |
48 | def forward(self, inputs, masks):
49 | lengths = masks.sum(-1)
50 | outputs, hidden = self.rnn(inputs, lengths)
51 | token_score = self.token_scorer(outputs).squeeze(-1)
52 | token_score[masks == 0] = -1e8
53 | token_score = token_score.softmax(dim=-1) * masks.float()
54 | weighted_sum = (outputs * token_score.unsqueeze(-1)).sum(-2)
55 | return hidden + weighted_sum
56 |
57 |
58 | class Encoder(nn.Module):
59 | def __init__(self, input_size, hidden_size, dropout):
60 | super(Encoder, self).__init__()
61 | self.input_size = input_size
62 | self.hidden_size = hidden_size
63 | self.output_size = hidden_size
64 | self.input_dense = nn.Linear(input_size, hidden_size)
65 | self.edu_rnn = MaskedGRU(hidden_size, hidden_size//2, bidirectional=True)
66 | self.dropout = nn.Dropout(dropout)
67 | self.conv = nn.Sequential(
68 | nn.Conv1d(hidden_size, hidden_size, kernel_size=2, padding=1, bias=False),
69 | nn.ReLU(),
70 | nn.Dropout(dropout)
71 | )
72 | self.split_rnn = MaskedGRU(hidden_size, hidden_size//2, bidirectional=True)
73 |
74 | def forward(self, inputs, masks):
75 | inputs = self.input_dense(inputs)
76 | # edu rnn
77 | edus, _ = self.edu_rnn(inputs, masks.sum(-1))
78 | edus = inputs + self.dropout(edus)
79 | # cnn
80 | edus = edus.transpose(-2, -1)
81 | splits = self.conv(edus).transpose(-2, -1)
82 | masks = torch.cat([(masks.sum(-1, keepdim=True) > 0).type_as(masks), masks], dim=1)
83 | lengths = masks.sum(-1)
84 | # split rnn
85 | outputs, hidden = self.split_rnn(splits, lengths)
86 | outputs = splits + self.dropout(outputs)
87 | return outputs, masks, hidden
88 |
89 |
90 | class Decoder(nn.Module):
91 | def __init__(self, inputs_size, hidden_size):
92 | super(Decoder, self).__init__()
93 | self.input_dense = nn.Linear(inputs_size, hidden_size)
94 | self.rnn = nn.GRU(hidden_size, hidden_size, batch_first=True)
95 | self.output_size = hidden_size
96 |
97 | def forward(self, input, state):
98 | return self.run_step(input, state)
99 |
100 | def run_batch(self, inputs, init_states, masks):
101 | inputs = self.input_dense(inputs) * masks.unsqueeze(-1).float()
102 | outputs, _ = self.rnn(inputs, init_states.unsqueeze(0))
103 | outputs = outputs * masks.unsqueeze(-1).float()
104 | return outputs
105 |
106 | def run_step(self, input, state):
107 | input = self.input_dense(input)
108 | self.rnn.flatten_parameters()
109 | output, state = self.rnn(input, state)
110 | return output, state
111 |
112 |
113 | class BiaffineAttention(nn.Module):
114 | def __init__(self, encoder_size, decoder_size, num_labels, hidden_size):
115 | super(BiaffineAttention, self).__init__()
116 | self.encoder_size = encoder_size
117 | self.decoder_size = decoder_size
118 | self.num_labels = num_labels
119 | self.hidden_size = hidden_size
120 | self.e_mlp = nn.Sequential(
121 | nn.Linear(encoder_size, hidden_size),
122 | nn.ReLU()
123 | )
124 | self.d_mlp = nn.Sequential(
125 | nn.Linear(decoder_size, hidden_size),
126 | nn.ReLU()
127 | )
128 | self.W_e = nn.Parameter(torch.empty(num_labels, hidden_size, dtype=torch.float))
129 | self.W_d = nn.Parameter(torch.empty(num_labels, hidden_size, dtype=torch.float))
130 | self.U = nn.Parameter(torch.empty(num_labels, hidden_size, hidden_size, dtype=torch.float))
131 | self.b = nn.Parameter(torch.zeros(num_labels, 1, 1, dtype=torch.float))
132 | nn.init.xavier_normal_(self.W_e)
133 | nn.init.xavier_normal_(self.W_d)
134 | nn.init.xavier_normal_(self.U)
135 |
136 | def forward(self, e_outputs, d_outputs):
137 | # e_outputs [batch, length_encoder, encoder_size]
138 | # d_outputs [batch, length_decoder, decoder_size]
139 |
140 | # [batch, length_encoder, hidden_size]
141 | e_outputs = self.e_mlp(e_outputs)
142 | # [batch, length_encoder, hidden_size]
143 | d_outputs = self.d_mlp(d_outputs)
144 |
145 | # [batch, num_labels, 1, length_encoder]
146 | out_e = (self.W_e @ e_outputs.transpose(1, 2)).unsqueeze(2)
147 | # [batch, num_labels, length_decoder, 1]
148 | out_d = (self.W_d @ d_outputs.transpose(1, 2)).unsqueeze(3)
149 |
150 | # [batch, 1, length_decoder, hidden_size] @ [num_labels, hidden_size, hidden_size]
151 | # [batch, num_labels, length_decoder, hidden_size]
152 | out_u = d_outputs.unsqueeze(1) @ self.U
153 | # [batch, num_labels, length_decoder, hidden_size] * [batch, 1, hidden_size, length_encoder]
154 | # [batch, num_labels, length_decoder, length_encoder]
155 | out_u = out_u @ e_outputs.unsqueeze(1).transpose(2, 3)
156 | # [batch, length_decoder, length_encoder, num_labels]
157 | out = (out_e + out_d + out_u + self.b).permute(0, 2, 3, 1)
158 | return out
159 |
160 |
161 | class SplitAttention(nn.Module):
162 | def __init__(self, encoder_size, decoder_size, hidden_size):
163 | super(SplitAttention, self).__init__()
164 | self.biaffine = BiaffineAttention(encoder_size, decoder_size, 1, hidden_size)
165 |
166 | def forward(self, e_outputs, d_outputs, masks):
167 | biaffine = self.biaffine(e_outputs, d_outputs)
168 | attn = biaffine.squeeze(-1)
169 | attn[masks == 0] = -1e8
170 | return attn
171 |
172 |
173 | class PartitionPtr(nn.Module):
174 | def __init__(self, hidden_size, dropout, word_vocab, pos_vocab, nuc_label, rel_label,
175 | pretrained=None, w2v_size=None, w2v_freeze=False, pos_size=30,
176 | split_mlp_size=32, nuc_mlp_size=128, rel_mlp_size=128,
177 | use_gpu=False):
178 | super(PartitionPtr, self).__init__()
179 | self.use_gpu = use_gpu
180 | self.word_vocab = word_vocab
181 | self.pos_vocab = pos_vocab
182 | self.nuc_label = nuc_label
183 | self.rel_label = rel_label
184 | self.word_emb = word_vocab.embedding(pretrained=pretrained, dim=w2v_size, freeze=w2v_freeze, use_gpu=use_gpu)
185 | self.w2v_size = self.word_emb.weight.shape[-1]
186 | self.pos_emb = pos_vocab.embedding(dim=pos_size, use_gpu=use_gpu)
187 | self.pos_size = pos_size
188 | self.hidden_size = hidden_size
189 | self.dropout_p = dropout
190 |
191 | # component
192 | self.edu_encoder = BiGRUEDUEncoder(self.w2v_size+self.pos_size, hidden_size)
193 | self.encoder = Encoder(self.edu_encoder.output_size, hidden_size, dropout)
194 | self.context_dense = nn.Linear(self.encoder.output_size, hidden_size)
195 | self.decoder = Decoder(self.encoder.output_size*2, hidden_size)
196 | self.split_attention = SplitAttention(self.encoder.output_size, self.decoder.output_size, split_mlp_size)
197 | self.nuc_classifier = BiaffineAttention(self.encoder.output_size, self.decoder.output_size, len(self.nuc_label),
198 | nuc_mlp_size)
199 | self.rel_classifier = BiaffineAttention(self.encoder.output_size, self.decoder.output_size, len(self.rel_label),
200 | rel_mlp_size)
201 |
202 | def forward(self, left, right, memory, state):
203 | return self.decode(left, right, memory, state)
204 |
205 | def decode(self, left, right, memory, state):
206 | d_input = torch.cat([memory[0, left], memory[0, right]]).view(1, 1, -1)
207 | d_output, state = self.decoder(d_input, state)
208 | masks = torch.zeros(1, 1, memory.size(1), dtype=torch.uint8)
209 | masks[0, 0, left+1:right] = 1
210 | if self.use_gpu:
211 | masks = masks.cuda()
212 | split_scores = self.split_attention(memory, d_output, masks)
213 | split_scores = split_scores.softmax(dim=-1)
214 | nucs_score = self.nuc_classifier(memory, d_output).softmax(dim=-1) * masks.unsqueeze(-1).float()
215 | rels_score = self.rel_classifier(memory, d_output).softmax(dim=-1) * masks.unsqueeze(-1).float()
216 | split_scores = split_scores[0, 0].cpu().detach().numpy()
217 | nucs_score = nucs_score[0, 0].cpu().detach().numpy()
218 | rels_score = rels_score[0, 0].cpu().detach().numpy()
219 | return split_scores, nucs_score, rels_score, state
220 |
221 | def encode_edus(self, e_inputs):
222 | e_input_words, e_input_poses, e_masks = e_inputs
223 | batch_size, max_edu_seqlen, max_word_seqlen = e_input_words.size()
224 | # [batch_size, max_edu_seqlen, max_word_seqlen, embedding]
225 | word_embedd = self.word_emb(e_input_words)
226 | pos_embedd = self.pos_emb(e_input_poses)
227 | concat_embedd = torch.cat([word_embedd, pos_embedd], dim=-1) * e_masks.unsqueeze(-1).float()
228 | # encode edu
229 | # [batch_size*max_edu_seqlen, max_word_seqlen, embedding]
230 | inputs = concat_embedd.view(batch_size*max_edu_seqlen, max_word_seqlen, -1)
231 | # [batch_size*max_edu_seqlen, max_word_seqlen]
232 | masks = e_masks.view(batch_size*max_edu_seqlen, max_word_seqlen)
233 | edu_encoded = self.edu_encoder(inputs, masks)
234 | # [batch_size, max_edu_seqlen, edu_encoder_output_size]
235 | edu_encoded = edu_encoded.view(batch_size, max_edu_seqlen, self.edu_encoder.output_size)
236 | e_masks = (e_masks.sum(-1) > 0).int()
237 | return edu_encoded, e_masks
238 |
239 | def _decode_batch(self, e_outputs, e_contexts, d_inputs):
240 | d_inputs_indices, d_masks = d_inputs
241 | d_outputs_masks = (d_masks.sum(-1) > 0).type_as(d_masks)
242 |
243 | d_init_states = self.context_dense(e_contexts)
244 |
245 | d_inputs = e_outputs[torch.arange(e_outputs.size(0)), d_inputs_indices.permute(2, 1, 0)].permute(2, 1, 0, 3)
246 | d_inputs = d_inputs.contiguous().view(d_inputs.size(0), d_inputs.size(1), -1)
247 | d_inputs = d_inputs * d_outputs_masks.unsqueeze(-1).float()
248 |
249 | d_outputs = self.decoder.run_batch(d_inputs, d_init_states, d_outputs_masks)
250 | return d_outputs, d_outputs_masks, d_masks
251 |
252 | def loss(self, e_inputs, d_inputs, grounds):
253 | e_inputs, e_masks = self.encode_edus(e_inputs)
254 | e_outputs, e_outputs_masks, e_contexts = self.encoder(e_inputs, e_masks)
255 | d_outputs, d_outputs_masks, d_masks = self._decode_batch(e_outputs, e_contexts, d_inputs)
256 |
257 | splits_ground, nucs_ground, rels_ground = grounds
258 | # split loss
259 | splits_attn = self.split_attention(e_outputs, d_outputs, d_masks)
260 | splits_predict = splits_attn.log_softmax(dim=2)
261 | splits_ground = splits_ground.view(-1)
262 | splits_predict = splits_predict.view(splits_ground.size(0), -1)
263 | splits_masks = d_outputs_masks.view(-1).float()
264 | splits_loss = F.nll_loss(splits_predict, splits_ground, reduction="none")
265 | splits_loss = (splits_loss * splits_masks).sum() / splits_masks.sum()
266 | # nuclear loss
267 | nucs_score = self.nuc_classifier(e_outputs, d_outputs)
268 | nucs_score = nucs_score.log_softmax(dim=-1) * d_masks.unsqueeze(-1).float()
269 | nucs_score = nucs_score.view(nucs_score.size(0)*nucs_score.size(1), nucs_score.size(2), nucs_score.size(3))
270 | target_nucs_score = nucs_score[torch.arange(nucs_score.size(0)), splits_ground]
271 | target_nucs_ground = nucs_ground.view(-1)
272 | nucs_loss = F.nll_loss(target_nucs_score, target_nucs_ground)
273 |
274 | # relation loss
275 | rels_score = self.rel_classifier(e_outputs, d_outputs)
276 | rels_score = rels_score.log_softmax(dim=-1) * d_masks.unsqueeze(-1).float()
277 | rels_score = rels_score.view(rels_score.size(0)*rels_score.size(1), rels_score.size(2), rels_score.size(3))
278 | target_rels_score = rels_score[torch.arange(rels_score.size(0)), splits_ground]
279 | target_rels_ground = rels_ground.view(-1)
280 | rels_loss = F.nll_loss(target_rels_score, target_rels_ground)
281 |
282 | return splits_loss, nucs_loss, rels_loss
283 |
--------------------------------------------------------------------------------
/treebuilder/partptr/parser.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import torch
3 | import numpy as np
4 | from structure.nodes import Paragraph, Relation, rev_relationmap
5 | import matplotlib
6 | import matplotlib.pyplot as plt
7 | from interface import ParserI
8 | from structure.nodes import EDU, TEXT
9 |
10 |
11 | class PartPtrParser(ParserI):
12 | def __init__(self, model):
13 | self.parser = PartitionPtrParser(model)
14 |
15 | def parse(self, para):
16 | edus = []
17 | for edu in para.edus():
18 | edu_copy = EDU([TEXT(edu.text)])
19 | setattr(edu_copy, "words", edu.words)
20 | setattr(edu_copy, "tags", edu.tags)
21 | edus.append(edu_copy)
22 | return self.parser.parse(edus)
23 |
24 |
25 | class PartitionPtrParser:
26 | def __init__(self, model):
27 | self.model = model
28 |
29 | def parse(self, edus, ret_session=False):
30 | if len(edus) < 2:
31 | return Paragraph(edus)
32 |
33 | # TODO implement beam search
34 | session = self.init_session(edus)
35 | while not session.terminate():
36 | split_scores, nucs_score, rels_score, state = self.decode(session)
37 | split = split_scores.argmax()
38 | nuclear_id = nucs_score[split].argmax()
39 | nuclear = self.model.nuc_label.id2label[nuclear_id]
40 | relation_id = rels_score[split].argmax()
41 | relation = self.model.rel_label.id2label[relation_id]
42 | session = session.forward(split_scores, state, split, nuclear, relation)
43 | # build tree by splits (left, split, right)
44 | root_relation = self.build_tree(edus, session.splits[:], session.nuclears[:], session.relations[:])
45 | discourse = Paragraph([root_relation])
46 | if ret_session:
47 | return discourse, session
48 | else:
49 | return discourse
50 |
51 | def init_session(self, edus):
52 | edu_words = [edu.words for edu in edus]
53 | edu_poses = [edu.tags for edu in edus]
54 | max_word_seqlen = max(len(words) for words in edu_words)
55 | edu_seqlen = len(edu_words)
56 |
57 | e_input_words = np.zeros((1, edu_seqlen, max_word_seqlen), dtype=np.long)
58 | e_input_poses = np.zeros_like(e_input_words)
59 | e_input_masks = np.zeros_like(e_input_words, dtype=np.uint8)
60 |
61 | for i, (words, poses) in enumerate(zip(edu_words, edu_poses)):
62 | e_input_words[0, i, :len(words)] = [self.model.word_vocab[word] for word in words]
63 | e_input_poses[0, i, :len(poses)] = [self.model.pos_vocab[pos] for pos in poses]
64 | e_input_masks[0, i, :len(words)] = 1
65 |
66 | e_input_words = torch.from_numpy(e_input_words).long()
67 | e_input_poses = torch.from_numpy(e_input_poses).long()
68 | e_input_masks = torch.from_numpy(e_input_masks).byte()
69 |
70 | if self.model.use_gpu:
71 | e_input_words = e_input_words.cuda()
72 | e_input_poses = e_input_poses.cuda()
73 | e_input_masks = e_input_masks.cuda()
74 |
75 | edu_encoded, e_masks = self.model.encode_edus((e_input_words, e_input_poses, e_input_masks))
76 | memory, _, context = self.model.encoder(edu_encoded, e_masks)
77 | state = self.model.context_dense(context).unsqueeze(0)
78 | return Session(memory, state)
79 |
80 | def decode(self, session):
81 | left, right = session.stack[-1]
82 | return self.model(left, right, session.memory, session.state)
83 |
84 | def build_tree(self, edus, splits, nuclears, relations):
85 | left, split, right = splits.pop(0)
86 | nuclear = nuclears.pop(0)
87 | ftype = relations.pop(0)
88 | ctype = rev_relationmap[ftype]
89 | if split - left == 1:
90 | left_node = edus[left]
91 | else:
92 | left_node = self.build_tree(edus, splits, nuclears, relations)
93 |
94 | if right - split == 1:
95 | right_node = edus[split]
96 | else:
97 | right_node = self.build_tree(edus, splits, nuclears, relations)
98 |
99 | relation = Relation([left_node, right_node], nuclear=nuclear, ftype=ftype, ctype=ctype)
100 | return relation
101 |
102 |
103 | class Session:
104 | def __init__(self, memory, state):
105 | self.n = memory.size(1) - 2
106 | self.step = 0
107 | self.memory = memory
108 | self.state = state
109 | self.stack = [(0, self.n + 1)]
110 | self.scores = np.zeros((self.n, self.n+2), dtype=np.float)
111 | self.splits = []
112 | self.nuclears = []
113 | self.relations = []
114 |
115 | def forward(self, score, state, split, nuclear, relation):
116 | left, right = self.stack.pop()
117 | if right - split > 1:
118 | self.stack.append((split, right))
119 | if split - left > 1:
120 | self.stack.append((left, split))
121 | self.splits.append((left, split, right))
122 | self.nuclears.append(nuclear)
123 | self.relations.append(relation)
124 | self.state = state
125 | self.scores[self.step] = score
126 | self.step += 1
127 | return self
128 |
129 | def terminate(self):
130 | return self.step >= self.n
131 |
132 | def draw_decision_hotmap(self):
133 | textcolors = ["black", "white"]
134 | cmap = "YlGn"
135 | ylabel = "split score"
136 | col_labels = ["split %d" % i for i in range(0, self.scores.shape[1])]
137 | row_labels = ["step %d" % i for i in range(1, self.scores.shape[0] + 1)]
138 | fig, ax = plt.subplots()
139 | im = ax.imshow(self.scores, cmap=cmap)
140 | cbar = ax.figure.colorbar(im, ax=ax)
141 | cbar.ax.set_ylabel(ylabel, rotation=-90, va="bottom")
142 | ax.set_xticks(np.arange(self.scores.shape[1]))
143 | ax.set_yticks(np.arange(self.scores.shape[0]))
144 | ax.set_xticklabels(col_labels)
145 | ax.set_yticklabels(row_labels)
146 | ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
147 | plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor")
148 | for edge, spine in ax.spines.items():
149 | spine.set_visible(False)
150 | ax.set_xticks(np.arange(self.scores.shape[1] + 1) - .5, minor=True)
151 | ax.set_yticks(np.arange(self.scores.shape[0] + 1) - .5, minor=True)
152 | ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
153 | ax.tick_params(which="minor", bottom=False, left=False)
154 | threshold = im.norm(self.scores.max()) / 2.
155 | valfmt = matplotlib.ticker.StrMethodFormatter("{x:.2f}")
156 | texts = []
157 | kw = dict(horizontalalignment="center", verticalalignment="center")
158 | for i in range(self.scores.shape[0]):
159 | for j in range(self.scores.shape[1]):
160 | kw.update(color=textcolors[im.norm(self.scores[i, j]) > threshold])
161 | text = im.axes.text(j, i, valfmt(self.scores[i, j], None), **kw)
162 | texts.append(text)
163 | fig.tight_layout()
164 | plt.show()
165 |
166 | def __repr__(self):
167 | return "[step %d]memory size: %s, state size: %s\n stack:\n%s\n, scores:\n %s" % \
168 | (self.step, str(self.memory.size()), str(self.state.size()),
169 | "\n".join(map(str, self.stack)) or "[]",
170 | str(self.scores))
171 |
172 | def __str__(self):
173 | return repr(self)
174 |
--------------------------------------------------------------------------------
/treebuilder/partptr/test.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import torch
3 | import logging
4 | from itertools import chain
5 |
6 | from tqdm import tqdm
7 |
8 | from treebuilder.partptr.parser import PartitionPtrParser
9 | from structure.nodes import EDU, TEXT
10 | from dataset import CDTB
11 | from util import eval
12 |
13 |
14 | def main():
15 | logging.basicConfig(level=logging.INFO)
16 | with open("data/models/treebuilder.partptr.model", "rb") as model_fd:
17 | model = torch.load(model_fd, map_location="cpu")
18 | model.eval()
19 | model.use_gpu = False
20 | parser = PartitionPtrParser(model)
21 | cdtb = CDTB("data/CDTB", "TRAIN", "VALIDATE", "TEST", ctb_dir="data/CTB", preprocess=True, cache_dir="data/cache")
22 | golds = list(filter(lambda d: d.root_relation(), chain(*cdtb.test)))
23 |
24 | import parse
25 | pipeline = parse.build_pipeline()
26 |
27 | strips = []
28 | for paragraph in golds:
29 | edus = []
30 | for edu in paragraph.edus():
31 | # edu_copy = EDU([TEXT(edu.text)])
32 | # setattr(edu_copy, "words", edu.words)
33 | # setattr(edu_copy, "tags", edu.tags)
34 | edus.append(edu.text)
35 | strips.append("".join(edus))
36 | # print(strips[-1])
37 | parses = []
38 | parse_sessions = []
39 | for edus in tqdm(strips):
40 | # parse, session = parser.parse(edus, ret_session=True)
41 | parse = pipeline(edus)
42 | parses.append(parse)
43 | # parse_sessions.append(session)
44 |
45 | # macro cdtb scores
46 | cdtb_macro_scores = eval.parse_eval(parses, golds, average="macro")
47 | logging.info("CDTB macro (strict) scores:")
48 | logging.info(eval.gen_parse_report(*cdtb_macro_scores))
49 | # micro cdtb scores
50 | cdtb_micro_scores = eval.parse_eval(parses, golds, average="micro")
51 | logging.info("CDTB micro (strict) scores:")
52 | logging.info(eval.gen_parse_report(*cdtb_micro_scores))
53 |
54 | # micro rst scores
55 | rst_scores = eval.rst_parse_eval(parses, golds)
56 | logging.info("RST styled scores:")
57 | logging.info(eval.gen_parse_report(*rst_scores))
58 |
59 | # nuclear scores
60 | nuclear_scores = eval.nuclear_eval(parses, golds)
61 | logging.info("nuclear scores:")
62 | logging.info(eval.gen_category_report(nuclear_scores))
63 |
64 | # relation scores
65 | ctype_scores, ftype_scores = eval.relation_eval(parses, golds)
66 | logging.info("coarse relation scores:")
67 | logging.info(eval.gen_category_report(ctype_scores))
68 | logging.info("fine relation scores:")
69 | logging.info(eval.gen_category_report(ftype_scores))
70 |
71 | # draw gold and parse tree along with decision hotmap
72 | for gold, parse, session in zip(golds, parses, parse_sessions):
73 | gold.draw()
74 | session.draw_decision_hotmap()
75 | parse.draw()
76 |
77 |
78 | if __name__ == '__main__':
79 | main()
80 |
--------------------------------------------------------------------------------
/treebuilder/partptr/train.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import argparse
3 | import logging
4 | import random
5 | import torch
6 | import copy
7 | import numpy as np
8 | from dataset import CDTB
9 | from collections import Counter
10 | from itertools import chain
11 | from structure.vocab import Vocab, Label
12 | from structure.nodes import node_type_filter, EDU, Relation, Sentence, TEXT
13 | from treebuilder.partptr.model import PartitionPtr
14 | from treebuilder.partptr.parser import PartitionPtrParser
15 | import torch.optim as optim
16 | from util.eval import parse_eval, gen_parse_report
17 | from tensorboardX import SummaryWriter
18 |
19 |
20 | def build_vocab(dataset):
21 | word_freq = Counter()
22 | pos_freq = Counter()
23 | nuc_freq = Counter()
24 | rel_freq = Counter()
25 | for paragraph in chain(*dataset):
26 | for node in paragraph.iterfind(filter=node_type_filter([EDU, Relation])):
27 | if isinstance(node, EDU):
28 | word_freq.update(node.words)
29 | pos_freq.update(node.tags)
30 | elif isinstance(node, Relation):
31 | nuc_freq[node.nuclear] += 1
32 | rel_freq[node.ftype] += 1
33 |
34 | word_vocab = Vocab("word", word_freq)
35 | pos_vocab = Vocab("part of speech", pos_freq)
36 | nuc_label = Label("nuclear", nuc_freq)
37 | rel_label = Label("relation", rel_freq)
38 | return word_vocab, pos_vocab, nuc_label, rel_label
39 |
40 |
41 | def gen_decoder_data(root, edu2ids):
42 | # splits s0 s1 s2 s3 s4 s5 s6
43 | # edus s/ e0 e1 e2 e3 e4 e5 /s
44 | splits = [] # [(0, 3, 6, NS), (0, 2, 3, SN), ...]
45 | child_edus = [] # [edus]
46 |
47 | if isinstance(root, EDU):
48 | child_edus.append(root)
49 | elif isinstance(root, Sentence):
50 | for child in root:
51 | _child_edus, _splits = gen_decoder_data(child, edu2ids)
52 | child_edus.extend(_child_edus)
53 | splits.extend(_splits)
54 | elif isinstance(root, Relation):
55 | children = [gen_decoder_data(child, edu2ids) for child in root]
56 | if len(children) < 2:
57 | raise ValueError("relation node should have at least 2 children")
58 |
59 | while children:
60 | left_child_edus, left_child_splits = children.pop(0)
61 | if children:
62 | last_child_edus, _ = children[-1]
63 | start = edu2ids[left_child_edus[0]]
64 | split = edu2ids[left_child_edus[-1]] + 1
65 | end = edu2ids[last_child_edus[-1]] + 1
66 | nuc = root.nuclear
67 | rel = root.ftype
68 | splits.append((start, split, end, nuc, rel))
69 | child_edus.extend(left_child_edus)
70 | splits.extend(left_child_splits)
71 | return child_edus, splits
72 |
73 |
74 | def numericalize(dataset, word_vocab, pos_vocab, nuc_label, rel_label):
75 | instances = []
76 | for paragraph in filter(lambda d: d.root_relation(), chain(*dataset)):
77 | encoder_inputs = []
78 | decoder_inputs = []
79 | pred_splits = []
80 | pred_nucs = []
81 | pred_rels = []
82 | edus = list(paragraph.edus())
83 | for edu in edus:
84 | edu_word_ids = [word_vocab[word] for word in edu.words]
85 | edu_pos_ids = [pos_vocab[pos] for pos in edu.tags]
86 | encoder_inputs.append((edu_word_ids, edu_pos_ids))
87 | edu2ids = {edu: i for i, edu in enumerate(edus)}
88 | _, splits = gen_decoder_data(paragraph.root_relation(), edu2ids)
89 | for start, split, end, nuc, rel in splits:
90 | decoder_inputs.append((start, end))
91 | pred_splits.append(split)
92 | pred_nucs.append(nuc_label[nuc])
93 | pred_rels.append(rel_label[rel])
94 | instances.append((encoder_inputs, decoder_inputs, pred_splits, pred_nucs, pred_rels))
95 | return instances
96 |
97 |
98 | def gen_batch_iter(instances, batch_size, use_gpu=False):
99 | random_instances = np.random.permutation(instances)
100 | num_instances = len(instances)
101 | offset = 0
102 | while offset < num_instances:
103 | batch = random_instances[offset: min(num_instances, offset+batch_size)]
104 |
105 | # find out max seqlen of edus and words of edus
106 | num_batch = batch.shape[0]
107 | max_edu_seqlen = 0
108 | max_word_seqlen = 0
109 | for encoder_inputs, decoder_inputs, pred_splits, pred_nucs, pred_rels in batch:
110 | max_edu_seqlen = max_edu_seqlen if max_edu_seqlen >= len(encoder_inputs) else len(encoder_inputs)
111 | for edu_word_ids, edu_pos_ids in encoder_inputs:
112 | max_word_seqlen = max_word_seqlen if max_word_seqlen >= len(edu_word_ids) else len(edu_word_ids)
113 |
114 | # batch to numpy
115 | e_input_words = np.zeros([num_batch, max_edu_seqlen, max_word_seqlen], dtype=np.long)
116 | e_input_poses = np.zeros([num_batch, max_edu_seqlen, max_word_seqlen], dtype=np.long)
117 | e_masks = np.zeros([num_batch, max_edu_seqlen, max_word_seqlen], dtype=np.uint8)
118 |
119 | d_inputs = np.zeros([num_batch, max_edu_seqlen-1, 2], dtype=np.long)
120 | d_outputs = np.zeros([num_batch, max_edu_seqlen-1], dtype=np.long)
121 | d_output_nucs = np.zeros([num_batch, max_edu_seqlen-1], dtype=np.long)
122 | d_output_rels = np.zeros([num_batch, max_edu_seqlen - 1], dtype=np.long)
123 | d_masks = np.zeros([num_batch, max_edu_seqlen-1, max_edu_seqlen+1], dtype=np.uint8)
124 |
125 | for batchi, (encoder_inputs, decoder_inputs, pred_splits, pred_nucs, pred_rels) in enumerate(batch):
126 | for edui, (edu_word_ids, edu_pos_ids) in enumerate(encoder_inputs):
127 | word_seqlen = len(edu_word_ids)
128 | e_input_words[batchi][edui][:word_seqlen] = edu_word_ids
129 | e_input_poses[batchi][edui][:word_seqlen] = edu_pos_ids
130 | e_masks[batchi][edui][:word_seqlen] = 1
131 |
132 | for di, decoder_input in enumerate(decoder_inputs):
133 | d_inputs[batchi][di] = decoder_input
134 | d_masks[batchi][di][decoder_input[0]+1: decoder_input[1]] = 1
135 | d_outputs[batchi][:len(pred_splits)] = pred_splits
136 | d_output_nucs[batchi][:len(pred_nucs)] = pred_nucs
137 | d_output_rels[batchi][:len(pred_rels)] = pred_rels
138 |
139 | # numpy to torch
140 | e_input_words = torch.from_numpy(e_input_words).long()
141 | e_input_poses = torch.from_numpy(e_input_poses).long()
142 | e_masks = torch.from_numpy(e_masks).byte()
143 | d_inputs = torch.from_numpy(d_inputs).long()
144 | d_outputs = torch.from_numpy(d_outputs).long()
145 | d_output_nucs = torch.from_numpy(d_output_nucs).long()
146 | d_output_rels = torch.from_numpy(d_output_rels).long()
147 | d_masks = torch.from_numpy(d_masks).byte()
148 |
149 | if use_gpu:
150 | e_input_words = e_input_words.cuda()
151 | e_input_poses = e_input_poses.cuda()
152 | e_masks = e_masks.cuda()
153 | d_inputs = d_inputs.cuda()
154 | d_outputs = d_outputs.cuda()
155 | d_output_nucs = d_output_nucs.cuda()
156 | d_output_rels = d_output_rels.cuda()
157 | d_masks = d_masks.cuda()
158 |
159 | yield (e_input_words, e_input_poses, e_masks), (d_inputs, d_masks), (d_outputs, d_output_nucs, d_output_rels)
160 | offset = offset + batch_size
161 |
162 |
163 | def parse_and_eval(dataset, model):
164 | model.eval()
165 | parser = PartitionPtrParser(model)
166 | golds = list(filter(lambda d: d.root_relation(), chain(*dataset)))
167 | num_instances = len(golds)
168 | strips = []
169 | for paragraph in golds:
170 | edus = []
171 | for edu in paragraph.edus():
172 | edu_copy = EDU([TEXT(edu.text)])
173 | setattr(edu_copy, "words", edu.words)
174 | setattr(edu_copy, "tags", edu.tags)
175 | edus.append(edu_copy)
176 | strips.append(edus)
177 | parses = []
178 | for edus in strips:
179 | parse = parser.parse(edus)
180 | parses.append(parse)
181 | return num_instances, parse_eval(parses, golds)
182 |
183 |
184 | def model_score(scores):
185 | eval_score = sum(score[2] for score in scores)
186 | return eval_score
187 |
188 |
189 | def main(args):
190 | # set seed for reproducibility
191 | random.seed(args.seed)
192 | torch.manual_seed(args.seed)
193 | np.random.seed(args.seed)
194 |
195 | # load dataset
196 | cdtb = CDTB(args.data, "TRAIN", "VALIDATE", "TEST", ctb_dir=args.ctb_dir, preprocess=True, cache_dir=args.cache_dir)
197 | # build vocabulary
198 | word_vocab, pos_vocab, nuc_label, rel_label = build_vocab(cdtb.train)
199 |
200 | trainset = numericalize(cdtb.train, word_vocab, pos_vocab, nuc_label, rel_label)
201 | logging.info("num of instances trainset: %d" % len(trainset))
202 | logging.info("args: %s" % str(args))
203 | # build model
204 | model = PartitionPtr(hidden_size=args.hidden_size, dropout=args.dropout,
205 | word_vocab=word_vocab, pos_vocab=pos_vocab, nuc_label=nuc_label, rel_label=rel_label,
206 | pretrained=args.pretrained, w2v_size=args.w2v_size, w2v_freeze=args.w2v_freeze,
207 | pos_size=args.pos_size,
208 | split_mlp_size=args.split_mlp_size, nuc_mlp_size=args.nuc_mlp_size,
209 | rel_mlp_size=args.rel_mlp_size,
210 | use_gpu=args.use_gpu)
211 | if args.use_gpu:
212 | model.cuda()
213 | logging.info("model:\n%s" % str(model))
214 |
215 | # train and evaluate
216 | niter = 0
217 | log_splits_loss = 0.
218 | log_nucs_loss = 0.
219 | log_rels_loss = 0.
220 | log_loss = 0.
221 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2)
222 | writer = SummaryWriter(args.log_dir)
223 | logging.info("hint: run 'tensorboard --logdir %s' to observe training status" % args.log_dir)
224 | best_model = None
225 | best_model_score = 0.
226 | for nepoch in range(1, args.epoch + 1):
227 | batch_iter = gen_batch_iter(trainset, args.batch_size, args.use_gpu)
228 | for nbatch, (e_inputs, d_inputs, grounds) in enumerate(batch_iter, start=1):
229 | niter += 1
230 | model.train()
231 | optimizer.zero_grad()
232 | splits_loss, nucs_loss, rels_loss = model.loss(e_inputs, d_inputs, grounds)
233 | loss = args.a_split_loss * splits_loss + args.a_nuclear_loss * nucs_loss + args.a_relation_loss * rels_loss
234 | loss.backward()
235 | optimizer.step()
236 | log_splits_loss += splits_loss.item()
237 | log_nucs_loss += nucs_loss.item()
238 | log_rels_loss += rels_loss.item()
239 | log_loss += loss.item()
240 | if niter % args.log_every == 0:
241 | logging.info("[iter %-6d]epoch: %-3d, batch %-5d,"
242 | "train splits loss:%.5f, nuclear loss %.5f, relation loss %.5f, loss %.5f" %
243 | (niter, nepoch, nbatch, log_splits_loss, log_nucs_loss, log_rels_loss, log_loss))
244 | writer.add_scalar("train/split_loss", log_splits_loss, niter)
245 | writer.add_scalar("train/nuclear_loss", log_nucs_loss, niter)
246 | writer.add_scalar("train/relation_loss", log_rels_loss, niter)
247 | writer.add_scalar("train/loss", log_loss, niter)
248 | log_splits_loss = 0.
249 | log_nucs_loss = 0.
250 | log_rels_loss = 0.
251 | log_loss = 0.
252 | if niter % args.validate_every == 0:
253 | num_instances, validate_scores = parse_and_eval(cdtb.validate, model)
254 | logging.info("validation on %d instances" % num_instances)
255 | logging.info(gen_parse_report(*validate_scores))
256 | writer.add_scalar("validate/span_f1", validate_scores[0][2], niter)
257 | writer.add_scalar("validate/nuclear_f1", validate_scores[1][2], niter)
258 | writer.add_scalar("validate/coarse_relation_f1", validate_scores[2][2], niter)
259 | writer.add_scalar("validate/fine_relation_f1", validate_scores[3][2], niter)
260 | new_model_score = model_score(validate_scores)
261 | if new_model_score > best_model_score:
262 | # test on testset with new best model
263 | best_model_score = new_model_score
264 | best_model = copy.deepcopy(model)
265 | logging.info("test on new best model")
266 | num_instances, test_scores = parse_and_eval(cdtb.test, best_model)
267 | logging.info("test on %d instances" % num_instances)
268 | logging.info(gen_parse_report(*test_scores))
269 | writer.add_scalar("test/span_f1", test_scores[0][2], niter)
270 | writer.add_scalar("test/nuclear_f1", test_scores[1][2], niter)
271 | writer.add_scalar("test/coarse_relation_f1", test_scores[2][2], niter)
272 | writer.add_scalar("test/fine_relation_f1", test_scores[3][2], niter)
273 | if best_model:
274 | # evaluation and save best model
275 | logging.info("final test result")
276 | num_instances, test_scores = parse_and_eval(cdtb.test, best_model)
277 | logging.info("test on %d instances" % num_instances)
278 | logging.info(gen_parse_report(*test_scores))
279 | logging.info("save best model to %s" % args.model_save)
280 | with open(args.model_save, "wb+") as model_fd:
281 | torch.save(best_model, model_fd)
282 | writer.close()
283 |
284 |
285 | if __name__ == '__main__':
286 | logging.basicConfig(level=logging.INFO)
287 | arg_parser = argparse.ArgumentParser()
288 |
289 | # dataset parameters
290 | arg_parser.add_argument("--data", default="data/CDTB")
291 | arg_parser.add_argument("--ctb_dir", default="data/CTB")
292 | arg_parser.add_argument("--cache_dir", default="data/cache")
293 |
294 | # model parameters
295 | arg_parser.add_argument("-hidden_size", default=512, type=int)
296 | arg_parser.add_argument("-dropout", default=0.33, type=float)
297 | # w2v_group = arg_parser.add_mutually_exclusive_group(required=True)
298 | arg_parser.add_argument("-pretrained", default="data/pretrained/sgns.renmin.word")
299 | arg_parser.add_argument("-w2v_size", type=int)
300 | arg_parser.add_argument("-pos_size", default=30, type=int)
301 | arg_parser.add_argument("-split_mlp_size", default=64, type=int)
302 | arg_parser.add_argument("-nuc_mlp_size", default=32, type=int)
303 | arg_parser.add_argument("-rel_mlp_size", default=128, type=int)
304 | arg_parser.add_argument("--w2v_freeze", dest="w2v_freeze", action="store_true")
305 | arg_parser.set_defaults(w2v_freeze=True)
306 |
307 | # train parameters
308 | arg_parser.add_argument("-epoch", default=20, type=int)
309 | arg_parser.add_argument("-batch_size", default=64, type=int)
310 | arg_parser.add_argument("-lr", default=0.001, type=float)
311 | arg_parser.add_argument("-l2", default=0.0, type=float)
312 | arg_parser.add_argument("-log_every", default=10, type=int)
313 | arg_parser.add_argument("-validate_every", default=10, type=int)
314 | arg_parser.add_argument("-a_split_loss", default=0.3, type=float)
315 | arg_parser.add_argument("-a_nuclear_loss", default=1.0, type=float)
316 | arg_parser.add_argument("-a_relation_loss", default=1.0, type=float)
317 | arg_parser.add_argument("-log_dir", default="data/log")
318 | arg_parser.add_argument("-model_save", default="data/models/treebuilder.partptr.model")
319 | arg_parser.add_argument("--seed", default=21, type=int)
320 | arg_parser.add_argument("--use_gpu", dest="use_gpu", action="store_true")
321 | arg_parser.set_defaults(use_gpu=True)
322 |
323 | main(arg_parser.parse_args())
324 |
--------------------------------------------------------------------------------
/treebuilder/partptr/train_b.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import argparse
3 | import logging
4 | import random
5 | import torch
6 | import copy
7 | import numpy as np
8 | from dataset import CDTB
9 | from collections import Counter
10 | from itertools import chain
11 | from structure.vocab import Vocab, Label
12 | from structure.nodes import node_type_filter, EDU, Relation, Sentence, TEXT
13 | from treebuilder.partptr.model import PartitionPtr
14 | from treebuilder.partptr.parser import PartitionPtrParser
15 | import torch.optim as optim
16 | from util.eval import parse_eval, gen_parse_report
17 | from tensorboardX import SummaryWriter
18 |
19 |
20 | def build_vocab(dataset):
21 | word_freq = Counter()
22 | pos_freq = Counter()
23 | nuc_freq = Counter()
24 | rel_freq = Counter()
25 | for paragraph in chain(*dataset):
26 | for node in paragraph.iterfind(filter=node_type_filter([EDU, Relation])):
27 | if isinstance(node, EDU):
28 | word_freq.update(node.words)
29 | pos_freq.update(node.tags)
30 | elif isinstance(node, Relation):
31 | nuc_freq[node.nuclear] += 1
32 | rel_freq[node.ftype] += 1
33 |
34 | word_vocab = Vocab("word", word_freq)
35 | pos_vocab = Vocab("part of speech", pos_freq)
36 | nuc_label = Label("nuclear", nuc_freq)
37 | rel_label = Label("relation", rel_freq)
38 | return word_vocab, pos_vocab, nuc_label, rel_label
39 |
40 |
41 | def gen_decoder_data(root, edu2ids):
42 | # splits s0 s1 s2 s3 s4 s5 s6
43 | # edus s/ e0 e1 e2 e3 e4 e5 /s
44 | splits = [] # [(0, 3, 6, NS), (0, 2, 3, SN), ...]
45 | child_edus = [] # [edus]
46 |
47 | if isinstance(root, EDU):
48 | child_edus.append(root)
49 | elif isinstance(root, Sentence):
50 | for child in root:
51 | _child_edus, _splits = gen_decoder_data(child, edu2ids)
52 | child_edus.extend(_child_edus)
53 | splits.extend(_splits)
54 | elif isinstance(root, Relation):
55 | children = [gen_decoder_data(child, edu2ids) for child in root]
56 | if len(children) < 2:
57 | raise ValueError("relation node should have at least 2 children")
58 |
59 | while children:
60 | left_child_edus, left_child_splits = children.pop(0)
61 | if children:
62 | last_child_edus, _ = children[-1]
63 | start = edu2ids[left_child_edus[0]]
64 | split = edu2ids[left_child_edus[-1]] + 1
65 | end = edu2ids[last_child_edus[-1]] + 1
66 | nuc = root.nuclear
67 | rel = root.ftype
68 | splits.append((start, split, end, nuc, rel))
69 | child_edus.extend(left_child_edus)
70 | splits.extend(left_child_splits)
71 | return child_edus, splits
72 |
73 |
74 | def numericalize(dataset, word_vocab, pos_vocab, nuc_label, rel_label):
75 | instances = []
76 | for paragraph in filter(lambda d: d.root_relation(), chain(*dataset)):
77 | encoder_inputs = []
78 | decoder_inputs = []
79 | pred_splits = []
80 | pred_nucs = []
81 | pred_rels = []
82 | edus = list(paragraph.edus())
83 | for edu in edus:
84 | edu_word_ids = [word_vocab[word] for word in edu.words]
85 | edu_pos_ids = [pos_vocab[pos] for pos in edu.tags]
86 | encoder_inputs.append((edu_word_ids, edu_pos_ids))
87 | edu2ids = {edu: i for i, edu in enumerate(edus)}
88 | _, splits = gen_decoder_data(paragraph.root_relation(), edu2ids)
89 | for start, split, end, nuc, rel in splits:
90 | decoder_inputs.append((start, end))
91 | pred_splits.append(split)
92 | pred_nucs.append(nuc_label[nuc])
93 | pred_rels.append(rel_label[rel])
94 | instances.append((encoder_inputs, decoder_inputs, pred_splits, pred_nucs, pred_rels))
95 | return instances
96 |
97 |
98 | def gen_batch_iter(instances, batch_size, use_gpu=False):
99 | random_instances = np.random.permutation(instances)
100 | num_instances = len(instances)
101 | offset = 0
102 | while offset < num_instances:
103 | batch = random_instances[offset: min(num_instances, offset+batch_size)]
104 |
105 | # find out max seqlen of edus and words of edus
106 | num_batch = batch.shape[0]
107 | max_edu_seqlen = 0
108 | max_word_seqlen = 0
109 | for encoder_inputs, decoder_inputs, pred_splits, pred_nucs, pred_rels in batch:
110 | max_edu_seqlen = max_edu_seqlen if max_edu_seqlen >= len(encoder_inputs) else len(encoder_inputs)
111 | for edu_word_ids, edu_pos_ids in encoder_inputs:
112 | max_word_seqlen = max_word_seqlen if max_word_seqlen >= len(edu_word_ids) else len(edu_word_ids)
113 |
114 | # batch to numpy
115 | e_input_words = np.zeros([num_batch, max_edu_seqlen, max_word_seqlen], dtype=np.long)
116 | e_input_poses = np.zeros([num_batch, max_edu_seqlen, max_word_seqlen], dtype=np.long)
117 | e_masks = np.zeros([num_batch, max_edu_seqlen, max_word_seqlen], dtype=np.uint8)
118 |
119 | d_inputs = np.zeros([num_batch, max_edu_seqlen-1, 2], dtype=np.long)
120 | d_outputs = np.zeros([num_batch, max_edu_seqlen-1], dtype=np.long)
121 | d_output_nucs = np.zeros([num_batch, max_edu_seqlen-1], dtype=np.long)
122 | d_output_rels = np.zeros([num_batch, max_edu_seqlen - 1], dtype=np.long)
123 | d_masks = np.zeros([num_batch, max_edu_seqlen-1, max_edu_seqlen+1], dtype=np.uint8)
124 |
125 | for batchi, (encoder_inputs, decoder_inputs, pred_splits, pred_nucs, pred_rels) in enumerate(batch):
126 | for edui, (edu_word_ids, edu_pos_ids) in enumerate(encoder_inputs):
127 | word_seqlen = len(edu_word_ids)
128 | e_input_words[batchi][edui][:word_seqlen] = edu_word_ids
129 | e_input_poses[batchi][edui][:word_seqlen] = edu_pos_ids
130 | e_masks[batchi][edui][:word_seqlen] = 1
131 |
132 | for di, decoder_input in enumerate(decoder_inputs):
133 | d_inputs[batchi][di] = decoder_input
134 | d_masks[batchi][di][decoder_input[0]+1: decoder_input[1]] = 1
135 | d_outputs[batchi][:len(pred_splits)] = pred_splits
136 | d_output_nucs[batchi][:len(pred_nucs)] = pred_nucs
137 | d_output_rels[batchi][:len(pred_rels)] = pred_rels
138 |
139 | # numpy to torch
140 | e_input_words = torch.from_numpy(e_input_words).long()
141 | e_input_poses = torch.from_numpy(e_input_poses).long()
142 | e_masks = torch.from_numpy(e_masks).byte()
143 | d_inputs = torch.from_numpy(d_inputs).long()
144 | d_outputs = torch.from_numpy(d_outputs).long()
145 | d_output_nucs = torch.from_numpy(d_output_nucs).long()
146 | d_output_rels = torch.from_numpy(d_output_rels).long()
147 | d_masks = torch.from_numpy(d_masks).byte()
148 |
149 | if use_gpu:
150 | e_input_words = e_input_words.cuda()
151 | e_input_poses = e_input_poses.cuda()
152 | e_masks = e_masks.cuda()
153 | d_inputs = d_inputs.cuda()
154 | d_outputs = d_outputs.cuda()
155 | d_output_nucs = d_output_nucs.cuda()
156 | d_output_rels = d_output_rels.cuda()
157 | d_masks = d_masks.cuda()
158 |
159 | yield (e_input_words, e_input_poses, e_masks), (d_inputs, d_masks), (d_outputs, d_output_nucs, d_output_rels)
160 | offset = offset + batch_size
161 |
162 |
163 | def parse_and_eval(dataset, model):
164 | model.eval()
165 | parser = PartitionPtrParser(model)
166 | golds = list(filter(lambda d: d.root_relation(), chain(*dataset)))
167 | num_instances = len(golds)
168 | strips = []
169 | for paragraph in golds:
170 | edus = []
171 | for edu in paragraph.edus():
172 | edu_copy = EDU([TEXT(edu.text)])
173 | setattr(edu_copy, "words", edu.words)
174 | setattr(edu_copy, "tags", edu.tags)
175 | edus.append(edu_copy)
176 | strips.append(edus)
177 | parses = []
178 | for edus in strips:
179 | parse = parser.parse(edus)
180 | parses.append(parse)
181 | return num_instances, parse_eval(parses, golds)
182 |
183 |
184 | def model_score(scores):
185 | eval_score = sum(score[2] for score in scores)
186 | return eval_score
187 |
188 |
189 | def main(args):
190 | # set seed for reproducibility
191 | random.seed(args.seed)
192 | torch.manual_seed(args.seed)
193 | np.random.seed(args.seed)
194 |
195 | # load dataset
196 | cdtb = CDTB(args.data, "TRAIN", "VALIDATE", "TEST", ctb_dir=args.ctb_dir, preprocess=True, cache_dir=args.cache_dir)
197 | # build vocabulary
198 | word_vocab, pos_vocab, nuc_label, rel_label = build_vocab(cdtb.train)
199 |
200 | trainset = numericalize(cdtb.train, word_vocab, pos_vocab, nuc_label, rel_label)
201 | logging.info("num of instances trainset: %d" % len(trainset))
202 | logging.info("args: %s" % str(args))
203 | # build model
204 | model = PartitionPtr(hidden_size=args.hidden_size, dropout=args.dropout,
205 | word_vocab=word_vocab, pos_vocab=pos_vocab, nuc_label=nuc_label, rel_label=rel_label,
206 | pretrained=args.pretrained, w2v_size=args.w2v_size, w2v_freeze=args.w2v_freeze,
207 | pos_size=args.pos_size,
208 | split_mlp_size=args.split_mlp_size, nuc_mlp_size=args.nuc_mlp_size,
209 | rel_mlp_size=args.rel_mlp_size,
210 | use_gpu=args.use_gpu)
211 | if args.use_gpu:
212 | model.cuda()
213 | logging.info("model:\n%s" % str(model))
214 |
215 | # train and evaluate
216 | niter = 0
217 | log_splits_loss = 0.
218 | log_nucs_loss = 0.
219 | log_rels_loss = 0.
220 | log_loss = 0.
221 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2)
222 | writer = SummaryWriter(args.log_dir)
223 | logging.info("hint: run 'tensorboard --logdir %s' to observe training status" % args.log_dir)
224 | best_model = None
225 | best_model_score = 0.
226 | for nepoch in range(1, args.epoch + 1):
227 | batch_iter = gen_batch_iter(trainset, args.batch_size, args.use_gpu)
228 | for nbatch, (e_inputs, d_inputs, grounds) in enumerate(batch_iter, start=1):
229 | niter += 1
230 | model.train()
231 | optimizer.zero_grad()
232 | splits_loss, nucs_loss, rels_loss = model.loss(e_inputs, d_inputs, grounds)
233 | loss = args.a_split_loss * splits_loss + args.a_nuclear_loss * nucs_loss + args.a_relation_loss * rels_loss
234 | loss.backward()
235 | optimizer.step()
236 | log_splits_loss += splits_loss.item()
237 | log_nucs_loss += nucs_loss.item()
238 | log_rels_loss += rels_loss.item()
239 | log_loss += loss.item()
240 | if niter % args.log_every == 0:
241 | logging.info("[iter %-6d]epoch: %-3d, batch %-5d,"
242 | "train splits loss:%.5f, nuclear loss %.5f, relation loss %.5f, loss %.5f" %
243 | (niter, nepoch, nbatch, log_splits_loss, log_nucs_loss, log_rels_loss, log_loss))
244 | writer.add_scalar("train/split_loss", log_splits_loss, niter)
245 | writer.add_scalar("train/nuclear_loss", log_nucs_loss, niter)
246 | writer.add_scalar("train/relation_loss", log_rels_loss, niter)
247 | writer.add_scalar("train/loss", log_loss, niter)
248 | log_splits_loss = 0.
249 | log_nucs_loss = 0.
250 | log_rels_loss = 0.
251 | log_loss = 0.
252 | if niter % args.validate_every == 0:
253 | num_instances, validate_scores = parse_and_eval(cdtb.validate, model)
254 | logging.info("validation on %d instances" % num_instances)
255 | logging.info(gen_parse_report(*validate_scores))
256 | writer.add_scalar("validate/span_f1", validate_scores[0][2], niter)
257 | writer.add_scalar("validate/nuclear_f1", validate_scores[1][2], niter)
258 | writer.add_scalar("validate/coarse_relation_f1", validate_scores[2][2], niter)
259 | writer.add_scalar("validate/fine_relation_f1", validate_scores[3][2], niter)
260 | new_model_score = model_score(validate_scores)
261 | if new_model_score > best_model_score:
262 | # test on testset with new best model
263 | best_model_score = new_model_score
264 | best_model = copy.deepcopy(model)
265 | logging.info("test on new best model")
266 | num_instances, test_scores = parse_and_eval(cdtb.test, best_model)
267 | logging.info("test on %d instances" % num_instances)
268 | logging.info(gen_parse_report(*test_scores))
269 | writer.add_scalar("test/span_f1", test_scores[0][2], niter)
270 | writer.add_scalar("test/nuclear_f1", test_scores[1][2], niter)
271 | writer.add_scalar("test/coarse_relation_f1", test_scores[2][2], niter)
272 | writer.add_scalar("test/fine_relation_f1", test_scores[3][2], niter)
273 | if best_model:
274 | # evaluation and save best model
275 | logging.info("final test result")
276 | num_instances, test_scores = parse_and_eval(cdtb.test, best_model)
277 | logging.info("test on %d instances" % num_instances)
278 | logging.info(gen_parse_report(*test_scores))
279 | logging.info("save best model to %s" % args.model_save)
280 | with open(args.model_save, "wb+") as model_fd:
281 | torch.save(best_model, model_fd)
282 | writer.close()
283 |
284 |
285 | if __name__ == '__main__':
286 | logging.basicConfig(level=logging.INFO)
287 | arg_parser = argparse.ArgumentParser()
288 |
289 | # dataset parameters
290 | arg_parser.add_argument("data")
291 | arg_parser.add_argument("--ctb_dir")
292 | arg_parser.add_argument("--cache_dir")
293 |
294 | # model parameters
295 | arg_parser.add_argument("-hidden_size", default=512, type=int)
296 | arg_parser.add_argument("-dropout", default=0.33, type=float)
297 | w2v_group = arg_parser.add_mutually_exclusive_group(required=True)
298 | w2v_group.add_argument("-pretrained")
299 | w2v_group.add_argument("-w2v_size", type=int)
300 | arg_parser.add_argument("-pos_size", default=30, type=int)
301 | arg_parser.add_argument("-split_mlp_size", default=64, type=int)
302 | arg_parser.add_argument("-nuc_mlp_size", default=32, type=int)
303 | arg_parser.add_argument("-rel_mlp_size", default=128, type=int)
304 | arg_parser.add_argument("--w2v_freeze", dest="w2v_freeze", action="store_true")
305 | arg_parser.set_defaults(w2v_freeze=False)
306 |
307 | # train parameters
308 | arg_parser.add_argument("-epoch", default=20, type=int)
309 | arg_parser.add_argument("-batch_size", default=64, type=int)
310 | arg_parser.add_argument("-lr", default=0.001, type=float)
311 | arg_parser.add_argument("-l2", default=0.0, type=float)
312 | arg_parser.add_argument("-log_every", default=10, type=int)
313 | arg_parser.add_argument("-validate_every", default=10, type=int)
314 | arg_parser.add_argument("-a_split_loss", default=0.3, type=float)
315 | arg_parser.add_argument("-a_nuclear_loss", default=1.0, type=float)
316 | arg_parser.add_argument("-a_relation_loss", default=1.0, type=float)
317 | arg_parser.add_argument("-log_dir", default="data/log")
318 | arg_parser.add_argument("-model_save", default="data/models/treebuilder.partptr.model")
319 | arg_parser.add_argument("--seed", default=21, type=int)
320 | arg_parser.add_argument("--use_gpu", dest="use_gpu", action="store_true")
321 | arg_parser.set_defaults(use_gpu=False)
322 |
323 | main(arg_parser.parse_args())
324 |
--------------------------------------------------------------------------------
/treebuilder/shiftreduce/__init__.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 |
3 | from .parser import ShiftReduceParser
4 |
--------------------------------------------------------------------------------
/treebuilder/shiftreduce/model.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from collections import deque
6 |
7 |
8 | SHIFT = "SHIFT"
9 | REDUCE = "REDUCE"
10 |
11 |
12 | class ShiftReduceState:
13 | def __init__(self, stack, buffer, tracking):
14 | self.stack = stack
15 | self.buffer = buffer
16 | self.tracking = tracking
17 |
18 | def __copy__(self):
19 | stack = [(hs.clone(), cs.clone()) for hs, cs in self.stack]
20 | buffer = deque([(hb.clone(), cb.clone()) for hb, cb in self.buffer])
21 | h, c = self.tracking
22 | tracking = h.clone(), c.clone()
23 | return ShiftReduceState(stack, buffer, tracking)
24 |
25 |
26 | class Reducer(nn.Module):
27 | def __init__(self, hidden_size):
28 | nn.Module.__init__(self)
29 | self.hidden_size = hidden_size
30 | self.comp = nn.Linear(self.hidden_size * 3, self.hidden_size * 5)
31 |
32 | def forward(self, state):
33 | (h1, c1), (h2, c2) = state.stack[-1], state.stack[-2]
34 | tracking_h = state.tracking[0].view(-1)
35 | a, i, f1, f2, o = self.comp(torch.cat([h1, h2, tracking_h])).chunk(5)
36 | c = a.tanh() * i.sigmoid() + f1.sigmoid() * c1 + f2.sigmoid() * c2
37 | h = o.sigmoid() * c.tanh()
38 | return h, c
39 |
40 |
41 | class MLP(nn.Module):
42 | def __init__(self, hidden_size, num_layers, dropout_p, num_classes):
43 | nn.Module.__init__(self)
44 | self.linears = nn.ModuleList([nn.Linear(hidden_size, hidden_size) for _ in range(num_layers-1)])
45 | self.activations = nn.ModuleList([nn.ReLU() for _ in range(num_layers - 1)])
46 | self.dropouts = nn.ModuleList([nn.Dropout(p=dropout_p) for _ in range(num_layers - 1)])
47 | self.logits = nn.Linear(hidden_size, num_classes)
48 |
49 | def forward(self, hidden):
50 | for linear, dropout, activation in zip(self.linears, self.dropouts, self.activations):
51 | hidden = linear(hidden)
52 | hidden = activation(hidden)
53 | hidden = dropout(hidden)
54 | return self.logits(hidden)
55 |
56 |
57 | class ShiftReduceModel(nn.Module):
58 | def __init__(self, hidden_size, dropout, cnn_filters, word_vocab, pos_vocab, trans_label,
59 | pretrained=None, w2v_size=None, w2v_freeze=False, pos_size=30, mlp_layers=1,
60 | use_gpu=False):
61 | super(ShiftReduceModel, self).__init__()
62 | self.word_vocab = word_vocab
63 | self.pos_vocab = pos_vocab
64 | self.trans_label = trans_label
65 | self.word_emb = word_vocab.embedding(pretrained=pretrained, dim=w2v_size, freeze=w2v_freeze, use_gpu=use_gpu)
66 | self.w2v_size = self.word_emb.weight.shape[-1]
67 | self.pos_emb = pos_vocab.embedding(dim=pos_size, use_gpu=use_gpu)
68 | self.pos_size = pos_size
69 |
70 | self.hidden_size = hidden_size
71 | self.dropout_p = dropout
72 | self.use_gpu = use_gpu
73 |
74 | # components
75 | cnn_input_width = self.w2v_size + self.pos_size
76 | unigram_filter_num, bigram_filter_num, trigram_filter_num = cnn_filters
77 | self.edu_unigram_cnn = nn.Conv2d(1, unigram_filter_num, (1, cnn_input_width), padding=(0, 0))
78 | self.edu_bigram_cnn = nn.Conv2d(1, bigram_filter_num, (2, cnn_input_width), padding=(1, 0))
79 | self.edu_trigram_cnn = nn.Conv2d(1, trigram_filter_num, (3, cnn_input_width), padding=(2, 0))
80 |
81 | self.edu_proj = nn.Linear(self.w2v_size * 2 + self.pos_size + sum(cnn_filters), self.hidden_size * 2)
82 | self.tracker = nn.LSTMCell(hidden_size * 3, hidden_size)
83 | self.reducer = Reducer(hidden_size)
84 | self.scorer = MLP(hidden_size, mlp_layers, dropout, len(trans_label))
85 |
86 | def forward(self, state):
87 | return self.scorer(state.tracking[0].view(-1))
88 |
89 | def shift(self, state):
90 | assert len(state.buffer) > 2
91 | b1 = state.buffer.popleft()
92 | state.stack.append(b1)
93 | return self.update_tracking(state)
94 |
95 | def reduce(self, state):
96 | assert len(state.stack) >= 4
97 | reduced = self.reducer(state)
98 | state.stack.pop()
99 | state.stack.pop()
100 | state.stack.append(reduced)
101 | return self.update_tracking(state)
102 |
103 | def init_state(self, edu_words, edu_poses):
104 | edu_encoded = self.encode_edus(edu_words, edu_poses)
105 | placeholder = torch.zeros(self.hidden_size)
106 | if self.use_gpu:
107 | placeholder = placeholder.cuda()
108 | stack = [(placeholder.clone(), placeholder.clone()),
109 | (placeholder.clone(), placeholder.clone())]
110 | buffer = deque(edu_encoded +
111 | [(placeholder.clone(), placeholder.clone()), (placeholder.clone(), placeholder.clone())])
112 | tracking = placeholder.clone().view(1, -1), placeholder.clone().view(1, -1)
113 | state = ShiftReduceState(stack, buffer, tracking)
114 | return self.update_tracking(state)
115 |
116 | def update_tracking(self, state):
117 | (s1, _), (s2, _) = state.stack[-1], state.stack[-2]
118 | b1, _ = state.buffer[0]
119 | cell_input = torch.cat([s1, s2, b1], dim=0).view(1, -1)
120 | new_tracking = self.tracker(cell_input, state.tracking)
121 | state.tracking = new_tracking
122 | return state
123 |
124 | def loss(self, edu_words, edu_poses, transes):
125 | state = self.init_state(edu_words, edu_poses)
126 | pred_trans_logits = []
127 | for trans_id in transes:
128 | trans, _, _ = self.trans_label.id2label[trans_id]
129 | logits = self(state)
130 | pred_trans_logits.append(logits)
131 | if trans == SHIFT:
132 | state = self.shift(state)
133 | elif trans == REDUCE:
134 | state = self.reduce(state)
135 | else:
136 | raise ValueError("Unkown transition")
137 |
138 | pred = torch.stack(pred_trans_logits, dim=0)
139 | gold = torch.tensor(transes).long()
140 | if self.use_gpu:
141 | gold = gold.cuda()
142 | loss = F.cross_entropy(pred, gold)
143 | return loss
144 |
145 | def encode_edus(self, edu_words, edu_poses):
146 | encoded = []
147 | for words, poses in zip(edu_words, edu_poses):
148 | word_ids = torch.tensor(words or [0]).long()
149 | pos_ids = torch.tensor(poses or [0]).long()
150 | if self.use_gpu:
151 | word_ids = word_ids.cuda()
152 | pos_ids = pos_ids.cuda()
153 | word_embs = self.word_emb(word_ids)
154 | pos_embs = self.pos_emb(pos_ids)
155 | # basic
156 | w1, w_1 = word_embs[0], word_embs[-1]
157 | p1 = pos_embs[0]
158 | # cnn
159 | cnn_input = torch.cat([word_embs, pos_embs], dim=1)
160 | cnn_input = cnn_input.view(1, 1, cnn_input.size(0), cnn_input.size(1))
161 | unigram_output = F.relu(self.edu_unigram_cnn(cnn_input)).squeeze(-1)
162 | unigram_feats = F.max_pool1d(unigram_output, kernel_size=unigram_output.size(2)).view(-1)
163 | bigram_output = F.relu(self.edu_bigram_cnn(cnn_input)).squeeze(-1)
164 | bigram_feats = F.max_pool1d(bigram_output, kernel_size=bigram_output.size(2)).view(-1)
165 | trigram_output = F.relu(self.edu_trigram_cnn(cnn_input)).squeeze(-1)
166 | trigram_feats = F.max_pool1d(trigram_output, kernel_size=trigram_output.size(2)).view(-1)
167 | cnn_feats = torch.cat([unigram_feats, bigram_feats, trigram_feats], dim=0)
168 | # proj
169 | h, c = self.edu_proj(torch.cat([w1, w_1, p1, cnn_feats], dim=0)).chunk(2)
170 | encoded.append((h, c))
171 | return encoded
172 |
--------------------------------------------------------------------------------
/treebuilder/shiftreduce/parser.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 |
3 | from interface import ParserI
4 | from structure import Paragraph, EDU, TEXT, Relation, rev_relationmap
5 | from treebuilder.shiftreduce.model import SHIFT, REDUCE
6 | from collections import deque
7 |
8 |
9 | INF = 1e8
10 |
11 |
12 | class ShiftReduceParser(ParserI):
13 | def __init__(self, model):
14 | self.model = model
15 | model.eval()
16 |
17 | def parse(self, para: Paragraph) -> Paragraph:
18 | edus = []
19 | for edu in para.edus():
20 | edu_copy = EDU([TEXT(edu.text)])
21 | setattr(edu_copy, "words", edu.words)
22 | setattr(edu_copy, "tags", edu.tags)
23 | edus.append(edu_copy)
24 | if len(edus) < 2:
25 | return para
26 |
27 | trans_probs = []
28 | state = self.init_state(edus)
29 | while not self.terminate(state):
30 | logits = self.model(state)
31 | valid = self.valid_trans(state)
32 | for i, (trans, _, _) in enumerate(self.model.trans_label.id2label):
33 | if trans not in valid:
34 | logits[i] = -INF
35 | probs = logits.softmax(dim=0)
36 | trans_probs.append(probs)
37 | next_trans, _, _ = self.model.trans_label.id2label[probs.argmax(dim=0)]
38 | if next_trans == SHIFT:
39 | state = self.model.shift(state)
40 | elif next_trans == REDUCE:
41 | state = self.model.reduce(state)
42 | else:
43 | raise ValueError("unexpected transition occured")
44 | parsed = self.build_tree(edus, trans_probs)
45 | return parsed
46 |
47 | def build_tree(self, edus, trans_probs):
48 | buffer = deque(edus)
49 | stack = []
50 | for prob in trans_probs:
51 | trans, nuclear, ftype = self.model.trans_label.id2label[prob.argmax()]
52 | ctype = rev_relationmap[ftype] if ftype is not None else None
53 | if trans == SHIFT:
54 | stack.append(buffer.popleft())
55 | elif trans == REDUCE:
56 | right = stack.pop()
57 | left = stack.pop()
58 | comp = Relation([left, right], nuclear=nuclear, ftype=ftype, ctype=ctype)
59 | stack.append(comp)
60 | assert len(stack) == 1
61 | return Paragraph([stack[0]])
62 |
63 | def init_state(self, edus):
64 | word_ids = [[self.model.word_vocab[word] for word in edu.words] for edu in edus]
65 | pos_ids = [[self.model.pos_vocab[pos] for pos in edu.tags] for edu in edus]
66 | state = self.model.init_state(word_ids, pos_ids)
67 | return state
68 |
69 | def valid_trans(self, state):
70 | valid = []
71 | if len(state.buffer) > 2:
72 | valid.append(SHIFT)
73 | if len(state.stack) >= 4:
74 | valid.append(REDUCE)
75 | return valid
76 |
77 | def terminate(self, state):
78 | return not self.valid_trans(state)
79 |
--------------------------------------------------------------------------------
/treebuilder/shiftreduce/train.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import logging
3 | import argparse
4 | import random
5 | import copy
6 | import torch
7 | import torch.optim as optim
8 | import numpy as np
9 | from itertools import chain
10 | from dataset import CDTB
11 | from collections import Counter
12 | from structure import EDU, Sentence, Relation, node_type_filter, TEXT, Paragraph
13 | from structure.vocab import Label, Vocab
14 | from treebuilder.shiftreduce.model import ShiftReduceModel, SHIFT, REDUCE
15 | from treebuilder.shiftreduce.parser import ShiftReduceParser
16 | from util.eval import parse_eval, gen_parse_report
17 |
18 |
19 | def oracle(tree):
20 | if tree.root_relation() is None:
21 | raise ValueError("Can not conduct transitions from forest")
22 |
23 | def _oracle(root):
24 | trans, children = [], []
25 | if isinstance(root, EDU):
26 | trans.append((SHIFT, None, None))
27 | children.append(root)
28 | elif isinstance(root, Sentence):
29 | for node in root:
30 | _trans, _children = _oracle(node)
31 | trans.extend(_trans)
32 | children.extend(_children)
33 | elif isinstance(root, Relation):
34 | rel_children = []
35 | for node in root:
36 | _trans, _children = _oracle(node)
37 | trans.extend(_trans)
38 | rel_children.extend(_children)
39 | while len(rel_children) > 1:
40 | rel_children.pop()
41 | trans.append((REDUCE, root.nuclear, root.ftype))
42 | children.append(root)
43 | else:
44 | raise ValueError("unhandle node type %s" % repr(type(root)))
45 | return trans, children
46 |
47 | transitions, _ = _oracle(tree.root_relation())
48 | return transitions
49 |
50 |
51 | def gen_instances(trees):
52 | instances = []
53 | for tree in trees:
54 | root = tree.root_relation()
55 | if root is not None:
56 | words = []
57 | poses = []
58 | for edu in root.iterfind(node_type_filter(EDU)):
59 | words.append(edu.words)
60 | poses.append(edu.tags)
61 | trans = oracle(tree)
62 | instances.append((words, poses, trans))
63 | return instances
64 |
65 |
66 | def build_vocab(instances):
67 | words_counter = Counter()
68 | poses_counter = Counter()
69 | trans_counter = Counter()
70 | for words, poses, trans in instances:
71 | words_counter.update(chain(*words))
72 | poses_counter.update(chain(*poses))
73 | trans_counter.update(trans)
74 | word_vocab = Vocab("word", words_counter)
75 | pos_vocab = Vocab("part of speech", poses_counter)
76 | trans_label = Label("transition", trans_counter)
77 | return word_vocab, pos_vocab, trans_label
78 |
79 |
80 | def numericalize(instances, word_vocab, pos_vocab, trans_label):
81 | ids = []
82 | for edu_words, edu_poses, transes in instances:
83 | word_ids = [[word_vocab[word] for word in edu] for edu in edu_words]
84 | pos_ids = [[pos_vocab[pos] for pos in edu] for edu in edu_poses]
85 | trans_ids = [trans_label[trans] for trans in transes]
86 | ids.append((word_ids, pos_ids, trans_ids))
87 | return ids
88 |
89 |
90 | def gen_batch(dataset, batch_size):
91 | offset = 0
92 | while offset < len(dataset):
93 | _offset = offset + batch_size if offset + batch_size < len(dataset) else len(dataset)
94 | yield dataset[offset: _offset]
95 | offset = _offset
96 |
97 |
98 | def parse_and_eval(dataset, model):
99 | parser = ShiftReduceParser(model)
100 | golds = list(filter(lambda d: d.root_relation(), chain(*dataset)))
101 | num_instances = len(golds)
102 | strips = []
103 | for paragraph in golds:
104 | edus = []
105 | for edu in paragraph.edus():
106 | edu_copy = EDU([TEXT(edu.text)])
107 | setattr(edu_copy, "words", edu.words)
108 | setattr(edu_copy, "tags", edu.tags)
109 | edus.append(edu_copy)
110 | strips.append(Paragraph(edus))
111 |
112 | parses = []
113 | for strip in strips:
114 | parses.append(parser.parse(strip))
115 | return num_instances, parse_eval(parses, golds)
116 |
117 |
118 | def model_score(scores):
119 | eval_score = sum(score[2] for score in scores)
120 | return eval_score
121 |
122 |
123 | def main(args):
124 | # set seed for reproducibility
125 | random.seed(args.seed)
126 | torch.manual_seed(args.seed)
127 | np.random.seed(args.seed)
128 | if torch.cuda.is_available():
129 | torch.cuda.manual_seed(args.seed)
130 |
131 | # load dataset
132 | cdtb = CDTB(args.data, "TRAIN", "VALIDATE", "TEST", ctb_dir=args.ctb_dir, preprocess=True, cache_dir=args.cache_dir)
133 |
134 | trainset = gen_instances(chain(*cdtb.train))
135 | logging.info("generate %d instances from trainset" % len(trainset))
136 | word_vocab, pos_vocab, trans_label = build_vocab(trainset)
137 | trainset = numericalize(trainset, word_vocab, pos_vocab, trans_label)
138 |
139 | model = ShiftReduceModel(hidden_size=args.hidden_size, dropout=args.dropout, cnn_filters=args.cnn_filters,
140 | word_vocab=word_vocab, pos_vocab=pos_vocab, trans_label=trans_label,
141 | pretrained=args.pretrained, w2v_size=args.w2v_size, w2v_freeze=args.w2v_freeze,
142 | pos_size=args.pos_size, mlp_layers=args.mlp_layers,
143 | use_gpu=args.use_gpu)
144 | if args.use_gpu:
145 | model.cuda()
146 | logging.info("model:\n" + str(model))
147 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.l2)
148 | dataset = np.array(trainset)
149 | niter = 0
150 | best_model = None
151 | best_model_score = 0.
152 | for nepoch in range(1, args.epoch + 1):
153 | np.random.shuffle(dataset)
154 | batch_iter = gen_batch(dataset, args.batch_size)
155 | for nbatch, batch in enumerate(batch_iter):
156 | niter += 1
157 | model.train()
158 | optimizer.zero_grad()
159 | loss = 0.
160 | for word_ids, pos_ids, trans_ids in batch:
161 | batch_loss = model.loss(word_ids, pos_ids, trans_ids)
162 | loss += batch_loss
163 | loss = loss / len(batch)
164 | loss.backward()
165 | optimizer.step()
166 | if niter % args.log_every == 0:
167 | logging.info("[iter %-6d]epoch: %-3d, batch %-5d, train loss %.5f" %
168 | (niter, nepoch, nbatch, loss.item()))
169 |
170 | if niter % args.validate_every == 0:
171 | model.eval()
172 | num_instances, validate_scores = parse_and_eval(cdtb.validate, model)
173 | logging.info("validation on %d instances" % num_instances)
174 | logging.info(gen_parse_report(*validate_scores))
175 | new_model_score = model_score(validate_scores)
176 | if new_model_score > best_model_score:
177 | # test on testset with new best model
178 | best_model_score = new_model_score
179 | best_model = copy.deepcopy(model)
180 | logging.info("test on new best model")
181 | num_instances, test_scores = parse_and_eval(cdtb.test, best_model)
182 | logging.info("test on %d instances" % num_instances)
183 | logging.info(gen_parse_report(*test_scores))
184 | if best_model:
185 | # evaluation and save best model
186 | logging.info("final test result")
187 | num_instances, test_scores = parse_and_eval(cdtb.test, best_model)
188 | logging.info("test on %d instances" % num_instances)
189 | logging.info(gen_parse_report(*test_scores))
190 | logging.info("save best model to %s" % args.model_save)
191 | with open(args.model_save, "wb+") as model_fd:
192 | torch.save(best_model, model_fd)
193 |
194 |
195 | if __name__ == '__main__':
196 | logging.basicConfig(level=logging.INFO)
197 | arg_parser = argparse.ArgumentParser()
198 |
199 | # dataset parameters
200 | arg_parser.add_argument("data")
201 | arg_parser.add_argument("--ctb_dir")
202 | arg_parser.add_argument("--cache_dir")
203 |
204 | # model parameters
205 | arg_parser.add_argument("-hidden_size", default=256, type=int)
206 | arg_parser.add_argument("-dropout", default=0.33, type=float)
207 | w2v_group = arg_parser.add_mutually_exclusive_group(required=True)
208 | w2v_group.add_argument("-pretrained")
209 | w2v_group.add_argument("-w2v_size", type=int)
210 | arg_parser.add_argument("-pos_size", default=30, type=int)
211 | arg_parser.add_argument("--w2v_freeze", dest="w2v_freeze", action="store_true")
212 | arg_parser.add_argument("-cnn_filters", nargs=3, default=[60, 30, 10], type=int)
213 | arg_parser.add_argument("-mlp_layers", default=2, type=int)
214 | arg_parser.set_defaults(w2v_freeze=False)
215 |
216 | # train parameters
217 | arg_parser.add_argument("--seed", default=21, type=int)
218 | arg_parser.add_argument("--use_gpu", dest="use_gpu", action="store_true")
219 | arg_parser.set_defaults(use_gpu=False)
220 | arg_parser.add_argument("--epoch", default=20, type=int)
221 | arg_parser.add_argument("--batch_size", default=32)
222 | arg_parser.add_argument("-lr", default=0.001, type=float)
223 | arg_parser.add_argument("-l2", default=0.00001, type=float)
224 | arg_parser.add_argument("-log_every", default=3, type=int)
225 | arg_parser.add_argument("-validate_every", default=10, type=int)
226 | arg_parser.add_argument("-model_save", default="data/models/treebuilder.shiftreduce.model")
227 | main(arg_parser.parse_args())
228 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NLP-Discourse-SoochowU/t2d_discourseparser/ce552908b1907cf8b59db11802811a6468c9bfc9/util/__init__.py
--------------------------------------------------------------------------------
/util/berkely.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | import subprocess
3 | import thulac
4 | import threading
5 | import os
6 | from nltk.tree import Tree
7 |
8 |
9 | BERKELEY_JAR = "berkeleyparser/BerkeleyParser-1.7.jar"
10 | BERKELEY_GRAMMAR = "berkeleyparser/chn_sm5.gr"
11 |
12 |
13 | class BerkeleyParser(object):
14 | def __init__(self):
15 | self.tokenizer = thulac.thulac()
16 | self.cmd = ['java', '-Xmx1024m', '-jar', BERKELEY_JAR, '-gr', BERKELEY_GRAMMAR]
17 | self.process = self.start()
18 |
19 | def start(self):
20 | return subprocess.Popen(self.cmd, env=dict(os.environ), universal_newlines=True, shell=False, bufsize=0,
21 | stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, errors='ignore')
22 |
23 | def stop(self):
24 | if self.process:
25 | self.process.terminate()
26 |
27 | def restart(self):
28 | self.stop()
29 | self.process = self.start()
30 |
31 | def parse_thread(self, text, results):
32 | text = text.replace("(", '-LRB-')
33 | text = text.replace(")", '-RRB-')
34 | self.process.stdin.write(text + '\n')
35 | self.process.stdin.flush()
36 | ret = self.process.stdout.readline().strip()
37 | results.append(ret)
38 |
39 | def parse(self, words, timeout=20000):
40 | # words, _ = list(zip(*self.tokenizer.cut(text)))
41 | results = []
42 | t = threading.Thread(target=self.parse_thread, kwargs={'text': " ".join(words), 'results': results})
43 | t.setDaemon(True)
44 | t.start()
45 | t.join(timeout)
46 |
47 | if not results:
48 | self.restart()
49 | raise TimeoutError()
50 | else:
51 | return Tree.fromstring(results[0])
52 |
--------------------------------------------------------------------------------
/util/eval.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 | from collections import defaultdict, Counter
3 | from itertools import chain
4 | import numpy as np
5 | from structure.nodes import EDU, Sentence, Relation
6 |
7 |
8 | def edu_eval(segs, golds):
9 | num_corr = 0
10 | num_gold = 0
11 | num_pred = 0
12 | for seg, gold in zip(segs, golds):
13 | seg_spans = set()
14 | gold_spans = set()
15 | seg_offset = 0
16 | for edu in seg.edus():
17 | seg_spans.add((seg_offset, seg_offset+len(edu.text)-1))
18 | seg_offset += len(edu.text)
19 | gold_offset = 0
20 | for edu in gold.edus():
21 | gold_spans.add((gold_offset, gold_offset+len(edu.text)-1))
22 | gold_offset += len(edu.text)
23 | num_corr += len(seg_spans & gold_spans)
24 | num_gold += len(gold_spans)
25 | num_pred += len(seg_spans)
26 | precision = num_corr / num_pred if num_pred > 0 else 0
27 | recall = num_corr / num_gold if num_gold > 0 else 0
28 | if precision + recall == 0:
29 | f1 = 0.
30 | else:
31 | f1 = 2. * precision * recall / (precision + recall)
32 | return num_gold, num_pred, num_corr, precision, recall, f1
33 |
34 |
35 | def gen_edu_report(score):
36 | # num_corr, num_gold, num_pred, precision, recall, f1
37 | report = '\n'
38 | report += 'gold pred corr precision recall f1\n'
39 | report += '----------------------------------------------------------\n'
40 | report += '%-4d %-4d %-4d %-.3f %-.3f %-.3f\n' % score
41 | return report
42 |
43 |
44 | def rst_parse_eval(parses, golds):
45 | return parse_eval(parses, golds, average="micro", strict=False, binarize=True)
46 |
47 |
48 | def parse_eval(parses, golds, average="macro", strict=True, binarize=True):
49 | """
50 | :param parses: list of inferenced trees
51 | :param golds: list of gold standard trees
52 | :param average: "micro"|"macro"
53 | micro average means scores are compute globally, each relation is an instance
54 | macro average means score of each tree is compute firstly and then mean of these scores are returned
55 | :param strict: if set True, the inner node of the tree count as a correct node only if all of it's child have
56 | correct boundaries, otherwise, the start offset and end offset of it's boundary are consernded
57 | :param binarize: binarize tree in parse and gold tree before evaluation
58 | :return: scores of span, nuclear, relation, full(span+nuclear+relation all correct)
59 | each score is a triple (recall, precision, f1)
60 | """
61 | if len(parses) != len(golds):
62 | raise ValueError("number of parsed trees should equal to gold standards!")
63 |
64 | num_parse = np.zeros(len(parses), dtype=np.float)
65 | num_gold = np.zeros(len(golds), dtype=np.float)
66 | num_corr_span = np.zeros(len(parses), dtype=np.float)
67 | num_corr_nuc = np.zeros(len(parses), dtype=np.float)
68 | num_corr_ctype = np.zeros(len(parses), dtype=np.float)
69 | num_corr_ftype = np.zeros(len(parses), dtype=np.float)
70 | num_corr_cfull = np.zeros(len(parses), dtype=np.float)
71 | num_corr_ffull = np.zeros(len(parses), dtype=np.float)
72 |
73 | for i, (parse, gold) in enumerate(zip(parses, golds)):
74 | parse_quads = factorize_tree(parse, strict, binarize)
75 | gold_quads = factorize_tree(gold, strict, binarize)
76 | parse_dict = {quad[0]: quad for quad in parse_quads}
77 | gold_dict = {quad[0]: quad for quad in gold_quads}
78 |
79 | num_parse[i] = len(parse_quads)
80 | num_gold[i] = len(gold_quads)
81 |
82 | parse_spans = set(parse_dict.keys())
83 | gold_spans = set(gold_dict.keys())
84 | corr_spans = gold_spans & parse_spans
85 | num_corr_span[i] = len(corr_spans)
86 |
87 | for span in corr_spans:
88 | # nuclear
89 | if parse_dict[span][1] == gold_dict[span][1]:
90 | num_corr_nuc[i] += 1
91 | # coarse relation
92 | if parse_dict[span][2] == gold_dict[span][2]:
93 | num_corr_ctype[i] += 1
94 | # fine relation
95 | if parse_dict[span][3] == gold_dict[span][3]:
96 | num_corr_ftype[i] += 1
97 | # both nuclear and coarse relation
98 | if parse_dict[span][1] == gold_dict[span][1] and parse_dict[span][2] == gold_dict[span][2]:
99 | num_corr_cfull[i] += 1
100 | # both nuclear and fine relation
101 | if parse_dict[span][1] == gold_dict[span][1] and parse_dict[span][3] == gold_dict[span][3]:
102 | num_corr_ffull[i] += 1
103 |
104 | span_score = f1_score(num_corr_span, num_gold, num_parse, average=average)
105 | nuc_score = f1_score(num_corr_nuc, num_gold, num_parse, average=average)
106 | ctype_score = f1_score(num_corr_ctype, num_gold, num_parse, average=average)
107 | ftype_score = f1_score(num_corr_ftype, num_gold, num_parse, average=average)
108 | cfull_score = f1_score(num_corr_cfull, num_gold, num_parse, average=average)
109 | ffull_score = f1_score(num_corr_ffull, num_gold, num_parse, average=average)
110 |
111 | return span_score, nuc_score, ctype_score, ftype_score, cfull_score, ffull_score
112 |
113 |
114 | def gen_parse_report(span_score, nuc_score, ctype_score, ftype_score, cfull_score, ffull_score):
115 | report = '\n'
116 | report += ' precision recall f1\n'
117 | report += '---------------------------------------------\n'
118 | report += 'span %5.3f %5.3f %5.3f\n' % span_score
119 | report += 'nuclear %5.3f %5.3f %5.3f\n' % nuc_score
120 | report += 'ctype %5.3f %5.3f %5.3f\n' % ctype_score
121 | report += 'cfull %5.3f %5.3f %5.3f\n' % cfull_score
122 | report += 'ftype %5.3f %5.3f %5.3f\n' % ftype_score
123 | report += 'ffull %5.3f %5.3f %5.3f\n' % ffull_score
124 | report += '\n'
125 | return report
126 |
127 |
128 | def nuclear_eval(parses, golds, strict=True, binarize=True):
129 | if len(parses) != len(golds):
130 | raise ValueError("number of parsed trees should equal to gold standards!")
131 |
132 | num_nuc_parse = defaultdict(int)
133 | num_nuc_gold = defaultdict(int)
134 | num_nuc_corr = defaultdict(int)
135 |
136 | for i, (parse, gold) in enumerate(zip(parses, golds)):
137 | parse_quads = factorize_tree(parse, strict, binarize)
138 | gold_quads = factorize_tree(gold, strict, binarize)
139 | parse_dict = {quad[0]: quad for quad in parse_quads}
140 | gold_dict = {quad[0]: quad for quad in gold_quads}
141 |
142 | parse_spans = set(parse_dict.keys())
143 | gold_spans = set(gold_dict.keys())
144 | corr_spans = gold_spans & parse_spans
145 |
146 | for quad in parse_quads:
147 | num_nuc_parse[quad[1]] += 1
148 | for quad in gold_quads:
149 | num_nuc_gold[quad[1]] += 1
150 | for span in corr_spans:
151 | if parse_dict[span][1] == gold_dict[span][1]:
152 | num_nuc_corr[parse_dict[span][1]] += 1
153 |
154 | scores = []
155 | for nuc_type in set(chain(num_nuc_parse.keys(), num_nuc_gold.keys(), num_nuc_corr.keys())):
156 | corr = num_nuc_corr[nuc_type]
157 | pred = num_nuc_parse[nuc_type]
158 | gold = num_nuc_gold[nuc_type]
159 | precision = corr / pred if pred > 0 else 0
160 | recall = corr / gold if gold > 0 else 0
161 | if precision + recall == 0:
162 | f1 = 0.
163 | else:
164 | f1 = 2. * precision * recall / (precision + recall)
165 | scores.append((nuc_type, gold, pred, corr, precision, recall, f1))
166 | return scores
167 |
168 |
169 | def relation_eval(parses, golds, strict=True, binarize=True):
170 | if len(parses) != len(golds):
171 | raise ValueError("number of parsed trees should equal to gold standards!")
172 |
173 | num_crel_parse = defaultdict(int)
174 | num_crel_gold = defaultdict(int)
175 | num_crel_corr = defaultdict(int)
176 | num_frel_parse = defaultdict(int)
177 | num_frel_gold = defaultdict(int)
178 | num_frel_corr = defaultdict(int)
179 |
180 | for i, (parse, gold) in enumerate(zip(parses, golds)):
181 | parse_quads = factorize_tree(parse, strict, binarize)
182 | gold_quads = factorize_tree(gold, strict, binarize)
183 | parse_dict = {quad[0]: quad for quad in parse_quads}
184 | gold_dict = {quad[0]: quad for quad in gold_quads}
185 |
186 | parse_spans = set(parse_dict.keys())
187 | gold_spans = set(gold_dict.keys())
188 | corr_spans = gold_spans & parse_spans
189 |
190 | for quad in parse_quads:
191 | num_crel_parse[quad[2]] += 1
192 | num_frel_parse[quad[3]] += 1
193 | for quad in gold_quads:
194 | num_crel_gold[quad[2]] += 1
195 | num_frel_gold[quad[3]] += 1
196 | for span in corr_spans:
197 | if parse_dict[span][2] == gold_dict[span][2]:
198 | num_crel_corr[parse_dict[span][2]] += 1
199 | if parse_dict[span][3] == gold_dict[span][3]:
200 | num_frel_corr[parse_dict[span][3]] += 1
201 |
202 | crel_scores = []
203 | frel_scores = []
204 | for nuc_type in set(chain(num_crel_parse.keys(), num_crel_gold.keys(), num_crel_corr.keys())):
205 | corr = num_crel_corr[nuc_type]
206 | pred = num_crel_parse[nuc_type]
207 | gold = num_crel_gold[nuc_type]
208 | precision = corr / pred if pred > 0 else 0
209 | recall = corr / gold if gold > 0 else 0
210 | if precision + recall == 0:
211 | f1 = 0.
212 | else:
213 | f1 = 2. * precision * recall / (precision + recall)
214 | crel_scores.append((nuc_type, gold, pred, corr, precision, recall, f1))
215 | for nuc_type in set(chain(num_frel_parse.keys(), num_frel_gold.keys(), num_frel_corr.keys())):
216 | corr = num_frel_corr[nuc_type]
217 | pred = num_frel_parse[nuc_type]
218 | gold = num_frel_gold[nuc_type]
219 | precision = corr / pred if pred > 0 else 0
220 | recall = corr / gold if gold > 0 else 0
221 | if precision + recall == 0:
222 | f1 = 0.
223 | else:
224 | f1 = 2. * precision * recall / (precision + recall)
225 | frel_scores.append((nuc_type, gold, pred, corr, precision, recall, f1))
226 | return crel_scores, frel_scores
227 |
228 |
229 | def gen_category_report(scores):
230 | report = '\n'
231 | report += 'type gold pred corr precision recall f1\n'
232 | report += '--------------------------------------------------------------------\n'
233 | for score in scores:
234 | report += '%-4s %5d %5d %5d %5.3f %5.3f %5.3f\n' % score
235 | return report
236 |
237 |
238 | def factorize_tree(tree, strict=False, binarize=True):
239 | quads = set() # (span, nuclear, coarse relation, fine relation)
240 |
241 | def factorize(root, offset=0):
242 | if isinstance(root, EDU):
243 | return [(offset, offset+len(root.text))]
244 | elif isinstance(root, Sentence):
245 | children_spans = []
246 | for child in root:
247 | spans = factorize(child, offset)
248 | children_spans.extend(spans)
249 | offset = spans[-1][1]
250 | return children_spans
251 | elif isinstance(root, Relation):
252 | children_spans = []
253 | for child in root:
254 | spans = factorize(child, offset)
255 | children_spans.extend(spans)
256 | offset = spans[-1][1]
257 | if binarize:
258 | while len(children_spans) >= 2:
259 | right = children_spans.pop()
260 | left = children_spans.pop()
261 | if strict:
262 | span = left, right
263 | else:
264 | span = left[0], right[1]
265 | quads.add((span, root.nuclear, root.ctype, root.ftype))
266 | children_spans.append((left[0], right[1]))
267 | else:
268 | if strict:
269 | span = children_spans
270 | else:
271 | span = children_spans[0][0], children_spans[-1][1]
272 | quads.add((span, root.nuclear, root.ctype, root.ftype))
273 | return [(children_spans[0][0], children_spans[-1][1])]
274 |
275 | factorize(tree.root_relation())
276 | return quads
277 |
278 |
279 | def height_eval(parses, golds, strict=True, binarize=True):
280 | gold_heights = []
281 | corr_heights = []
282 |
283 | for i, (parse, gold) in enumerate(zip(parses, golds)):
284 | parse_quads = factorize_tree(parse, strict, binarize)
285 | gold_quads = factorize_tree(gold, strict, binarize)
286 | parse_dict = {quad[0]: quad for quad in parse_quads}
287 | gold_dict = {quad[0]: quad for quad in gold_quads}
288 |
289 | parse_spans = set(parse_dict.keys())
290 | gold_spans = set(gold_dict.keys())
291 | corr_spans = gold_spans & parse_spans
292 |
293 | span_heights = factorize_span_height(gold, strict=strict, binarize=binarize)
294 | for span, height in span_heights.items():
295 | gold_heights.append(height)
296 | if span in corr_spans:
297 | corr_heights.append(height)
298 |
299 | gold_count = Counter(gold_heights)
300 | corr_count = Counter(corr_heights)
301 | height_scores = []
302 | for height in sorted(gold_count.keys()):
303 | precision = float(corr_count[height]) / (float(gold_count[height]) + 1e-8)
304 | height_scores.append((height, gold_count[height], corr_count[height], precision))
305 | return height_scores
306 |
307 |
308 | def gen_height_report(scores):
309 | report = '\n'
310 | report += 'height gold corr precision\n'
311 | report += '------------------------------------------\n'
312 | for score in scores:
313 | report += '%-4s %5d %5d %5.3f\n' % score
314 | return report
315 |
316 |
317 | def factorize_span_height(tree, strict=False, binarize=True):
318 | span_height = {}
319 |
320 | def factorize(root, offset=0):
321 | if isinstance(root, EDU):
322 | return 0, [(offset, offset + len(root.text))] # height, child_spans
323 | elif isinstance(root, Sentence):
324 | children_spans = []
325 | max_height = 0
326 | for child in root:
327 | height, spans = factorize(child, offset)
328 | children_spans.extend(spans)
329 | offset = spans[-1][1]
330 | max_height = height if height > max_height else max_height
331 | return max_height, children_spans
332 | elif isinstance(root, Relation):
333 | children_spans = []
334 | max_height = 0
335 | for child in root:
336 | height, spans = factorize(child, offset)
337 | children_spans.extend(spans)
338 | offset = spans[-1][1]
339 | max_height = height if height > max_height else max_height
340 | if binarize:
341 | while len(children_spans) >= 2:
342 | right = children_spans.pop()
343 | left = children_spans.pop()
344 | if strict:
345 | span = left, right
346 | else:
347 | span = left[0], right[1]
348 | max_height += 1
349 | span_height[span] = max_height
350 | children_spans.append((left[0], right[1]))
351 | else:
352 | if strict:
353 | span = children_spans
354 | else:
355 | span = children_spans[0][0], children_spans[-1][1]
356 | max_height += 1
357 | span_height[span] = max_height
358 | return max_height, [(children_spans[0][0], children_spans[-1][1])]
359 |
360 | factorize(tree.root_relation())
361 | return span_height
362 |
363 |
364 | def f1_score(num_corr, num_gold, num_pred, average="micro"):
365 | if average == "micro":
366 | precision = np.nan_to_num(num_corr.sum() / num_pred.sum())
367 | recall = np.nan_to_num(num_corr.sum() / num_gold.sum())
368 | elif average == "macro":
369 | precision = (np.nan_to_num(num_corr / (num_pred + 1e-10))).mean()
370 | recall = (np.nan_to_num(num_corr / num_gold)).mean()
371 | else:
372 | raise ValueError("unsupported average mode '%s'" % average)
373 | if precision + recall == 0:
374 | f1 = 0.
375 | else:
376 | f1 = 2. * precision * recall / (precision + recall)
377 | return precision, recall, f1
378 |
--------------------------------------------------------------------------------
/util/ltp.py:
--------------------------------------------------------------------------------
1 | # coding: UTF-8
2 |
3 | import os
4 | import pyltp
5 |
6 | LTP_DATA_DIR = 'pub/pyltp_models'
7 | path_to_tagger = os.path.join(LTP_DATA_DIR, 'pos.model')
8 | path_to_parser = os.path.join(LTP_DATA_DIR, "parser.model")
9 |
10 |
11 | class LTPParser:
12 | def __init__(self):
13 | self.tagger = pyltp.Postagger()
14 | self.parser = pyltp.Parser()
15 | self.tagger.load(path_to_tagger)
16 | self.parser.load(path_to_parser)
17 |
18 | def __enter__(self):
19 | self.tagger.load(path_to_tagger)
20 | self.parser.load(path_to_parser)
21 | return self
22 |
23 | def __exit__(self, exc_type, exc_val, exc_tb):
24 | self.tagger.release()
25 | self.parser.release()
26 |
27 | def parse(self, words):
28 | tags = self.tagger.postag(words)
29 | parse = self.parser.parse(words, tags)
30 | return parse
31 |
--------------------------------------------------------------------------------