├── .flake8
├── .gitignore
├── .gitmodules
├── README.md
├── configs
└── roberta.yaml
├── ctc
├── __init__.py
├── metric.py
├── model.py
├── parser.py
├── struct.py
└── transform.py
├── data
└── clang8.toy
├── pred.sh
├── recover.py
├── run.py
├── supar
├── tools
└── m2scorer
│ ├── LICENSE
│ ├── README
│ ├── example
│ ├── README
│ ├── source_gold
│ ├── system
│ └── system2
│ ├── m2scorer
│ └── scripts
│ ├── Tokenizer.py
│ ├── combiner.py
│ ├── levenshtein.py
│ ├── m2scorer.py
│ ├── nuclesgmlparser.py
│ ├── token_offsets.py
│ └── util.py
└── train.sh
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 127
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # data files
2 | data
3 |
4 | # bash scripts
5 | *.sh
6 |
7 | # docs
8 | docs/_build
9 |
10 | # intermediate files
11 | build
12 | dist
13 | *.egg-info
14 | *.pyc
15 |
16 | # experimental results
17 | exp
18 | results
19 | wandb
20 |
21 | # log and config files
22 | log.*
23 | *.log
24 | *.cfg
25 | *.ini
26 | *.yml
27 | *.yaml
28 |
29 | # pycache
30 | __pycache__
31 |
32 | # saved model
33 | *.pkl
34 | *.pt
35 |
36 | # hidden files
37 | .*
38 |
39 | # vscode
40 | .vscode
41 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "3rdparty/parser"]
2 | path = 3rdparty/parser
3 | url = https://github.com/yzhangcs/parser
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Non-autoregressive Text Editing with Copy-aware Latent Alignments
4 |
5 |
11 |
1Soochow University, Suzhou, China
12 |
2Tencent AI Lab
13 |
14 |
15 |
16 |
17 | [](https://yzhang.site/assets/pubs/emnlp/2023/ctc.pdf)
18 | [](https://arxiv.org/abs/2310.07821)
19 | [](https://www.semanticscholar.org/paper/Non-autoregressive-Text-Editing-with-Copy-aware-Zhang-Zhang/116277fd27c97d50bba2d8023d3c590c1ea8187b)
20 | 
21 |
22 |
23 |
24 |
25 |

26 |
27 |
28 |
29 | ## Citation
30 |
31 | If you are interested in our work, please cite
32 | ```bib
33 | @inproceedings{zhang-etal-2023-ctc,
34 | title = {Non-autoregressive Text Editing with Copy-aware Latent Alignments},
35 | author = {Zhang, Yu and
36 | Zhang, Yue and
37 | Cui, Leyang and
38 | Fu, Guohong},
39 | booktitle = {Proceedings of EMNLP},
40 | year = {2023},
41 | address = {Singapore}
42 | }
43 | ```
44 |
45 | ## Setup
46 |
47 | The following packages should be installed:
48 | * [`PyTorch`](https://github.com/pytorch/pytorch): >= 2.0
49 | * [`Transformers`](https://github.com/huggingface/transformers)
50 | * [`Errant`](https://github.com/chrisjbryant/errant)
51 |
52 | Clone this repo recursively:
53 | ```sh
54 | git clone https://github.com/yzhangcs/ctc-copy.git --recursive
55 | ```
56 |
57 | You can follow this [repo](https://github.com/HillZhang1999/SynGEC) to obtain the 3-stage train/dev/test data for training a English GEC model.
58 | The multilingual datasets are available [here](https://github.com/google-research-datasets/clang8).
59 |
60 | Before running, you are required to preprocess each sentence pair into the format of `SRC:\t[src]\nTGT:\t[tgt]\n`, where `src` and `tgt` are the source and target sentences, respectively. Each sentence pair is separated by a blank line.
61 | See [`data/clang8.toy`](data/clang8.toy) for examples.
62 |
63 | ## Run
64 |
65 | Try the following command to train a 3-stage English model,
66 | ```sh
67 | bash train.sh
68 | ```
69 | To make predictions & evaluations:
70 | ```sh
71 | bash pred.sh
72 | ```
73 |
74 | ## Contact
75 |
76 | If you have any questions, please feel free to [email](mailto:yzhang.cs@outlook.com) me.
77 |
--------------------------------------------------------------------------------
/configs/roberta.yaml:
--------------------------------------------------------------------------------
1 | encoder: bert
2 | bert: roberta-large
3 | upsampling: 4
4 | beam_size: 12
5 | dropout: .1
6 | token_dropout: .1
7 | n_decoder_layers: 2
8 | find_unused_parameters: 0
9 | topk: 1
10 | label_smoothing: 0
11 | lr: 5e-05
12 | lr_rate: 10
13 | mu: .9
14 | nu: .9
15 | eps: 1e-12
16 | weight_decay: .01
17 | clip: 5.0
18 | min_freq: 2
19 | fix_len: 20
20 | epochs: 64
21 | patience: 10
22 | batch_size: 100000
23 | eval_batch_size: 10000
24 | warmup_steps: 1000
25 | update_steps: 25
26 | max_len: 64
27 |
--------------------------------------------------------------------------------
/ctc/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from .parser import CTCParser
4 |
5 | __all__ = ['CTCParser']
6 |
--------------------------------------------------------------------------------
/ctc/metric.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import annotations
4 |
5 | import math
6 | import os
7 | import tempfile
8 | from collections import Counter
9 | from typing import Any, List, Optional, Set, Tuple
10 |
11 | import torch
12 | from errant import Annotator
13 |
14 | from supar.structs.fn import levenshtein
15 | from supar.utils.metric import Metric
16 |
17 |
18 | class PerplexityMetric(Metric):
19 |
20 | def __init__(
21 | self,
22 | loss: Optional[float] = None,
23 | preds: Optional[Tuple[torch.Tensor, List, List]] = None,
24 | golds: Optional[Tuple[torch.Tensor, List, List]] = None,
25 | mask: Optional[torch.BoolTensor] = None,
26 | annotator: Annotator = None,
27 | reverse: bool = False,
28 | eps: float = 1e-12
29 | ) -> PerplexityMetric:
30 | super().__init__(reverse=reverse, eps=eps)
31 |
32 | self.n_tokens = 0.
33 |
34 | self.tp = 0.0
35 | self.pred = 0.0
36 | self.gold = 0.0
37 | self.total_loss = 0.
38 |
39 | if loss is not None:
40 | self(loss, preds, golds, mask, annotator)
41 |
42 | def __repr__(self):
43 | s = f"loss: {self.loss:.4f} PPL: {self.ppl:.4f}"
44 | if self.tp > 0:
45 | s += f" - TGT: P: {self.p:6.2%} R: {self.r:6.2%} F0.5: {self.f:6.2%}"
46 | return s
47 |
48 | def __call__(
49 | self,
50 | loss: float,
51 | preds: Tuple[torch.Tensor, List, List],
52 | golds: Tuple[torch.Tensor, List, List],
53 | mask: torch.BoolTensor,
54 | annotator: Any
55 | ) -> PerplexityMetric:
56 | n_tokens = mask.sum().item()
57 | self.n += len(mask)
58 | self.count += 1
59 | self.n_tokens += n_tokens
60 | self.total_loss += float(loss) * n_tokens
61 |
62 | if preds is not None:
63 | if annotator is not None:
64 | with tempfile.TemporaryDirectory() as t:
65 | fsrc, fpred, fgold = os.path.join(t, 'src'), os.path.join(t, 'pred'), os.path.join(t, 'gold')
66 | pred_m2, gold_m2 = os.path.join(t, 'pred.m2'), os.path.join(t, 'gold.m2')
67 | with open(fsrc, 'w') as fs, open(fpred, 'w') as f:
68 | for s, i, *_ in preds:
69 | fs.write(s + '\n')
70 | f.write(i + '\n')
71 | with open(fgold, 'w') as f:
72 | for _, i, *_ in golds:
73 | f.write(i + '\n')
74 | self.errant_parallel(fsrc, fpred, pred_m2, annotator)
75 | self.errant_parallel(fsrc, fgold, gold_m2, annotator)
76 | out = self.errant_compare(pred_m2, gold_m2)
77 | tp, fp, fn = out['tp'], out['fp'], out['fn']
78 | self.tp += tp
79 | self.pred += tp + fp
80 | self.gold += tp + fn
81 | else:
82 | for p, g in zip(preds, golds):
83 | e_p = self.compare(p[2], p[3])
84 | e_g = self.compare(g[2], g[3])
85 | self.tp += len(e_p & e_g)
86 | self.pred += len(e_p)
87 | self.gold += len(e_g)
88 | return self
89 |
90 | def __add__(self, other: PerplexityMetric) -> PerplexityMetric:
91 | metric = PerplexityMetric(eps=self.eps)
92 | metric.n = self.n + other.n
93 | metric.count = self.count + other.count
94 | metric.n_tokens = self.n_tokens + other.n_tokens
95 | metric.total_loss = self.total_loss + other.total_loss
96 |
97 | metric.tp = self.tp + other.tp
98 | metric.pred = self.pred + other.pred
99 | metric.gold = self.gold + other.gold
100 | metric.reverse = self.reverse or other.reverse
101 | return metric
102 |
103 | @property
104 | def score(self):
105 | return self.f
106 |
107 | @property
108 | def loss(self):
109 | return self.total_loss / self.n_tokens
110 |
111 | @property
112 | def ppl(self):
113 | return math.pow(2, (self.loss / math.log(2)))
114 |
115 | @property
116 | def p(self):
117 | return self.tp / (self.pred + self.eps)
118 |
119 | @property
120 | def r(self):
121 | return self.tp / (self.gold + self.eps)
122 |
123 | @property
124 | def f(self):
125 | return (1 + 0.5**2) * self.p * self.r / (0.5**2 * self.p + self.r + self.eps)
126 |
127 | @property
128 | def values(self):
129 | return {'P': self.p,
130 | 'R': self.r,
131 | 'F0.5': self.f}
132 |
133 | def compare(self, s, t) -> Set:
134 | return {(i, edit) for i, _, edit in levenshtein(s, t, align=True)[1] if edit != 0}
135 |
136 | def errant_parallel(self, forig: str, fcor: str, fout: str, annotator: Any) -> None:
137 | from contextlib import ExitStack
138 |
139 | def noop_edit(id=0):
140 | return "A -1 -1|||noop|||-NONE-|||REQUIRED|||-NONE-|||"+str(id)
141 | with ExitStack() as stack, open(fout, "w") as out_m2:
142 | in_files = [stack.enter_context(open(i)) for i in [forig]+[fcor]]
143 | # Process each line of all input files
144 | for line in zip(*in_files):
145 | # Get the original and all the corrected texts
146 | orig = line[0].strip()
147 | cors = line[1:]
148 | # Skip the line if orig is empty
149 | if not orig:
150 | continue
151 | # Parse orig with spacy
152 | orig = annotator.parse(orig)
153 | # Write orig to the output m2 file
154 | out_m2.write(" ".join(["S"]+[token.text for token in orig])+"\n")
155 | # Loop through the corrected texts
156 | for cor_id, cor in enumerate(cors):
157 | cor = cor.strip()
158 | # If the texts are the same, write a noop edit
159 | if orig.text.strip() == cor:
160 | out_m2.write(noop_edit(cor_id)+"\n")
161 | # Otherwise, do extra processing
162 | else:
163 | # Parse cor with spacy
164 | cor = annotator.parse(cor)
165 | # Align the texts and extract and classify the edits
166 | edits = annotator.annotate(orig, cor)
167 | # Loop through the edits
168 | for edit in edits:
169 | # Write the edit to the output m2 file
170 | out_m2.write(edit.to_m2(cor_id)+"\n")
171 | # Write a newline when we have processed all corrections for each line
172 | out_m2.write("\n")
173 |
174 | def errant_compare(self, fhyp: str, fref: str):
175 | from argparse import Namespace
176 |
177 | # Input: An m2 format sentence with edits.
178 | # Output: A list of lists. Each edit: [start, end, cat, cor, coder]
179 |
180 | def simplify_edits(sent):
181 | out_edits = []
182 | # Get the edit lines from an m2 block.
183 | edits = sent.split("\n")[1:]
184 | # Loop through the edits
185 | for edit in edits:
186 | # Preprocessing
187 | edit = edit[2:].split("|||") # Ignore "A " then split.
188 | span = edit[0].split()
189 | start = int(span[0])
190 | end = int(span[1])
191 | cat = edit[1]
192 | cor = edit[2]
193 | coder = int(edit[-1])
194 | out_edit = [start, end, cat, cor, coder]
195 | out_edits.append(out_edit)
196 | return out_edits
197 |
198 | # Input 1: A list of edits. Each edit: [start, end, cat, cor, coder]
199 | # Output: A dict; key is coder, value is edit dict.
200 | def process_edits(edits, args):
201 | coder_dict = {}
202 | # Add an explicit noop edit if there are no edits.
203 | if not edits:
204 | edits = [[-1, -1, "noop", "-NONE-", 0]]
205 | # Loop through the edits
206 | for edit in edits:
207 | # Name the edit elements for clarity
208 | start = edit[0]
209 | end = edit[1]
210 | cat = edit[2]
211 | cor = edit[3]
212 | coder = edit[4]
213 | # Add the coder to the coder_dict if necessary
214 | if coder not in coder_dict:
215 | coder_dict[coder] = {}
216 |
217 | # Optionally apply filters based on args
218 | # 1. UNK type edits are only useful for detection, not correction.
219 | if not args.dt and not args.ds and cat == "UNK":
220 | continue
221 | # 2. Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1
222 | if args.single and (end-start >= 2 or len(cor.split()) >= 2):
223 | continue
224 | # 3. Only evaluate multi token edits; i.e. 2+:n or n:2+
225 | if args.multi and end-start < 2 and len(cor.split()) < 2:
226 | continue
227 | # 4. If there is a filter, ignore the specified error types
228 | if args.filt and cat in args.filt:
229 | continue
230 |
231 | # Token Based Detection
232 | if args.dt:
233 | # Preserve noop edits.
234 | if start == -1:
235 | if (start, start) in coder_dict[coder].keys():
236 | coder_dict[coder][(start, start)].append(cat)
237 | else:
238 | coder_dict[coder][(start, start)] = [cat]
239 | # Insertions defined as affecting the token on the right
240 | elif start == end and start >= 0:
241 | if (start, start+1) in coder_dict[coder].keys():
242 | coder_dict[coder][(start, start+1)].append(cat)
243 | else:
244 | coder_dict[coder][(start, start+1)] = [cat]
245 | # Edit spans are split for each token in the range.
246 | else:
247 | for tok_id in range(start, end):
248 | if (tok_id, tok_id+1) in coder_dict[coder].keys():
249 | coder_dict[coder][(tok_id, tok_id+1)].append(cat)
250 | else:
251 | coder_dict[coder][(tok_id, tok_id+1)] = [cat]
252 |
253 | # Span Based Detection
254 | elif args.ds:
255 | if (start, end) in coder_dict[coder].keys():
256 | coder_dict[coder][(start, end)].append(cat)
257 | else:
258 | coder_dict[coder][(start, end)] = [cat]
259 |
260 | # Span Based Correction
261 | else:
262 | # With error type classification
263 | if args.cse:
264 | if (start, end, cat, cor) in coder_dict[coder].keys():
265 | coder_dict[coder][(start, end, cat, cor)].append(cat)
266 | else:
267 | coder_dict[coder][(start, end, cat, cor)] = [cat]
268 | # Without error type classification
269 | else:
270 | if (start, end, cor) in coder_dict[coder].keys():
271 | coder_dict[coder][(start, end, cor)].append(cat)
272 | else:
273 | coder_dict[coder][(start, end, cor)] = [cat]
274 | return coder_dict
275 |
276 | # Input 1-3: True positives, false positives, false negatives
277 | # Input 4: Value of beta in F-score.
278 | # Output 1-3: Precision, Recall and F-score rounded to 4dp.
279 |
280 | def computeFScore(tp, fp, fn, beta):
281 | p = float(tp)/(tp+fp) if fp else 1.0
282 | r = float(tp)/(tp+fn) if fn else 1.0
283 | f = float((1+(beta**2))*p*r)/(((beta**2)*p)+r) if p+r else 0.0
284 | return round(p, 4), round(r, 4), round(f, 4)
285 | # Input 1: A hyp dict; key is coder_id, value is dict of processed hyp edits.
286 | # Input 2: A ref dict; key is coder_id, value is dict of processed ref edits.
287 | # Input 3: A dictionary of the best corpus level TP, FP and FN counts so far.
288 | # Input 4: Sentence ID (for verbose output only)
289 | # Input 5: Command line args
290 | # Output 1: A dict of the best corpus level TP, FP and FN for the input sentence.
291 | # Output 2: The corresponding error type dict for the above dict.
292 |
293 | # Input 1: A dictionary of hypothesis edits for a single system.
294 | # Input 2: A dictionary of reference edits for a single annotator.
295 | # Output 1-3: The TP, FP and FN for the hyp vs the given ref annotator.
296 | # Output 4: A dictionary of the error type counts.
297 | def compareEdits(hyp_edits, ref_edits):
298 | tp = 0 # True Positives
299 | fp = 0 # False Positives
300 | fn = 0 # False Negatives
301 | cat_dict = {} # {cat: [tp, fp, fn], ...}
302 |
303 | for h_edit, h_cats in hyp_edits.items():
304 | # noop hyp edits cannot be TP or FP
305 | if h_cats[0] == "noop":
306 | continue
307 | # TRUE POSITIVES
308 | if h_edit in ref_edits.keys():
309 | # On occasion, multiple tokens at same span.
310 | for h_cat in ref_edits[h_edit]: # Use ref dict for TP
311 | tp += 1
312 | # Each dict value [TP, FP, FN]
313 | if h_cat in cat_dict.keys():
314 | cat_dict[h_cat][0] += 1
315 | else:
316 | cat_dict[h_cat] = [1, 0, 0]
317 | # FALSE POSITIVES
318 | else:
319 | # On occasion, multiple tokens at same span.
320 | for h_cat in h_cats:
321 | fp += 1
322 | # Each dict value [TP, FP, FN]
323 | if h_cat in cat_dict.keys():
324 | cat_dict[h_cat][1] += 1
325 | else:
326 | cat_dict[h_cat] = [0, 1, 0]
327 | for r_edit, r_cats in ref_edits.items():
328 | # noop ref edits cannot be FN
329 | if r_cats[0] == "noop":
330 | continue
331 | # FALSE NEGATIVES
332 | if r_edit not in hyp_edits.keys():
333 | # On occasion, multiple tokens at same span.
334 | for r_cat in r_cats:
335 | fn += 1
336 | # Each dict value [TP, FP, FN]
337 | if r_cat in cat_dict.keys():
338 | cat_dict[r_cat][2] += 1
339 | else:
340 | cat_dict[r_cat] = [0, 0, 1]
341 | return tp, fp, fn, cat_dict
342 |
343 | def evaluate_edits(hyp_dict, ref_dict, best, sent_id, original_sentence, args):
344 | # Store the best sentence level scores and hyp+ref combination IDs
345 | # best_f is initialised as -1 cause 0 is a valid result.
346 | best_tp, best_fp, best_fn, best_f, _, _ = 0, 0, 0, -1, 0, 0
347 | best_cat = {}
348 | # Compare each hyp and ref combination
349 | for hyp_id in hyp_dict.keys():
350 | for ref_id in ref_dict.keys():
351 | # Get the local counts for the current combination.
352 | tp, fp, fn, cat_dict = compareEdits(hyp_dict[hyp_id], ref_dict[ref_id])
353 | # Compute the local sentence scores (for verbose output only)
354 | loc_p, loc_r, loc_f = computeFScore(tp, fp, fn, args.beta)
355 | # Compute the global sentence scores
356 | p, r, f = computeFScore(
357 | tp+best["tp"], fp+best["fp"], fn+best["fn"], args.beta)
358 | # Save the scores if they are better in terms of:
359 | # 1. Higher F-score
360 | # 2. Same F-score, higher TP
361 | # 3. Same F-score and TP, lower FP
362 | # 4. Same F-score, TP and FP, lower FN
363 | if (f > best_f) or \
364 | (f == best_f and tp > best_tp) or \
365 | (f == best_f and tp == best_tp and fp < best_fp) or \
366 | (f == best_f and tp == best_tp and fp == best_fp and fn < best_fn):
367 | best_tp, best_fp, best_fn = tp, fp, fn
368 | best_f, _, _ = f, hyp_id, ref_id
369 | best_cat = cat_dict
370 | # Save the best TP, FP and FNs as a dict, and return this and the best_cat dict
371 | best_dict = {"tp": best_tp, "fp": best_fp, "fn": best_fn}
372 | return best_dict, best_cat
373 |
374 | def merge_dict(dict1, dict2):
375 | for cat, stats in dict2.items():
376 | if cat in dict1.keys():
377 | dict1[cat] = [x+y for x, y in zip(dict1[cat], stats)]
378 | else:
379 | dict1[cat] = stats
380 | return dict1
381 | args = Namespace(beta=0.5,
382 | dt=False,
383 | ds=False,
384 | cs=False,
385 | cse=False,
386 | single=False,
387 | multi=False,
388 | filt=[],
389 | cat=1)
390 | # Open hypothesis and reference m2 files and split into chunks
391 | with open(fhyp) as fhyp, open(fref) as fref:
392 | hyp_m2 = fhyp.read().strip().split("\n\n")
393 | ref_m2 = fref.read().strip().split("\n\n")
394 | # Make sure they have the same number of sentences
395 | assert len(hyp_m2) == len(ref_m2)
396 |
397 | # Store global corpus level best counts here
398 | best_dict = Counter({"tp": 0, "fp": 0, "fn": 0})
399 | best_cats = {}
400 | # Process each sentence
401 | sents = zip(hyp_m2, ref_m2)
402 | for sent_id, sent in enumerate(sents):
403 | # Simplify the edits into lists of lists
404 | hyp_edits = simplify_edits(sent[0])
405 | ref_edits = simplify_edits(sent[1])
406 | # Process the edits for detection/correction based on args
407 | hyp_dict = process_edits(hyp_edits, args)
408 | ref_dict = process_edits(ref_edits, args)
409 | # original sentence for logging
410 | original_sentence = sent[0][2:].split("\nA")[0]
411 | # Evaluate edits and get best TP, FP, FN hyp+ref combo.
412 | count_dict, cat_dict = evaluate_edits(
413 | hyp_dict, ref_dict, best_dict, sent_id, original_sentence, args)
414 | # Merge these dicts with best_dict and best_cats
415 | best_dict += Counter(count_dict)
416 | best_cats = merge_dict(best_cats, cat_dict)
417 | return best_dict
418 |
419 |
420 | class ExactMatchMetric(Metric):
421 |
422 | def __init__(
423 | self,
424 | loss: Optional[float] = None,
425 | preds: Optional[Tuple[torch.Tensor, List, List]] = None,
426 | golds: Optional[Tuple[torch.Tensor, List, List]] = None,
427 | mask: Optional[torch.BoolTensor] = None,
428 | reverse: bool = True,
429 | eps: float = 1e-12
430 | ) -> ExactMatchMetric:
431 | super().__init__(reverse=reverse, eps=eps)
432 |
433 | self.n_tokens = 0.
434 |
435 | self.tp = 0.0
436 | self.total = 0.0
437 | self.total_loss = 0.
438 |
439 | if loss is not None:
440 | self(loss, preds, golds, mask)
441 |
442 | def __repr__(self):
443 | return f"loss: {self.loss:.4f} EM: {self.em:6.2%}"
444 |
445 | def __call__(
446 | self,
447 | loss: float,
448 | preds: Tuple[torch.Tensor, List, List],
449 | golds: Tuple[torch.Tensor, List, List],
450 | mask: torch.BoolTensor
451 | ) -> ExactMatchMetric:
452 | n_tokens = mask.sum().item()
453 | self.n += len(mask)
454 | self.count += 1
455 | self.n_tokens += n_tokens
456 | self.total_loss += float(loss) * n_tokens
457 |
458 | if preds is not None:
459 | self.tp += sum([p[3].equal(g[3]) for p, g in zip(preds, golds)])
460 | self.total += len(preds)
461 | return self
462 |
463 | def __add__(self, other: ExactMatchMetric) -> ExactMatchMetric:
464 | metric = ExactMatchMetric(eps=self.eps)
465 | metric.n = self.n + other.n
466 | metric.count = self.count + other.count
467 | metric.n_tokens = self.n_tokens + other.n_tokens
468 | metric.total_loss = self.total_loss + other.total_loss
469 |
470 | metric.tp = self.tp + other.tp
471 | metric.total = self.total + other.total
472 | metric.reverse = self.reverse or other.reverse
473 | return metric
474 |
475 | @property
476 | def score(self):
477 | return self.em
478 |
479 | @property
480 | def loss(self):
481 | return self.total_loss / self.n_tokens
482 |
483 | @property
484 | def em(self):
485 | return self.tp / (self.total + self.eps)
486 |
487 | @property
488 | def values(self):
489 | return {'EM': self.em}
490 |
--------------------------------------------------------------------------------
/ctc/model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from supar.model import Model
7 | from supar.modules import TokenDropout
8 | from supar.modules.transformer import (TransformerDecoder,
9 | TransformerDecoderLayer)
10 | from supar.config import Config
11 | from supar.utils.common import INF, MIN
12 | from supar.utils.fn import pad
13 |
14 |
15 | class CTCModel(Model):
16 | r"""
17 | The implementation of CTC Parser.
18 |
19 | Args:
20 | n_words (int):
21 | The size of the word vocabulary.
22 | n_tags (int):
23 | The number of POS tags, required if POS tag embeddings are used. Default: ``None``.
24 | n_chars (int):
25 | The number of characters, required if character-level representations are used. Default: ``None``.
26 | n_lemmas (int):
27 | The number of lemmas, required if lemma embeddings are used. Default: ``None``.
28 | encoder (str):
29 | Encoder to use.
30 | ``'lstm'``: BiLSTM encoder.
31 | ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``.
32 | Default: ``'lstm'``.
33 | n_embed (int):
34 | The size of word embeddings. Default: 100.
35 | n_pretrained (int):
36 | The size of pretrained word embeddings. Default: 125.
37 | n_feat_embed (int):
38 | The size of feature representations. Default: 100.
39 | n_char_embed (int):
40 | The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50.
41 | n_char_hidden (int):
42 | The size of y states of CharLSTM, required if using CharLSTM. Default: 100.
43 | char_pad_index (int):
44 | The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0.
45 | elmo (str):
46 | Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``.
47 | elmo_bos_eos (tuple[bool]):
48 | A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs.
49 | Default: ``(True, False)``.
50 | bert (str):
51 | Specifies which kind of language model to use, e.g., ``'bert-base-cased'``.
52 | This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_.
53 | Default: ``None``.
54 | n_bert_layers (int):
55 | Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features.
56 | The final outputs would be weighted sum of the y states of these layers.
57 | Default: 4.
58 | mix_dropout (float):
59 | The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0.
60 | bert_pooling (str):
61 | Pooling way to get token embeddings.
62 | ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all.
63 | Default: ``mean``.
64 | bert_pad_index (int):
65 | The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features.
66 | Default: 0.
67 | freeze (bool):
68 | If ``True``, freezes BERT parameters, required if using BERT features. Default: ``True``.
69 | embed_dropout (float):
70 | The dropout ratio of input embeddings. Default: .2.
71 | n_encoder_hidden (int):
72 | The size of LSTM y states. Default: 600.
73 | n_encoder_layers (int):
74 | The number of LSTM layers. Default: 3.
75 | encoder_dropout (float):
76 | The dropout ratio of encoder layer. Default: .33.
77 | mlp_dropout (float):
78 | The dropout ratio of unary edge factor MLP layers. Default: .33.
79 | pad_index (int):
80 | The index of the padding token in the word vocabulary. Default: 0.
81 | unk_index (int):
82 | The index of the unknown token in the word vocabulary. Default: 1.
83 |
84 | .. _transformers:
85 | https://github.com/huggingface/transformers
86 | """
87 |
88 | def __init__(self,
89 | n_words,
90 | n_tags=None,
91 | n_chars=None,
92 | n_lemmas=None,
93 | encoder='lstm',
94 | n_embed=100,
95 | n_pretrained=100,
96 | n_feat_embed=100,
97 | n_char_embed=50,
98 | n_char_hidden=100,
99 | char_pad_index=0,
100 | char_dropout=0,
101 | elmo='original_5b',
102 | elmo_bos_eos=(True, False),
103 | bert=None,
104 | n_bert_layers=4,
105 | mix_dropout=.0,
106 | bert_pooling='mean',
107 | bert_pad_index=0,
108 | freeze=True,
109 | embed_dropout=.33,
110 | n_encoder_hidden=512,
111 | n_encoder_layers=3,
112 | encoder_dropout=.1,
113 | dropout=.1,
114 | pad_index=0,
115 | unk_index=1,
116 | **kwargs):
117 | super().__init__(**Config().update(locals()))
118 |
119 | from transformers import AutoModel
120 | self.encoder = AutoModel.from_pretrained(self.args.bert,
121 | add_pooling_layer=False,
122 | attention_probs_dropout_prob=self.args.dropout,
123 | hidden_dropout_prob=self.args.dropout)
124 | if self.args.vocab:
125 | self.encoder.resize_token_embeddings(self.args.n_words)
126 | self.token_dropout = TokenDropout(self.args.get('token_dropout', 0))
127 | self.proj = nn.Linear(self.args.n_encoder_hidden, self.args.upsampling * self.args.n_encoder_hidden)
128 | self.decoder = TransformerDecoder(layer=TransformerDecoderLayer(n_model=self.args.n_encoder_hidden,
129 | dropout=self.args.dropout),
130 | n_layers=self.args.n_decoder_layers)
131 | self.classifier = nn.Linear(self.args.n_encoder_hidden, self.args.n_words)
132 |
133 | def forward(self, words):
134 | r"""
135 | Args:
136 | words (~torch.LongTensor): ``[batch_size, seq_len]``.
137 | Word indices.
138 |
139 | Returns:
140 | ~torch.Tensor:
141 | Representations for the src sentences of the shape ``[batch_size, seq_len, n_model]``.
142 | """
143 | x = self.encoder(inputs_embeds=self.token_dropout(self.encoder.embeddings.word_embeddings(words)),
144 | attention_mask=words.ne(self.args.pad_index))[0]
145 | return self.encoder_dropout(x)
146 |
147 | def resize(self, x):
148 | batch_size, seq_len, *_, upsampling = x.shape
149 | resized = x.new_zeros(batch_size, seq_len * upsampling, *_)
150 | for i, j in enumerate(x.unbind(-1)):
151 | resized[:, i::upsampling] = j
152 | return resized
153 |
154 | def loss(self, x, src, tgt, src_mask, tgt_mask, ratio=0):
155 | x_tgt, glat_mask = self.resize(self.proj(x).view(*x.shape, self.args.upsampling)), None
156 | if ratio > 0:
157 | with torch.no_grad(), torch.cuda.amp.autocast(enabled=False):
158 | mask = self.resize(src_mask.unsqueeze(-1).repeat(1, 1, self.args.upsampling))
159 | preds, s_x = self.decode(x, src, src_mask, True)
160 | align = self.align(s_x.log_softmax(2).transpose(0, 1), src, tgt, src_mask, tgt_mask)
161 | probs = ((align.ne(preds) & mask).sum(-1) / mask.sum(-1) * ratio).clamp_(0, 1)
162 | glat_mask = (src.new_zeros(mask.shape) + probs.unsqueeze(-1)).bernoulli().bool()
163 | e_tgt = self.encoder.embeddings(torch.where(align.ge(self.args.n_words-2), self.args.mask_index, align))
164 | x_tgt = torch.where(glat_mask.unsqueeze(-1), e_tgt, x_tgt)
165 | x = self.decoder(x_tgt=x_tgt,
166 | x_src=x,
167 | tgt_mask=self.resize(src_mask.unsqueeze(-1).repeat(1, 1, self.args.upsampling)),
168 | src_mask=src_mask)
169 | # [tgt_len, batch_size, n_words]
170 | s_x = self.classifier(x).log_softmax(2).transpose(0, 1)
171 | return self.ctc(s_x, src, tgt, src_mask, tgt_mask)
172 |
173 | def ctc(self, s_x, src, tgt, src_mask, tgt_mask, glat_mask=None):
174 | src = self.resize(src.unsqueeze(-1).repeat(1, 1, self.args.upsampling))
175 | # [tgt_len, batch_size]
176 | s_k, s_b = s_x[..., self.args.keep_index], s_x[..., self.args.nul_index]
177 | # [tgt_len, seq_len, batch_size]
178 | s_x = s_x.gather(-1, tgt.repeat(s_b.shape[0], 1, 1)).transpose(1, 2)
179 | s_x = torch.where(src.unsqueeze(-1).eq(tgt.unsqueeze(1)).movedim(0, -1), s_k.unsqueeze(1), s_x)
180 | if glat_mask is not None:
181 | glat_mask = glat_mask.t()
182 | s_b = s_b.masked_fill(glat_mask, 0)
183 | s_x = s_x.masked_fill(glat_mask.unsqueeze(1), 0)
184 | src_lens, tgt_lens = src_mask.sum(-1) * self.args.upsampling, tgt_mask.sum(-1)
185 | tgt_len, seq_len, batch_size = s_x.shape
186 | # [tgt_len, 2, seq_len + 1, batch_size]
187 | s = s_x.new_full((tgt_len, 2, seq_len + 1, batch_size), MIN)
188 | s[0, 0, 0], s[0, 1, 0] = s_b[0], s_x[0, 0]
189 | for t in range(1, tgt_len):
190 | s0 = torch.cat((torch.full_like(s[0, 0, :1], MIN), s[t-1, 1, :-1]))
191 | s1 = s[t-1, 0]
192 | s2 = s[t-1, 1]
193 | s[t, 0] = torch.stack((s0, s1)).logsumexp(0) + s_b[t]
194 | s[t, 1, :-1] = torch.stack((s0[:-1], s1[:-1], s2[:-1])).logsumexp(0) + s_x[t]
195 | s = s[src_lens - 1, 0, tgt_lens, range(batch_size)].logaddexp(s[src_lens - 1, 1, tgt_lens - 1, range(batch_size)])
196 | return -s.sum() / tgt_lens.sum()
197 |
198 | def decode(self, x, src, src_mask, score=False):
199 | batch_size, *_ = x.shape
200 | beam_size, n_words = self.args.beam_size, self.args.n_words
201 | keep_index, nul_index, pad_index = self.args.keep_index, self.args.nul_index, self.args.pad_index
202 | indices = src.new_tensor(range(batch_size)).unsqueeze(1).repeat(1, beam_size).view(-1)
203 | x = self.decoder(x_tgt=self.resize(self.proj(x).view(*x.shape, self.args.upsampling)),
204 | x_src=x,
205 | tgt_mask=self.resize(src_mask.unsqueeze(-1).repeat(1, 1, self.args.upsampling)),
206 | src_mask=src_mask)
207 | src = self.resize(src.unsqueeze(-1).repeat(1, 1, self.args.upsampling))
208 | src_mask = self.resize(src_mask.unsqueeze(-1).repeat(1, 1, self.args.upsampling))
209 |
210 | if not self.args.prefix:
211 | s_x = self.classifier(x)
212 | # [batch_size, tgt_len, topk]
213 | tgt = s_x.topk(self.args.topk, -1)[1]
214 | tgt = torch.where(tgt.eq(keep_index), src.unsqueeze(-1), tgt)
215 | # [batch_size, topk, tgt_len]
216 | tgt = tgt.masked_fill_(~src_mask.unsqueeze(2), self.args.pad_index).transpose(1, 2)
217 | if score:
218 | return tgt[:, 0], s_x
219 | # [batch_size, topk, tgt_len]
220 | tgt = [[j.unique_consecutive() for j in i.unbind(0)] for i in tgt.unbind(0)]
221 | tgt = pad([pad([j[j.ne(nul_index)] for j in i], pad_index) for i in tgt], pad_index)
222 | return tgt
223 |
224 | # [batch_size * beam_size, tgt_len, ...]
225 | x, src, src_mask = x[indices], src[indices], src_mask[indices]
226 | # [batch_size * beam_size, max_len]
227 | tgt = x.new_full((batch_size * beam_size, x.shape[1]), nul_index, dtype=torch.long)
228 | lens = tgt.new_full((tgt.shape[0],), 0)
229 | # [batch_size]
230 | batches = tgt.new_tensor(range(batch_size)) * beam_size
231 | # accumulated scores
232 | # [2, batch_size * beam_size]
233 | s = torch.stack((x.new_full((batch_size, beam_size), -INF).index_fill_(-1, tgt.new_tensor(0), 0).view(-1),
234 | x.new_full((batch_size * beam_size,), -INF)))
235 |
236 | def merge(s_b, s_n, tgt, lens, ends):
237 | # merge the prefixes that have grown in the new step
238 | s_n = s_n.view(batch_size, beam_size, -1)
239 | tgt, lens, ends = tgt.view(batch_size, beam_size, -1), lens.view(batch_size, -1), ends.view(batch_size, -1)
240 | # [batch_size, beam_size, beam_size]
241 | mask = tgt.scatter(-1, (lens.clamp(1) - 1).unsqueeze(-1), nul_index).unsqueeze(2).eq(tgt.unsqueeze(1)).all(-1)
242 | mask = mask & lens.gt(0).unsqueeze(2)
243 | s_g = s_n.gather(-1, ends.unsqueeze(2))
244 | s_n[..., nul_index] = s_n[..., nul_index].logaddexp(s_g.transpose(1, 2).masked_fill(~mask, -INF).logsumexp(2))
245 | s_n.scatter_(-1, ends.unsqueeze(2), torch.where(mask.any(2, True), -INF, s_g))
246 | s_n = s_n.view(batch_size * beam_size, -1)
247 | return s_b, s_n
248 |
249 | for t in range(x.shape[1]):
250 | # [batch_size * beam_size]
251 | mask = src_mask[:, t]
252 | # the past prefixes
253 | ends = tgt[range(tgt.shape[0]), lens - 1]
254 | # [batch_size * beam_size, n_words]
255 | s_t = self.classifier(x[:, t]).log_softmax(1)
256 | s_k = s_t.gather(-1, src[:, t].unsqueeze(-1)).logaddexp(s_t[:, keep_index].unsqueeze(-1))
257 | s_t = s_t.scatter_(-1, src[:, t].unsqueeze(-1), s_k)
258 | s_t[:, keep_index] = -INF
259 | s_e = s_t.gather(1, ends.unsqueeze(1))
260 | s_p = s.logsumexp(0).unsqueeze(-1)
261 | # [batch_size * beam_size]
262 | # the position for blanks are used for storing prefixes kept unchanged
263 | # *a - -> *a
264 | s_b = s_p + s_t.masked_fill(tgt.new_tensor(range(n_words)).ne(nul_index).unsqueeze(0), -INF)
265 | # *a b -> *ab
266 | s_n = s_p + s_t
267 | # *a- a -> *aa
268 | s_n = s_n.scatter_(1, ends.unsqueeze(1), s[0].unsqueeze(1) + s_e)
269 | # *a a -> *a
270 | s_n[:, nul_index] = s[1] + s_e.squeeze(1)
271 | # [2, batch_size * beam_size, n_words]
272 | s = torch.stack((merge(s_b, s_n, tgt, lens, ends)))
273 | # [batch_size, beam_size]
274 | cands = s.logsumexp(0).view(batch_size, -1).topk(beam_size, -1)[1]
275 | # [2, batch_size * beam_size]
276 | s = s.view(2, batch_size, -1).gather(-1, cands.repeat(2, 1, 1)).view(2, -1)
277 | # beams, tokens = cands // n_words, cands % n_words
278 | beams, tokens = cands.div(n_words, rounding_mode='floor'), (cands % n_words).view(-1, 1)
279 | indices = (batches.unsqueeze(-1) + beams).view(-1)
280 | lens[mask] = lens[indices[mask]]
281 | # [batch_size * beam_size, max_len]
282 | tgt[mask] = tgt[indices[mask]].scatter_(1, lens[mask].unsqueeze(1), tokens[mask])
283 | lens += tokens.ne(nul_index).squeeze(1) & mask
284 | cands = s.logsumexp(0).view(batch_size, -1).topk(self.args.topk, -1)[1]
285 | tgt = tgt[(batches.unsqueeze(-1) + cands).view(-1)].view(batch_size, self.args.topk, -1)
286 | tgt = pad([pad([j[j.ne(nul_index)] for j in i], pad_index) for i in tgt], pad_index)
287 | return tgt
288 |
289 | def align(self, s_x, src, tgt, src_mask, tgt_mask):
290 | src = self.resize(src.unsqueeze(-1).repeat(1, 1, self.args.upsampling))
291 | # [tgt_len, batch_size]
292 | s_k, s_b = s_x[..., self.args.keep_index], s_x[..., self.args.nul_index]
293 | # [tgt_len, seq_len, batch_size]
294 | s_x = s_x.gather(-1, tgt.repeat(s_b.shape[0], 1, 1)).transpose(1, 2)
295 | s_x = torch.where(src.unsqueeze(-1).eq(tgt.unsqueeze(1)).movedim(0, -1), s_k.unsqueeze(1), s_x)
296 | src_lens, tgt_lens = src_mask.sum(-1) * self.args.upsampling, tgt_mask.sum(-1)
297 | tgt_len, seq_len, batch_size = s_x.shape
298 | # [tgt_len, 2, seq_len + 1, batch_size]
299 | s = s_x.new_full((tgt_len, 2, seq_len + 1, batch_size), -INF)
300 | p = tgt.new_full((tgt_len, 2, seq_len + 1, batch_size), -1)
301 | s[0, 0, 0], s[0, 1, 0] = s_b[0], s_x[0, 0]
302 |
303 | for t in range(1, tgt_len):
304 | s0 = torch.cat((torch.full_like(s[0, 0, :1], -INF), s[t-1, 1, :-1]))
305 | s1 = s[t-1, 0]
306 | s2 = s[t-1, 1]
307 | s_t, p[t, 0] = torch.stack((s0, s1)).max(0)
308 | s[t, 0] = s_t + s_b[t]
309 | s_t, p[t, 1, :-1] = torch.stack((s0[:-1], s1[:-1], s2[:-1])).max(0)
310 | s[t, 1, :-1] = s_t + s_x[t]
311 | _, p_t = torch.stack((s[src_lens - 1, 0, tgt_lens, range(batch_size)],
312 | s[src_lens - 1, 1, tgt_lens - 1, range(batch_size)])).max(0)
313 |
314 | def backtrack(p, tgt, notnul):
315 | j, pred = [len(p[0][0])-1, len(p[0][0])-2], []
316 | for i in reversed(range(len(p))):
317 | prev = p[i][notnul][j[notnul]]
318 | pred.append(tgt[j[notnul]] if bool(notnul) else self.args.nul_index)
319 | if notnul == 0:
320 | if prev == 0:
321 | notnul = 1
322 | j[notnul] = j[1-notnul] - 1
323 | elif notnul == 1:
324 | if prev == 0:
325 | j[notnul] -= 1
326 | if prev == 1:
327 | notnul = 0
328 | j[notnul] = j[1-notnul]
329 | return tuple(reversed(pred))
330 | p_t, tgt, preds = p_t.tolist(), tgt.tolist(), torch.full_like(src, self.args.pad_index)
331 | for i, (src_len, tgt_len) in enumerate(zip(src_lens.tolist(), tgt_lens.tolist())):
332 | preds[i, :src_len] = src.new_tensor(backtrack(p[:src_len, :, :tgt_len+1, i].tolist(), tgt[i][:tgt_len], p_t[i]))
333 | return preds
334 |
--------------------------------------------------------------------------------
/ctc/parser.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import os
4 | import tempfile
5 |
6 | import errant
7 | import torch
8 | import torch.distributed as dist
9 | from torch.optim import AdamW, Optimizer
10 |
11 | from supar.config import Config
12 | from supar.parser import Parser
13 | from supar.utils import Dataset
14 | from supar.utils.field import Field
15 | from supar.utils.logging import get_logger
16 | from supar.utils.parallel import gather, is_dist, is_master
17 | from supar.utils.tokenizer import TransformerTokenizer
18 | from supar.utils.transform import Batch
19 |
20 | from .metric import PerplexityMetric
21 | from .model import CTCModel
22 | from .transform import Text
23 |
24 | logger = get_logger(__name__)
25 |
26 |
27 | class CTCParser(Parser):
28 |
29 | NAME = 'ctc'
30 | MODEL = CTCModel
31 |
32 | def __init__(self, *args, **kwargs):
33 | super().__init__(*args, **kwargs)
34 |
35 | self.SRC = self.transform.SRC
36 | self.TGT = self.transform.TGT
37 | self.annotator = errant.load("en")
38 |
39 | def init_optimizer(self) -> Optimizer:
40 | return AdamW(params=[{'params': p, 'lr': self.args.lr * (1 if n.startswith('encoder') else self.args.lr_rate)}
41 | for n, p in self.model.named_parameters()],
42 | lr=self.args.lr,
43 | betas=(self.args.get('mu', 0.9), self.args.get('nu', 0.999)),
44 | eps=self.args.get('eps', 1e-8),
45 | weight_decay=self.args.get('weight_decay', 0))
46 |
47 | def train_step(self, batch: Batch) -> torch.Tensor:
48 | src, tgt = batch
49 | src_mask, tgt_mask = batch.mask, tgt.ne(self.args.pad_index)
50 | mask = tgt_mask.sum(-1).lt(src_mask.sum(-1) * self.args.upsampling)
51 | src, tgt, src_mask, tgt_mask = src[mask], tgt[mask], src_mask[mask], tgt_mask[mask]
52 | x = self.model(src)
53 | loss = self.model.loss(x, src, tgt, src_mask, tgt_mask, self.args.glat)
54 | return loss
55 |
56 | @torch.no_grad()
57 | def eval_step(self, batch: Batch) -> PerplexityMetric:
58 | src, tgt = batch
59 | src_mask, tgt_mask = batch.mask, tgt.ne(self.args.pad_index)
60 | mask = tgt_mask.sum(-1).lt(src_mask.sum(-1) * self.args.upsampling)
61 | src, tgt, src_mask, tgt_mask = src[mask], tgt[mask], src_mask[mask], tgt_mask[mask]
62 | x = self.model(src)
63 | loss = self.model.loss(x, src, tgt, src_mask, tgt_mask)
64 | preds = golds = None
65 | if self.args.eval_tgt:
66 | golds = [(s.values[0], s.values[1], s.fields['src'].tolist(), t.tolist())
67 | for s, t in zip(batch.sentences, tgt[tgt_mask].split(tgt_mask.sum(-1).tolist()))]
68 | preds = self.model.decode(x, src, batch.mask)[:, 0]
69 | pred_mask = preds.ne(self.args.pad_index)
70 | preds = [i.tolist() for i in preds[pred_mask].split(pred_mask.sum(-1).tolist())]
71 | preds = [(s.values[0], self.TGT.tokenize.decode(i), s.fields['src'].tolist(), i)
72 | for s, i in zip(batch.sentences, preds)]
73 | return PerplexityMetric(loss,
74 | preds,
75 | golds,
76 | tgt_mask,
77 | (None if self.args.lev else self.annotator),
78 | not self.args.eval_tgt)
79 |
80 | @torch.no_grad()
81 | def pred_step(self, batch: Batch) -> Batch:
82 | src, = batch
83 | mask = batch.mask
84 | for _ in range(self.args.iteration):
85 | x = self.model(src)
86 | tgt = self.model.decode(x, src, mask)
87 | src = tgt[:, 0]
88 | mask = src.ne(self.args.pad_index)
89 | batch.tgt = [[self.TGT.tokenize.decode(cand).strip() for cand in i] for i in tgt.tolist()]
90 | return batch
91 |
92 | @classmethod
93 | def build(cls, path, min_freq=2, fix_len=20, **kwargs):
94 | r"""
95 | Build a brand-new Parser, including initialization of all data fields and model parameters.
96 |
97 | Args:
98 | path (str):
99 | The path of the model to be saved.
100 | min_freq (str):
101 | The minimum frequency needed to include a token in the vocabulary. Default: 2.
102 | fix_len (int):
103 | The max length of all subword pieces. The excess part of each piece will be truncated.
104 | Required if using CharLSTM/BERT.
105 | Default: 20.
106 | kwargs (dict):
107 | A dict holding the unconsumed arguments.
108 | """
109 |
110 | args = Config(**locals())
111 | os.makedirs(os.path.dirname(path) or './', exist_ok=True)
112 | if os.path.exists(path) and not args.build:
113 | return cls.load(**args)
114 |
115 | logger.info("Building the fields")
116 | t = TransformerTokenizer(args.bert)
117 | SRC = Field('src', pad=t.pad, unk=t.unk, bos=t.bos, eos=t.eos, tokenize=t)
118 | TGT = Field('tgt', pad=t.pad, unk=t.unk, bos=t.bos, eos=t.eos, tokenize=t)
119 | transform = Text(SRC=SRC, TGT=TGT)
120 | if args.vocab:
121 | if is_master():
122 | t.extend(Dataset(transform, args.train, **args).src)
123 | if is_dist():
124 | with tempfile.TemporaryDirectory(dir='.') as td:
125 | td = gather(td)[0]
126 | if is_master():
127 | torch.save(t, f'{td}/t')
128 | dist.barrier()
129 | t = torch.load(f'{td}/t')
130 | SRC.vocab = TGT.vocab = t.vocab
131 |
132 | args.update({'n_words': len(SRC.vocab) + 2,
133 | 'pad_index': SRC.pad_index,
134 | 'unk_index': SRC.unk_index,
135 | 'bos_index': SRC.bos_index,
136 | 'eos_index': SRC.eos_index,
137 | 'mask_index': t.mask_token_id,
138 | 'keep_index': len(SRC.vocab),
139 | 'nul_index': len(SRC.vocab) + 1})
140 | logger.info(f"{transform}")
141 | logger.info("Building the model")
142 | model = cls.MODEL(**args)
143 | logger.info(f"{model}\n")
144 |
145 | parser = cls(args, model, transform)
146 | parser.model.to(parser.device)
147 | return parser
148 |
--------------------------------------------------------------------------------
/ctc/struct.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import annotations
4 |
5 | from typing import List, Optional
6 |
7 | import torch
8 | from torch.distributions.utils import lazy_property
9 |
10 | from supar.structs.dist import StructuredDistribution
11 | from supar.structs.semiring import LogSemiring, Semiring
12 |
13 |
14 | class Levenshtein(StructuredDistribution):
15 |
16 | def __init__(
17 | self,
18 | scores: torch.Tensor,
19 | lens: Optional[torch.LongTensor] = None
20 | ) -> Levenshtein:
21 | super().__init__(scores)
22 |
23 | batch_size, _, seq_len, src_len = scores.shape[:4]
24 | if lens is not None:
25 | self.lens = lens
26 | else:
27 | self.lens = (scores.new_zeros(batch_size, 2) + scores.new_tensor(src_len, seq_len)).long()
28 | self.src_lens, self.tgt_lens = lens.unbind(-1)
29 | self.src_mask = self.src_lens.unsqueeze(-1).gt(self.lens.new_tensor(range(src_len)))
30 | self.tgt_mask = self.tgt_lens.unsqueeze(-1).gt(self.lens.new_tensor(range(seq_len)))
31 |
32 | def __add__(self, other):
33 | return Levenshtein(torch.stack((self.scores, other.scores)), self.lens)
34 |
35 | @lazy_property
36 | def argmax(self):
37 | margs = self.backward(self.max.sum())
38 | margs, edits = margs.argmax(1).transpose(1, 2), [torch.where(i) for i in margs.sum(1).transpose(1, 2).unbind()]
39 | return [torch.stack((e[0], e[1], m[e])).t().tolist() for e, m in zip(edits, margs)]
40 |
41 | def score(self, value: List) -> torch.Tensor:
42 | lens = self.lens.new_tensor([len(i) for i in value])
43 | edit_mask = lens.unsqueeze(-1).gt(lens.new_tensor(range(max(lens))))
44 | edits = list(self.lens.new_tensor([(i,) + span for i, spans in enumerate(value) for span in spans]).unbind(-1))
45 | s_edit = self.scores[edits[0], edits[3], edits[2], edits[1]]
46 | s = s_edit.new_full(edit_mask.shape, LogSemiring.one).masked_scatter_(edit_mask, s_edit)
47 | return LogSemiring.prod(s)
48 |
49 | def forward(self, semiring: Semiring) -> torch.Tensor:
50 | # [4, seq_len, src_len, batch_size, ...]
51 | s_edit = semiring.convert(self.scores.movedim(0, 3))
52 |
53 | _, seq_len, src_len, batch_size = s_edit.shape[:4]
54 | tgt_lens, src_lens, src_mask = self.tgt_lens, self.src_lens, self.src_mask.t()
55 | # [seq_len, src_len, batch_size]
56 | alpha = semiring.zeros_like(s_edit[0])
57 | trans = semiring.cumprod(torch.cat((semiring.ones_like(s_edit[0, :, :1]), s_edit[0, :, 1:]), 1), 1)
58 | # [seq_len, src_len, src_len, batch_size]
59 | trans = trans.unsqueeze(2) - trans.unsqueeze(1)
60 | trans_mask = src_mask.unsqueeze(0) & torch.ones_like(src_mask).unsqueeze(1)
61 | # [src_len, src_len, batch_size]
62 | trans_mask = trans_mask & src_mask.new_ones(src_len, src_len).tril(-1).unsqueeze(-1)
63 |
64 | for t in range(seq_len):
65 | s_a = alpha[t - 1] if t > 0 else semiring.ones_like(trans[0, 0])
66 | # INSERT
67 | s_i = semiring.mul(s_a, s_edit[1, t])
68 | # KEEP
69 | s_k = torch.cat((semiring.zeros_like(s_a[:1]), semiring.mul(s_a[:-1], s_edit[2, t, 1:])), 0)
70 | # REPLACE
71 | s_r = torch.cat((semiring.zeros_like(s_a[:1]), semiring.mul(s_a[:-1], s_edit[3, t, 1:])), 0)
72 | # SWAP
73 | s_s = torch.cat((semiring.zeros_like(s_a[:1]), semiring.mul(s_a[:-1], s_edit[4, t, 1:])), 0)
74 | # [src_len, batch_size]
75 | s_a = semiring.sum(torch.stack((s_i, s_k, s_r, s_s)), 0)
76 | # DELETE
77 | s_d = semiring.sum(semiring.zero_mask_(semiring.mul(trans[t], s_a.unsqueeze(0)), ~trans_mask), 1)
78 | # [src_len, batch_size]
79 | alpha[t] = semiring.add(s_d, s_a)
80 | # the full input is consumed when the final output symbol is generated
81 | return semiring.unconvert(alpha[tgt_lens - 1, src_lens - 1, range(batch_size)])
82 |
--------------------------------------------------------------------------------
/ctc/transform.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from __future__ import annotations
4 |
5 | import os
6 | import tempfile
7 | from contextlib import contextmanager
8 | from io import StringIO
9 | from typing import Iterable, List, Optional, Union
10 |
11 | import pathos.multiprocessing as mp
12 | import spacy
13 | import spacy.parts_of_speech as POS
14 | import torch.distributed as dist
15 | from rapidfuzz.distance import Indel
16 | from spacy.tokens import Doc
17 |
18 | from supar.utils import Field
19 | from supar.utils.fn import binarize, debinarize
20 | from supar.utils.logging import progress_bar
21 | from supar.utils.parallel import gather, is_dist, is_master
22 | from supar.utils.tokenizer import Tokenizer
23 | from supar.utils.transform import Sentence, Transform
24 |
25 |
26 | class Alignment:
27 | # Protected class resource
28 | _open_pos = {POS.ADJ, POS.ADV, POS.NOUN, POS.VERB}
29 |
30 | # Input 1: An original text string parsed by spacy
31 | # Input 2: A corrected text string parsed by spacy
32 | # Input 3: A flag for standard Levenshtein alignment
33 | def __init__(self, orig, cor, lev=False, nlp=None):
34 | # Set orig and cor
35 | self.nlp = nlp
36 | self.orig_toks, self.cor_toks = orig, cor
37 | self.orig = self.parse(orig)
38 | self.cor = self.parse(cor)
39 | # Align orig and cor and get the cost and op matrices
40 | self.cost_matrix, self.op_matrix = self.align(lev)
41 | # Get the cheapest align sequence from the op matrix
42 | self.align_seq = self.get_cheapest_align_seq()
43 |
44 | # Input: A flag for standard Levenshtein alignment
45 | # Output: The cost matrix and the operation matrix of the alignment
46 | def align(self, lev):
47 | # Sentence lengths
48 | o_len = len(self.orig)
49 | c_len = len(self.cor)
50 | # Lower case token IDs (for transpositions)
51 | # Create the cost_matrix and the op_matrix
52 | cost_matrix = [[0.0 for j in range(c_len+1)] for i in range(o_len+1)]
53 | op_matrix = [["O" for j in range(c_len+1)] for i in range(o_len+1)]
54 | # Fill in the edges
55 | for i in range(1, o_len+1):
56 | cost_matrix[i][0] = cost_matrix[i-1][0] + 1
57 | op_matrix[i][0] = "D"
58 | for j in range(1, c_len+1):
59 | cost_matrix[0][j] = cost_matrix[0][j-1] + 1
60 | op_matrix[0][j] = "I"
61 |
62 | # Loop through the cost_matrix
63 | for i in range(o_len):
64 | for j in range(c_len):
65 | # Matches
66 | if self.orig[i].orth == self.cor[j].orth and self.orig_toks[i] == self.cor_toks[j]:
67 | cost_matrix[i+1][j+1] = cost_matrix[i][j]
68 | op_matrix[i+1][j+1] = "M"
69 | # Non-matches
70 | else:
71 | del_cost = cost_matrix[i][j+1] + 1
72 | ins_cost = cost_matrix[i+1][j] + 1
73 | trans_cost = float("inf") # currently ignore swap/transpose
74 | k = 0
75 | # Standard Levenshtein (S = 1)
76 | if lev:
77 | sub_cost = cost_matrix[i][j] + 1
78 | # Linguistic Damerau-Levenshtein
79 | else:
80 | # Custom substitution
81 | sub_cost = cost_matrix[i][j] + self.get_sub_cost(self.orig[i], self.cor[j])
82 | # Costs
83 | costs = [trans_cost, sub_cost, ins_cost, del_cost]
84 | # Get the index of the cheapest (first cheapest if tied)
85 | l = costs.index(min(costs))
86 | # Save the cost and the op in the matrices
87 | cost_matrix[i+1][j+1] = costs[l]
88 | if l == 0:
89 | op_matrix[i+1][j+1] = "T"+str(k+1)
90 | elif l == 1:
91 | op_matrix[i+1][j+1] = "S"
92 | elif l == 2:
93 | op_matrix[i+1][j+1] = "I"
94 | else:
95 | op_matrix[i+1][j+1] = "D"
96 | # Return the matrices
97 | return cost_matrix, op_matrix
98 |
99 | # Input 1: A spacy orig Token
100 | # Input 2: A spacy cor Token
101 | # Output: A linguistic cost between 0 < x < 2
102 | def get_sub_cost(self, o, c):
103 | # Short circuit if the only difference is case
104 | if o.lower == c.lower:
105 | return 0
106 | # Lemma cost
107 | if o.lemma == c.lemma:
108 | lemma_cost = 0
109 | else:
110 | lemma_cost = 0.499
111 | # POS cost
112 | if o.pos == c.pos:
113 | pos_cost = 0
114 | elif o.pos in self._open_pos and c.pos in self._open_pos:
115 | pos_cost = 0.25
116 | else:
117 | pos_cost = 0.5
118 | # Char cost
119 | char_cost = Indel.normalized_distance(o.text, c.text)
120 | # Combine the costs
121 | return lemma_cost + pos_cost + char_cost
122 |
123 | # Get the cheapest alignment sequence and indices from the op matrix
124 | def get_cheapest_align_seq(self):
125 | i = len(self.op_matrix)-1
126 | j = len(self.op_matrix[0])-1
127 | op_set = {'D': 0, 'I': 1, 'M': 2, 'S': 3}
128 | align_seq = [(i, j, op_set['M'])]
129 | # Work backwards from bottom right until we hit top left
130 | while i + j != 0:
131 | # Get the edit operation in the current cell
132 | op = self.op_matrix[i][j]
133 | # Matches and substitutions
134 | if op in {"M", "S"}:
135 | i -= 1
136 | j -= 1
137 | # Deletions
138 | elif op == "D":
139 | i -= 1
140 | # Insertions
141 | elif op == "I":
142 | j -= 1
143 | align_seq.append((i, j, op_set[op]))
144 | # Reverse the list to go from left to right and return
145 | align_seq.reverse()
146 | return align_seq
147 |
148 | # Alignment object string representation
149 | def __str__(self):
150 | orig = " ".join(["Orig:"]+[tok.text for tok in self.orig])
151 | cor = " ".join(["Cor:"]+[tok.text for tok in self.cor])
152 | cost_matrix = "\n".join(["Cost Matrix:"]+[str(row) for row in self.cost_matrix])
153 | op_matrix = "\n".join(["Operation Matrix:"]+[str(row) for row in self.op_matrix])
154 | seq = "Best alignment: "+str(self.align_seq)
155 | return "\n".join([orig, cor, cost_matrix, op_matrix, seq])
156 |
157 | def parse(self, text):
158 | if isinstance(text, str):
159 | new_text = []
160 | for tok in text.split(): # remove bpe delimeter
161 | new_text.append(tok if tok[-4:] != "" else tok[:-4])
162 | text = Doc(self.nlp.vocab, new_text)
163 | else:
164 | new_text = []
165 | for tok in text:
166 | new_text.append(tok if tok[-4:] != "" else tok[:-4])
167 | text = Doc(self.nlp.vocab, new_text)
168 | self.nlp.tagger(text)
169 | self.nlp.parser(text)
170 | return text
171 |
172 |
173 | class Text(Transform):
174 |
175 | fields = ['SRC', 'TGT']
176 |
177 | def __init__(
178 | self,
179 | SRC: Optional[Union[Field, Iterable[Field]]] = None,
180 | TGT: Optional[Union[Field, Iterable[Field]]] = None,
181 | ) -> Text:
182 | super().__init__()
183 |
184 | self.SRC = SRC
185 | self.TGT = TGT
186 |
187 | @property
188 | def src(self):
189 | return self.SRC,
190 |
191 | @property
192 | def tgt(self):
193 | return self.TGT,
194 |
195 | def load(
196 | self,
197 | data: Union[str, Iterable],
198 | lang: Optional[str] = None,
199 | **kwargs
200 | ) -> Iterable[TextSentence]:
201 | r"""
202 | Loads the data in Text-X format.
203 | Also supports for loading data from Text-U file with comments and non-integer IDs.
204 |
205 | Args:
206 | data (str or Iterable):
207 | A filename or a list of instances.
208 | lang (str):
209 | Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize.
210 | ``None`` if tokenization is not required.
211 | Default: ``None``.
212 |
213 | Returns:
214 | A list of :class:`TextSentence` instances.
215 | """
216 |
217 | if lang is not None:
218 | tokenizer = Tokenizer(lang)
219 | if isinstance(data, str) and os.path.exists(data):
220 | f = open(data)
221 | if data.endswith('.txt'):
222 | lines = (i
223 | for s in f
224 | if len(s) > 1
225 | for i in StringIO((s.split() if lang is None else tokenizer(s)) + '\n'))
226 | else:
227 | lines = f
228 | else:
229 | if lang is not None:
230 | data = [tokenizer(s) for s in ([data] if isinstance(data, str) else data)]
231 | else:
232 | data = [data] if isinstance(data[0], str) else data
233 | lines = (i for s in data for i in StringIO(s + '\n'))
234 |
235 | index, sentence, nlp = 0, [], spacy.load("en", disable=["ner"])
236 | for line in lines:
237 | line = line.strip()
238 | if len(line) == 0:
239 | yield TextSentence(self, sentence, index, nlp)
240 | index += 1
241 | sentence = []
242 | else:
243 | sentence.append(line)
244 |
245 |
246 | class TextSentence(Sentence):
247 |
248 | def __init__(self, transform: Text, lines: List[str], index: Optional[int] = None, nlp=None) -> TextSentence:
249 | super().__init__(transform, index)
250 |
251 | self.cands = [(line+'\t').split('\t')[1] for line in lines[1:]]
252 | src, tgt = lines[0].split('\t')[1], self.cands[0]
253 | self.values = [src, tgt]
254 |
255 | def __repr__(self):
256 | self.cands = self.values[1] if isinstance(self.values[1], list) else [self.values[1]]
257 | lines = ['S\t' + self.values[0]]
258 | lines.extend(['T\t' + i for i in self.cands])
259 | return '\n'.join(lines) + '\n'
260 |
261 | @classmethod
262 | def align(cls, src, tgt, nlp):
263 | return Alignment(src, tgt, nlp=nlp).align_seq
264 |
--------------------------------------------------------------------------------
/data/clang8.toy:
--------------------------------------------------------------------------------
1 | S About winter
2 | T About winter
3 |
4 | S This is my second post .
5 | T This is my second post .
6 |
7 | S I will appreciate it if you correct my sentences .
8 | T I would appreciate it if you corrected my sentences .
9 |
10 | S It 's been getting colder these days here in Japan .
11 | T It 's been getting colder these days here in Japan .
12 |
13 | S The summer weather in Japan is not agreeable to me with its high humidity and temperature .
14 | T The summer weather in Japan is not agreeable to me with its high humidity and temperature .
15 |
16 | S So , as the winter is coming , I 'm getting to feel better .
17 | T So , as the winter is coming , I 'm getting to feel better .
18 |
19 | S Coldness is my energy .
20 | T Coldness is my energy .
21 |
22 | S And also , around the new year 's holidays , we will have a lot of enjoyable events
23 | T And also , around the new year 's holidays , we will have a lot of enjoyable events .
24 |
25 | S mostly with delicious foods , drinks , and good conversations .
26 | T Mostly with delicious food , drinks , and good conversation .
27 |
28 | S In addition , it is the time for skiing and snow boarding :)
29 | T In addition , it is the time for skiing and snowboarding :)
30 |
31 | S It is the very exciting season .
32 | T It is a very exciting season .
33 |
34 | S But , before enjoying those kind of happy time , I have to do a kind of boring ,
35 | T But , before enjoying those kinds of happy times , I have to do some kind of boring ,
36 |
37 | S customary practice .
38 | T customary practice .
39 |
40 | S Writing new year 's greeting cards is somehow a pain in the neck .
41 | T Writing new year 's greeting cards is somehow a pain in the neck .
42 |
43 | S Actually , I do n't have enough time to come up with an idea of the card 's design .
44 | T Actually , I did n't have enough time to come up with an idea for the card 's design .
45 |
46 | S I wish i could come across an good one in my mind .
47 | T I wish I could come across a good one in my mind .
48 |
49 | S Thank you for reading & thanks for your time .
50 | T Thank you for reading & thanks for your time .
51 |
--------------------------------------------------------------------------------
/pred.sh:
--------------------------------------------------------------------------------
1 | args=$@
2 | for arg in $args; do
3 | eval "$arg"
4 | done
5 |
6 | echo "config: ${config:=configs/roberta.yaml}"
7 | echo "path: ${path:=exp/ctc.roberta/model}"
8 | echo "data: ${data:=data/conll14.test}"
9 | echo "pred: ${pred:=$path.conll14.test.pred}"
10 | echo "input: ${input:=data/conll14.test.input}"
11 | echo "errant: ${errant:=data/conll14.test.errant.m2}"
12 | echo "devices: ${devices:=0}"
13 | echo "batch: ${batch:=10000}"
14 | echo "beam: ${beam:=12}"
15 | echo "iteration: ${iteration:=2}"
16 |
17 | (set -x; python -u run.py predict -d $devices -c $config -p $path --data $data --pred $pred --batch-size=$batch --beam-size=$beam --iteration $iteration
18 | CUDA_VISIBLE_DEVICES=$devices python recover.py --hyp $pred -o $pred.out -i $input -p $path -m 62)
19 |
20 | if ! conda env list | grep -q "^py27"; then
21 | echo "Creating the py27 environment..."; conda create -n py27 -y python=2.7
22 | fi
23 |
24 | source ~/anaconda3/etc/profile.d/conda.sh
25 | conda activate py27
26 | python tools/m2scorer/scripts/m2scorer.py -v $pred.out data/conll14.test.m2 > $pred.m2scorer.log
27 | tail -n 9 $pred.m2scorer.log
28 | conda deactivate
--------------------------------------------------------------------------------
/recover.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import argparse
4 | from ctc.parser import CTCParser
5 |
6 |
7 | def convert(file, fout, fin, fpath, max_len=64):
8 | count, sentence = 0, []
9 | tokenize_func = CTCParser.load(fpath).SRC.tokenize
10 | with open(file) as f, open(fout, 'w') as fout, open(fin) as fin:
11 | src_lines = [line.rstrip("\n") for line in fin]
12 | tgt_lines = []
13 | for line in f:
14 | line = line.strip()
15 | if len(line) == 0:
16 | tgt_lines.append((sentence[1]+'\t').split('\t')[1])
17 | sentence = []
18 | else:
19 | sentence.append(line)
20 | count = 0
21 | for line in src_lines:
22 | if len(tokenize_func(line)) >= max_len:
23 | fout.write(line + "\n")
24 | else:
25 | fout.write(tgt_lines[count] + "\n")
26 | count += 1
27 |
28 |
29 | if __name__ == "__main__":
30 | parser = argparse.ArgumentParser(description='Output files in line with m2scorer eval format.')
31 | parser.add_argument('--path', '-p', help='path to the model file')
32 | parser.add_argument('--input', '-i', help='path to the input file')
33 | parser.add_argument('--hyp', help='path to the predicted file')
34 | parser.add_argument('--fout', '-o', help='path to output file')
35 | parser.add_argument('--max_len', '-m', help='max length')
36 | args = parser.parse_args()
37 | convert(args.hyp, args.fout, args.input, args.path, int(args.max_len))
38 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import argparse
4 |
5 | from ctc import CTCParser
6 | from supar.cmds.run import init
7 |
8 |
9 | def main():
10 | parser = argparse.ArgumentParser(description='Create CTC Parser.')
11 | parser.set_defaults(Parser=CTCParser)
12 | parser.add_argument('--eval-tgt', action='store_true', help='whether to evaluate tgt')
13 | parser.add_argument('--lev', action='store_true', help='whether to evaluate P/R/F using levenshtein')
14 | parser.add_argument('--prefix', action='store_true', help='whether to perform prefix decoding')
15 | parser.add_argument('--glat', type=float, default=0, help='GLAT sampling ratio')
16 | parser.add_argument('--iteration', type=int, default=1, help='times of iterative decoding')
17 | subparsers = parser.add_subparsers(title='Commands', dest='mode')
18 | # train
19 | subparser = subparsers.add_parser('train', help='Train a parser.')
20 | subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first')
21 | subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training')
22 | subparser.add_argument('--encoder', choices=['lstm', 'transformer', 'bert'], default='bert', help='encoder to use')
23 | subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use')
24 | subparser.add_argument('--train', default='data/clang8.train', help='path to train file')
25 | subparser.add_argument('--dev', default='data/bea19.dev', help='path to dev file')
26 | subparser.add_argument('--test', default='data/conll14.test', help='path to test file')
27 | subparser.add_argument('--embed', default='glove-6b-100', help='file or embeddings available at `supar.utils.Embedding`')
28 | subparser.add_argument('--vocab', action='store_true', help='extend the vocab from new data')
29 | # evaluate
30 | subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.')
31 | subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
32 | subparser.add_argument('--data', default='data/conll14.test', help='path to dataset')
33 | # predict
34 | subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.')
35 | subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use')
36 | subparser.add_argument('--data', default='data/conll14.test', help='path to dataset')
37 | subparser.add_argument('--pred', default='pred.txt', help='path to predicted result')
38 | subparser.add_argument('--prob', action='store_true', help='whether to output probs')
39 | init(parser)
40 |
41 |
42 | if __name__ == "__main__":
43 | main()
44 |
--------------------------------------------------------------------------------
/supar:
--------------------------------------------------------------------------------
1 | 3rdparty/parser/supar
--------------------------------------------------------------------------------
/tools/m2scorer/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 2, June 1991
3 |
4 | Copyright (C) 1989, 1991 Free Software Foundation, Inc.,
5 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
6 | Everyone is permitted to copy and distribute verbatim copies
7 | of this license document, but changing it is not allowed.
8 |
9 | Preamble
10 |
11 | The licenses for most software are designed to take away your
12 | freedom to share and change it. By contrast, the GNU General Public
13 | License is intended to guarantee your freedom to share and change free
14 | software--to make sure the software is free for all its users. This
15 | General Public License applies to most of the Free Software
16 | Foundation's software and to any other program whose authors commit to
17 | using it. (Some other Free Software Foundation software is covered by
18 | the GNU Lesser General Public License instead.) You can apply it to
19 | your programs, too.
20 |
21 | When we speak of free software, we are referring to freedom, not
22 | price. Our General Public Licenses are designed to make sure that you
23 | have the freedom to distribute copies of free software (and charge for
24 | this service if you wish), that you receive source code or can get it
25 | if you want it, that you can change the software or use pieces of it
26 | in new free programs; and that you know you can do these things.
27 |
28 | To protect your rights, we need to make restrictions that forbid
29 | anyone to deny you these rights or to ask you to surrender the rights.
30 | These restrictions translate to certain responsibilities for you if you
31 | distribute copies of the software, or if you modify it.
32 |
33 | For example, if you distribute copies of such a program, whether
34 | gratis or for a fee, you must give the recipients all the rights that
35 | you have. You must make sure that they, too, receive or can get the
36 | source code. And you must show them these terms so they know their
37 | rights.
38 |
39 | We protect your rights with two steps: (1) copyright the software, and
40 | (2) offer you this license which gives you legal permission to copy,
41 | distribute and/or modify the software.
42 |
43 | Also, for each author's protection and ours, we want to make certain
44 | that everyone understands that there is no warranty for this free
45 | software. If the software is modified by someone else and passed on, we
46 | want its recipients to know that what they have is not the original, so
47 | that any problems introduced by others will not reflect on the original
48 | authors' reputations.
49 |
50 | Finally, any free program is threatened constantly by software
51 | patents. We wish to avoid the danger that redistributors of a free
52 | program will individually obtain patent licenses, in effect making the
53 | program proprietary. To prevent this, we have made it clear that any
54 | patent must be licensed for everyone's free use or not licensed at all.
55 |
56 | The precise terms and conditions for copying, distribution and
57 | modification follow.
58 |
59 | GNU GENERAL PUBLIC LICENSE
60 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
61 |
62 | 0. This License applies to any program or other work which contains
63 | a notice placed by the copyright holder saying it may be distributed
64 | under the terms of this General Public License. The "Program", below,
65 | refers to any such program or work, and a "work based on the Program"
66 | means either the Program or any derivative work under copyright law:
67 | that is to say, a work containing the Program or a portion of it,
68 | either verbatim or with modifications and/or translated into another
69 | language. (Hereinafter, translation is included without limitation in
70 | the term "modification".) Each licensee is addressed as "you".
71 |
72 | Activities other than copying, distribution and modification are not
73 | covered by this License; they are outside its scope. The act of
74 | running the Program is not restricted, and the output from the Program
75 | is covered only if its contents constitute a work based on the
76 | Program (independent of having been made by running the Program).
77 | Whether that is true depends on what the Program does.
78 |
79 | 1. You may copy and distribute verbatim copies of the Program's
80 | source code as you receive it, in any medium, provided that you
81 | conspicuously and appropriately publish on each copy an appropriate
82 | copyright notice and disclaimer of warranty; keep intact all the
83 | notices that refer to this License and to the absence of any warranty;
84 | and give any other recipients of the Program a copy of this License
85 | along with the Program.
86 |
87 | You may charge a fee for the physical act of transferring a copy, and
88 | you may at your option offer warranty protection in exchange for a fee.
89 |
90 | 2. You may modify your copy or copies of the Program or any portion
91 | of it, thus forming a work based on the Program, and copy and
92 | distribute such modifications or work under the terms of Section 1
93 | above, provided that you also meet all of these conditions:
94 |
95 | a) You must cause the modified files to carry prominent notices
96 | stating that you changed the files and the date of any change.
97 |
98 | b) You must cause any work that you distribute or publish, that in
99 | whole or in part contains or is derived from the Program or any
100 | part thereof, to be licensed as a whole at no charge to all third
101 | parties under the terms of this License.
102 |
103 | c) If the modified program normally reads commands interactively
104 | when run, you must cause it, when started running for such
105 | interactive use in the most ordinary way, to print or display an
106 | announcement including an appropriate copyright notice and a
107 | notice that there is no warranty (or else, saying that you provide
108 | a warranty) and that users may redistribute the program under
109 | these conditions, and telling the user how to view a copy of this
110 | License. (Exception: if the Program itself is interactive but
111 | does not normally print such an announcement, your work based on
112 | the Program is not required to print an announcement.)
113 |
114 | These requirements apply to the modified work as a whole. If
115 | identifiable sections of that work are not derived from the Program,
116 | and can be reasonably considered independent and separate works in
117 | themselves, then this License, and its terms, do not apply to those
118 | sections when you distribute them as separate works. But when you
119 | distribute the same sections as part of a whole which is a work based
120 | on the Program, the distribution of the whole must be on the terms of
121 | this License, whose permissions for other licensees extend to the
122 | entire whole, and thus to each and every part regardless of who wrote it.
123 |
124 | Thus, it is not the intent of this section to claim rights or contest
125 | your rights to work written entirely by you; rather, the intent is to
126 | exercise the right to control the distribution of derivative or
127 | collective works based on the Program.
128 |
129 | In addition, mere aggregation of another work not based on the Program
130 | with the Program (or with a work based on the Program) on a volume of
131 | a storage or distribution medium does not bring the other work under
132 | the scope of this License.
133 |
134 | 3. You may copy and distribute the Program (or a work based on it,
135 | under Section 2) in object code or executable form under the terms of
136 | Sections 1 and 2 above provided that you also do one of the following:
137 |
138 | a) Accompany it with the complete corresponding machine-readable
139 | source code, which must be distributed under the terms of Sections
140 | 1 and 2 above on a medium customarily used for software interchange; or,
141 |
142 | b) Accompany it with a written offer, valid for at least three
143 | years, to give any third party, for a charge no more than your
144 | cost of physically performing source distribution, a complete
145 | machine-readable copy of the corresponding source code, to be
146 | distributed under the terms of Sections 1 and 2 above on a medium
147 | customarily used for software interchange; or,
148 |
149 | c) Accompany it with the information you received as to the offer
150 | to distribute corresponding source code. (This alternative is
151 | allowed only for noncommercial distribution and only if you
152 | received the program in object code or executable form with such
153 | an offer, in accord with Subsection b above.)
154 |
155 | The source code for a work means the preferred form of the work for
156 | making modifications to it. For an executable work, complete source
157 | code means all the source code for all modules it contains, plus any
158 | associated interface definition files, plus the scripts used to
159 | control compilation and installation of the executable. However, as a
160 | special exception, the source code distributed need not include
161 | anything that is normally distributed (in either source or binary
162 | form) with the major components (compiler, kernel, and so on) of the
163 | operating system on which the executable runs, unless that component
164 | itself accompanies the executable.
165 |
166 | If distribution of executable or object code is made by offering
167 | access to copy from a designated place, then offering equivalent
168 | access to copy the source code from the same place counts as
169 | distribution of the source code, even though third parties are not
170 | compelled to copy the source along with the object code.
171 |
172 | 4. You may not copy, modify, sublicense, or distribute the Program
173 | except as expressly provided under this License. Any attempt
174 | otherwise to copy, modify, sublicense or distribute the Program is
175 | void, and will automatically terminate your rights under this License.
176 | However, parties who have received copies, or rights, from you under
177 | this License will not have their licenses terminated so long as such
178 | parties remain in full compliance.
179 |
180 | 5. You are not required to accept this License, since you have not
181 | signed it. However, nothing else grants you permission to modify or
182 | distribute the Program or its derivative works. These actions are
183 | prohibited by law if you do not accept this License. Therefore, by
184 | modifying or distributing the Program (or any work based on the
185 | Program), you indicate your acceptance of this License to do so, and
186 | all its terms and conditions for copying, distributing or modifying
187 | the Program or works based on it.
188 |
189 | 6. Each time you redistribute the Program (or any work based on the
190 | Program), the recipient automatically receives a license from the
191 | original licensor to copy, distribute or modify the Program subject to
192 | these terms and conditions. You may not impose any further
193 | restrictions on the recipients' exercise of the rights granted herein.
194 | You are not responsible for enforcing compliance by third parties to
195 | this License.
196 |
197 | 7. If, as a consequence of a court judgment or allegation of patent
198 | infringement or for any other reason (not limited to patent issues),
199 | conditions are imposed on you (whether by court order, agreement or
200 | otherwise) that contradict the conditions of this License, they do not
201 | excuse you from the conditions of this License. If you cannot
202 | distribute so as to satisfy simultaneously your obligations under this
203 | License and any other pertinent obligations, then as a consequence you
204 | may not distribute the Program at all. For example, if a patent
205 | license would not permit royalty-free redistribution of the Program by
206 | all those who receive copies directly or indirectly through you, then
207 | the only way you could satisfy both it and this License would be to
208 | refrain entirely from distribution of the Program.
209 |
210 | If any portion of this section is held invalid or unenforceable under
211 | any particular circumstance, the balance of the section is intended to
212 | apply and the section as a whole is intended to apply in other
213 | circumstances.
214 |
215 | It is not the purpose of this section to induce you to infringe any
216 | patents or other property right claims or to contest validity of any
217 | such claims; this section has the sole purpose of protecting the
218 | integrity of the free software distribution system, which is
219 | implemented by public license practices. Many people have made
220 | generous contributions to the wide range of software distributed
221 | through that system in reliance on consistent application of that
222 | system; it is up to the author/donor to decide if he or she is willing
223 | to distribute software through any other system and a licensee cannot
224 | impose that choice.
225 |
226 | This section is intended to make thoroughly clear what is believed to
227 | be a consequence of the rest of this License.
228 |
229 | 8. If the distribution and/or use of the Program is restricted in
230 | certain countries either by patents or by copyrighted interfaces, the
231 | original copyright holder who places the Program under this License
232 | may add an explicit geographical distribution limitation excluding
233 | those countries, so that distribution is permitted only in or among
234 | countries not thus excluded. In such case, this License incorporates
235 | the limitation as if written in the body of this License.
236 |
237 | 9. The Free Software Foundation may publish revised and/or new versions
238 | of the General Public License from time to time. Such new versions will
239 | be similar in spirit to the present version, but may differ in detail to
240 | address new problems or concerns.
241 |
242 | Each version is given a distinguishing version number. If the Program
243 | specifies a version number of this License which applies to it and "any
244 | later version", you have the option of following the terms and conditions
245 | either of that version or of any later version published by the Free
246 | Software Foundation. If the Program does not specify a version number of
247 | this License, you may choose any version ever published by the Free Software
248 | Foundation.
249 |
250 | 10. If you wish to incorporate parts of the Program into other free
251 | programs whose distribution conditions are different, write to the author
252 | to ask for permission. For software which is copyrighted by the Free
253 | Software Foundation, write to the Free Software Foundation; we sometimes
254 | make exceptions for this. Our decision will be guided by the two goals
255 | of preserving the free status of all derivatives of our free software and
256 | of promoting the sharing and reuse of software generally.
257 |
258 | NO WARRANTY
259 |
260 | 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY
261 | FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN
262 | OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES
263 | PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED
264 | OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
265 | MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS
266 | TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE
267 | PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING,
268 | REPAIR OR CORRECTION.
269 |
270 | 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
271 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR
272 | REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES,
273 | INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING
274 | OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED
275 | TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY
276 | YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER
277 | PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE
278 | POSSIBILITY OF SUCH DAMAGES.
279 |
280 | END OF TERMS AND CONDITIONS
281 |
282 | How to Apply These Terms to Your New Programs
283 |
284 | If you develop a new program, and you want it to be of the greatest
285 | possible use to the public, the best way to achieve this is to make it
286 | free software which everyone can redistribute and change under these terms.
287 |
288 | To do so, attach the following notices to the program. It is safest
289 | to attach them to the start of each source file to most effectively
290 | convey the exclusion of warranty; and each file should have at least
291 | the "copyright" line and a pointer to where the full notice is found.
292 |
293 |