├── .flake8 ├── .gitignore ├── .gitmodules ├── README.md ├── configs └── roberta.yaml ├── ctc ├── __init__.py ├── metric.py ├── model.py ├── parser.py ├── struct.py └── transform.py ├── data └── clang8.toy ├── pred.sh ├── recover.py ├── run.py ├── supar ├── tools └── m2scorer │ ├── LICENSE │ ├── README │ ├── example │ ├── README │ ├── source_gold │ ├── system │ └── system2 │ ├── m2scorer │ └── scripts │ ├── Tokenizer.py │ ├── combiner.py │ ├── levenshtein.py │ ├── m2scorer.py │ ├── nuclesgmlparser.py │ ├── token_offsets.py │ └── util.py └── train.sh /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 127 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # data files 2 | data 3 | 4 | # bash scripts 5 | *.sh 6 | 7 | # docs 8 | docs/_build 9 | 10 | # intermediate files 11 | build 12 | dist 13 | *.egg-info 14 | *.pyc 15 | 16 | # experimental results 17 | exp 18 | results 19 | wandb 20 | 21 | # log and config files 22 | log.* 23 | *.log 24 | *.cfg 25 | *.ini 26 | *.yml 27 | *.yaml 28 | 29 | # pycache 30 | __pycache__ 31 | 32 | # saved model 33 | *.pkl 34 | *.pt 35 | 36 | # hidden files 37 | .* 38 | 39 | # vscode 40 | .vscode 41 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "3rdparty/parser"] 2 | path = 3rdparty/parser 3 | url = https://github.com/yzhangcs/parser 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Non-autoregressive Text Editing with Copy-aware Latent Alignments 4 | 5 |
6 | Yu Zhang1*  7 | Yue Zhang1*  8 | Leyang Cui2  9 | Guohong Fu1  10 |
11 |
1Soochow University, Suzhou, China
12 |
2Tencent AI Lab
13 | 14 |
15 |

16 | 17 | [![conf](https://img.shields.io/badge/EMNLP%202023-orange?style=flat-square)](https://yzhang.site/assets/pubs/emnlp/2023/ctc.pdf) 18 | [![arxiv](https://img.shields.io/badge/arXiv-2310.07821-b31b1b.svg?style=flat-square)](https://arxiv.org/abs/2310.07821) 19 | [![citation](https://img.shields.io/badge/dynamic/json?label=citation&query=citationCount&url=https%3A%2F%2Fapi.semanticscholar.org%2Fgraph%2Fv1%2Fpaper%2F116277fd27c97d50bba2d8023d3c590c1ea8187b%3Ffields%3DcitationCount&style=flat-square)](https://www.semanticscholar.org/paper/Non-autoregressive-Text-Editing-with-Copy-aware-Zhang-Zhang/116277fd27c97d50bba2d8023d3c590c1ea8187b) 20 | ![python](https://img.shields.io/badge/python-%3E%3D%203.7-pybadges.svg?logo=python&style=flat-square) 21 | 22 |

23 |
24 | 25 | image 26 | 27 |
28 | 29 | ## Citation 30 | 31 | If you are interested in our work, please cite 32 | ```bib 33 | @inproceedings{zhang-etal-2023-ctc, 34 | title = {Non-autoregressive Text Editing with Copy-aware Latent Alignments}, 35 | author = {Zhang, Yu and 36 | Zhang, Yue and 37 | Cui, Leyang and 38 | Fu, Guohong}, 39 | booktitle = {Proceedings of EMNLP}, 40 | year = {2023}, 41 | address = {Singapore} 42 | } 43 | ``` 44 | 45 | ## Setup 46 | 47 | The following packages should be installed: 48 | * [`PyTorch`](https://github.com/pytorch/pytorch): >= 2.0 49 | * [`Transformers`](https://github.com/huggingface/transformers) 50 | * [`Errant`](https://github.com/chrisjbryant/errant) 51 | 52 | Clone this repo recursively: 53 | ```sh 54 | git clone https://github.com/yzhangcs/ctc-copy.git --recursive 55 | ``` 56 | 57 | You can follow this [repo](https://github.com/HillZhang1999/SynGEC) to obtain the 3-stage train/dev/test data for training a English GEC model. 58 | The multilingual datasets are available [here](https://github.com/google-research-datasets/clang8). 59 | 60 | Before running, you are required to preprocess each sentence pair into the format of `SRC:\t[src]\nTGT:\t[tgt]\n`, where `src` and `tgt` are the source and target sentences, respectively. Each sentence pair is separated by a blank line. 61 | See [`data/clang8.toy`](data/clang8.toy) for examples. 62 | 63 | ## Run 64 | 65 | Try the following command to train a 3-stage English model, 66 | ```sh 67 | bash train.sh 68 | ``` 69 | To make predictions & evaluations: 70 | ```sh 71 | bash pred.sh 72 | ``` 73 | 74 | ## Contact 75 | 76 | If you have any questions, please feel free to [email](mailto:yzhang.cs@outlook.com) me. 77 | -------------------------------------------------------------------------------- /configs/roberta.yaml: -------------------------------------------------------------------------------- 1 | encoder: bert 2 | bert: roberta-large 3 | upsampling: 4 4 | beam_size: 12 5 | dropout: .1 6 | token_dropout: .1 7 | n_decoder_layers: 2 8 | find_unused_parameters: 0 9 | topk: 1 10 | label_smoothing: 0 11 | lr: 5e-05 12 | lr_rate: 10 13 | mu: .9 14 | nu: .9 15 | eps: 1e-12 16 | weight_decay: .01 17 | clip: 5.0 18 | min_freq: 2 19 | fix_len: 20 20 | epochs: 64 21 | patience: 10 22 | batch_size: 100000 23 | eval_batch_size: 10000 24 | warmup_steps: 1000 25 | update_steps: 25 26 | max_len: 64 27 | -------------------------------------------------------------------------------- /ctc/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .parser import CTCParser 4 | 5 | __all__ = ['CTCParser'] 6 | -------------------------------------------------------------------------------- /ctc/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import annotations 4 | 5 | import math 6 | import os 7 | import tempfile 8 | from collections import Counter 9 | from typing import Any, List, Optional, Set, Tuple 10 | 11 | import torch 12 | from errant import Annotator 13 | 14 | from supar.structs.fn import levenshtein 15 | from supar.utils.metric import Metric 16 | 17 | 18 | class PerplexityMetric(Metric): 19 | 20 | def __init__( 21 | self, 22 | loss: Optional[float] = None, 23 | preds: Optional[Tuple[torch.Tensor, List, List]] = None, 24 | golds: Optional[Tuple[torch.Tensor, List, List]] = None, 25 | mask: Optional[torch.BoolTensor] = None, 26 | annotator: Annotator = None, 27 | reverse: bool = False, 28 | eps: float = 1e-12 29 | ) -> PerplexityMetric: 30 | super().__init__(reverse=reverse, eps=eps) 31 | 32 | self.n_tokens = 0. 33 | 34 | self.tp = 0.0 35 | self.pred = 0.0 36 | self.gold = 0.0 37 | self.total_loss = 0. 38 | 39 | if loss is not None: 40 | self(loss, preds, golds, mask, annotator) 41 | 42 | def __repr__(self): 43 | s = f"loss: {self.loss:.4f} PPL: {self.ppl:.4f}" 44 | if self.tp > 0: 45 | s += f" - TGT: P: {self.p:6.2%} R: {self.r:6.2%} F0.5: {self.f:6.2%}" 46 | return s 47 | 48 | def __call__( 49 | self, 50 | loss: float, 51 | preds: Tuple[torch.Tensor, List, List], 52 | golds: Tuple[torch.Tensor, List, List], 53 | mask: torch.BoolTensor, 54 | annotator: Any 55 | ) -> PerplexityMetric: 56 | n_tokens = mask.sum().item() 57 | self.n += len(mask) 58 | self.count += 1 59 | self.n_tokens += n_tokens 60 | self.total_loss += float(loss) * n_tokens 61 | 62 | if preds is not None: 63 | if annotator is not None: 64 | with tempfile.TemporaryDirectory() as t: 65 | fsrc, fpred, fgold = os.path.join(t, 'src'), os.path.join(t, 'pred'), os.path.join(t, 'gold') 66 | pred_m2, gold_m2 = os.path.join(t, 'pred.m2'), os.path.join(t, 'gold.m2') 67 | with open(fsrc, 'w') as fs, open(fpred, 'w') as f: 68 | for s, i, *_ in preds: 69 | fs.write(s + '\n') 70 | f.write(i + '\n') 71 | with open(fgold, 'w') as f: 72 | for _, i, *_ in golds: 73 | f.write(i + '\n') 74 | self.errant_parallel(fsrc, fpred, pred_m2, annotator) 75 | self.errant_parallel(fsrc, fgold, gold_m2, annotator) 76 | out = self.errant_compare(pred_m2, gold_m2) 77 | tp, fp, fn = out['tp'], out['fp'], out['fn'] 78 | self.tp += tp 79 | self.pred += tp + fp 80 | self.gold += tp + fn 81 | else: 82 | for p, g in zip(preds, golds): 83 | e_p = self.compare(p[2], p[3]) 84 | e_g = self.compare(g[2], g[3]) 85 | self.tp += len(e_p & e_g) 86 | self.pred += len(e_p) 87 | self.gold += len(e_g) 88 | return self 89 | 90 | def __add__(self, other: PerplexityMetric) -> PerplexityMetric: 91 | metric = PerplexityMetric(eps=self.eps) 92 | metric.n = self.n + other.n 93 | metric.count = self.count + other.count 94 | metric.n_tokens = self.n_tokens + other.n_tokens 95 | metric.total_loss = self.total_loss + other.total_loss 96 | 97 | metric.tp = self.tp + other.tp 98 | metric.pred = self.pred + other.pred 99 | metric.gold = self.gold + other.gold 100 | metric.reverse = self.reverse or other.reverse 101 | return metric 102 | 103 | @property 104 | def score(self): 105 | return self.f 106 | 107 | @property 108 | def loss(self): 109 | return self.total_loss / self.n_tokens 110 | 111 | @property 112 | def ppl(self): 113 | return math.pow(2, (self.loss / math.log(2))) 114 | 115 | @property 116 | def p(self): 117 | return self.tp / (self.pred + self.eps) 118 | 119 | @property 120 | def r(self): 121 | return self.tp / (self.gold + self.eps) 122 | 123 | @property 124 | def f(self): 125 | return (1 + 0.5**2) * self.p * self.r / (0.5**2 * self.p + self.r + self.eps) 126 | 127 | @property 128 | def values(self): 129 | return {'P': self.p, 130 | 'R': self.r, 131 | 'F0.5': self.f} 132 | 133 | def compare(self, s, t) -> Set: 134 | return {(i, edit) for i, _, edit in levenshtein(s, t, align=True)[1] if edit != 0} 135 | 136 | def errant_parallel(self, forig: str, fcor: str, fout: str, annotator: Any) -> None: 137 | from contextlib import ExitStack 138 | 139 | def noop_edit(id=0): 140 | return "A -1 -1|||noop|||-NONE-|||REQUIRED|||-NONE-|||"+str(id) 141 | with ExitStack() as stack, open(fout, "w") as out_m2: 142 | in_files = [stack.enter_context(open(i)) for i in [forig]+[fcor]] 143 | # Process each line of all input files 144 | for line in zip(*in_files): 145 | # Get the original and all the corrected texts 146 | orig = line[0].strip() 147 | cors = line[1:] 148 | # Skip the line if orig is empty 149 | if not orig: 150 | continue 151 | # Parse orig with spacy 152 | orig = annotator.parse(orig) 153 | # Write orig to the output m2 file 154 | out_m2.write(" ".join(["S"]+[token.text for token in orig])+"\n") 155 | # Loop through the corrected texts 156 | for cor_id, cor in enumerate(cors): 157 | cor = cor.strip() 158 | # If the texts are the same, write a noop edit 159 | if orig.text.strip() == cor: 160 | out_m2.write(noop_edit(cor_id)+"\n") 161 | # Otherwise, do extra processing 162 | else: 163 | # Parse cor with spacy 164 | cor = annotator.parse(cor) 165 | # Align the texts and extract and classify the edits 166 | edits = annotator.annotate(orig, cor) 167 | # Loop through the edits 168 | for edit in edits: 169 | # Write the edit to the output m2 file 170 | out_m2.write(edit.to_m2(cor_id)+"\n") 171 | # Write a newline when we have processed all corrections for each line 172 | out_m2.write("\n") 173 | 174 | def errant_compare(self, fhyp: str, fref: str): 175 | from argparse import Namespace 176 | 177 | # Input: An m2 format sentence with edits. 178 | # Output: A list of lists. Each edit: [start, end, cat, cor, coder] 179 | 180 | def simplify_edits(sent): 181 | out_edits = [] 182 | # Get the edit lines from an m2 block. 183 | edits = sent.split("\n")[1:] 184 | # Loop through the edits 185 | for edit in edits: 186 | # Preprocessing 187 | edit = edit[2:].split("|||") # Ignore "A " then split. 188 | span = edit[0].split() 189 | start = int(span[0]) 190 | end = int(span[1]) 191 | cat = edit[1] 192 | cor = edit[2] 193 | coder = int(edit[-1]) 194 | out_edit = [start, end, cat, cor, coder] 195 | out_edits.append(out_edit) 196 | return out_edits 197 | 198 | # Input 1: A list of edits. Each edit: [start, end, cat, cor, coder] 199 | # Output: A dict; key is coder, value is edit dict. 200 | def process_edits(edits, args): 201 | coder_dict = {} 202 | # Add an explicit noop edit if there are no edits. 203 | if not edits: 204 | edits = [[-1, -1, "noop", "-NONE-", 0]] 205 | # Loop through the edits 206 | for edit in edits: 207 | # Name the edit elements for clarity 208 | start = edit[0] 209 | end = edit[1] 210 | cat = edit[2] 211 | cor = edit[3] 212 | coder = edit[4] 213 | # Add the coder to the coder_dict if necessary 214 | if coder not in coder_dict: 215 | coder_dict[coder] = {} 216 | 217 | # Optionally apply filters based on args 218 | # 1. UNK type edits are only useful for detection, not correction. 219 | if not args.dt and not args.ds and cat == "UNK": 220 | continue 221 | # 2. Only evaluate single token edits; i.e. 0:1, 1:0 or 1:1 222 | if args.single and (end-start >= 2 or len(cor.split()) >= 2): 223 | continue 224 | # 3. Only evaluate multi token edits; i.e. 2+:n or n:2+ 225 | if args.multi and end-start < 2 and len(cor.split()) < 2: 226 | continue 227 | # 4. If there is a filter, ignore the specified error types 228 | if args.filt and cat in args.filt: 229 | continue 230 | 231 | # Token Based Detection 232 | if args.dt: 233 | # Preserve noop edits. 234 | if start == -1: 235 | if (start, start) in coder_dict[coder].keys(): 236 | coder_dict[coder][(start, start)].append(cat) 237 | else: 238 | coder_dict[coder][(start, start)] = [cat] 239 | # Insertions defined as affecting the token on the right 240 | elif start == end and start >= 0: 241 | if (start, start+1) in coder_dict[coder].keys(): 242 | coder_dict[coder][(start, start+1)].append(cat) 243 | else: 244 | coder_dict[coder][(start, start+1)] = [cat] 245 | # Edit spans are split for each token in the range. 246 | else: 247 | for tok_id in range(start, end): 248 | if (tok_id, tok_id+1) in coder_dict[coder].keys(): 249 | coder_dict[coder][(tok_id, tok_id+1)].append(cat) 250 | else: 251 | coder_dict[coder][(tok_id, tok_id+1)] = [cat] 252 | 253 | # Span Based Detection 254 | elif args.ds: 255 | if (start, end) in coder_dict[coder].keys(): 256 | coder_dict[coder][(start, end)].append(cat) 257 | else: 258 | coder_dict[coder][(start, end)] = [cat] 259 | 260 | # Span Based Correction 261 | else: 262 | # With error type classification 263 | if args.cse: 264 | if (start, end, cat, cor) in coder_dict[coder].keys(): 265 | coder_dict[coder][(start, end, cat, cor)].append(cat) 266 | else: 267 | coder_dict[coder][(start, end, cat, cor)] = [cat] 268 | # Without error type classification 269 | else: 270 | if (start, end, cor) in coder_dict[coder].keys(): 271 | coder_dict[coder][(start, end, cor)].append(cat) 272 | else: 273 | coder_dict[coder][(start, end, cor)] = [cat] 274 | return coder_dict 275 | 276 | # Input 1-3: True positives, false positives, false negatives 277 | # Input 4: Value of beta in F-score. 278 | # Output 1-3: Precision, Recall and F-score rounded to 4dp. 279 | 280 | def computeFScore(tp, fp, fn, beta): 281 | p = float(tp)/(tp+fp) if fp else 1.0 282 | r = float(tp)/(tp+fn) if fn else 1.0 283 | f = float((1+(beta**2))*p*r)/(((beta**2)*p)+r) if p+r else 0.0 284 | return round(p, 4), round(r, 4), round(f, 4) 285 | # Input 1: A hyp dict; key is coder_id, value is dict of processed hyp edits. 286 | # Input 2: A ref dict; key is coder_id, value is dict of processed ref edits. 287 | # Input 3: A dictionary of the best corpus level TP, FP and FN counts so far. 288 | # Input 4: Sentence ID (for verbose output only) 289 | # Input 5: Command line args 290 | # Output 1: A dict of the best corpus level TP, FP and FN for the input sentence. 291 | # Output 2: The corresponding error type dict for the above dict. 292 | 293 | # Input 1: A dictionary of hypothesis edits for a single system. 294 | # Input 2: A dictionary of reference edits for a single annotator. 295 | # Output 1-3: The TP, FP and FN for the hyp vs the given ref annotator. 296 | # Output 4: A dictionary of the error type counts. 297 | def compareEdits(hyp_edits, ref_edits): 298 | tp = 0 # True Positives 299 | fp = 0 # False Positives 300 | fn = 0 # False Negatives 301 | cat_dict = {} # {cat: [tp, fp, fn], ...} 302 | 303 | for h_edit, h_cats in hyp_edits.items(): 304 | # noop hyp edits cannot be TP or FP 305 | if h_cats[0] == "noop": 306 | continue 307 | # TRUE POSITIVES 308 | if h_edit in ref_edits.keys(): 309 | # On occasion, multiple tokens at same span. 310 | for h_cat in ref_edits[h_edit]: # Use ref dict for TP 311 | tp += 1 312 | # Each dict value [TP, FP, FN] 313 | if h_cat in cat_dict.keys(): 314 | cat_dict[h_cat][0] += 1 315 | else: 316 | cat_dict[h_cat] = [1, 0, 0] 317 | # FALSE POSITIVES 318 | else: 319 | # On occasion, multiple tokens at same span. 320 | for h_cat in h_cats: 321 | fp += 1 322 | # Each dict value [TP, FP, FN] 323 | if h_cat in cat_dict.keys(): 324 | cat_dict[h_cat][1] += 1 325 | else: 326 | cat_dict[h_cat] = [0, 1, 0] 327 | for r_edit, r_cats in ref_edits.items(): 328 | # noop ref edits cannot be FN 329 | if r_cats[0] == "noop": 330 | continue 331 | # FALSE NEGATIVES 332 | if r_edit not in hyp_edits.keys(): 333 | # On occasion, multiple tokens at same span. 334 | for r_cat in r_cats: 335 | fn += 1 336 | # Each dict value [TP, FP, FN] 337 | if r_cat in cat_dict.keys(): 338 | cat_dict[r_cat][2] += 1 339 | else: 340 | cat_dict[r_cat] = [0, 0, 1] 341 | return tp, fp, fn, cat_dict 342 | 343 | def evaluate_edits(hyp_dict, ref_dict, best, sent_id, original_sentence, args): 344 | # Store the best sentence level scores and hyp+ref combination IDs 345 | # best_f is initialised as -1 cause 0 is a valid result. 346 | best_tp, best_fp, best_fn, best_f, _, _ = 0, 0, 0, -1, 0, 0 347 | best_cat = {} 348 | # Compare each hyp and ref combination 349 | for hyp_id in hyp_dict.keys(): 350 | for ref_id in ref_dict.keys(): 351 | # Get the local counts for the current combination. 352 | tp, fp, fn, cat_dict = compareEdits(hyp_dict[hyp_id], ref_dict[ref_id]) 353 | # Compute the local sentence scores (for verbose output only) 354 | loc_p, loc_r, loc_f = computeFScore(tp, fp, fn, args.beta) 355 | # Compute the global sentence scores 356 | p, r, f = computeFScore( 357 | tp+best["tp"], fp+best["fp"], fn+best["fn"], args.beta) 358 | # Save the scores if they are better in terms of: 359 | # 1. Higher F-score 360 | # 2. Same F-score, higher TP 361 | # 3. Same F-score and TP, lower FP 362 | # 4. Same F-score, TP and FP, lower FN 363 | if (f > best_f) or \ 364 | (f == best_f and tp > best_tp) or \ 365 | (f == best_f and tp == best_tp and fp < best_fp) or \ 366 | (f == best_f and tp == best_tp and fp == best_fp and fn < best_fn): 367 | best_tp, best_fp, best_fn = tp, fp, fn 368 | best_f, _, _ = f, hyp_id, ref_id 369 | best_cat = cat_dict 370 | # Save the best TP, FP and FNs as a dict, and return this and the best_cat dict 371 | best_dict = {"tp": best_tp, "fp": best_fp, "fn": best_fn} 372 | return best_dict, best_cat 373 | 374 | def merge_dict(dict1, dict2): 375 | for cat, stats in dict2.items(): 376 | if cat in dict1.keys(): 377 | dict1[cat] = [x+y for x, y in zip(dict1[cat], stats)] 378 | else: 379 | dict1[cat] = stats 380 | return dict1 381 | args = Namespace(beta=0.5, 382 | dt=False, 383 | ds=False, 384 | cs=False, 385 | cse=False, 386 | single=False, 387 | multi=False, 388 | filt=[], 389 | cat=1) 390 | # Open hypothesis and reference m2 files and split into chunks 391 | with open(fhyp) as fhyp, open(fref) as fref: 392 | hyp_m2 = fhyp.read().strip().split("\n\n") 393 | ref_m2 = fref.read().strip().split("\n\n") 394 | # Make sure they have the same number of sentences 395 | assert len(hyp_m2) == len(ref_m2) 396 | 397 | # Store global corpus level best counts here 398 | best_dict = Counter({"tp": 0, "fp": 0, "fn": 0}) 399 | best_cats = {} 400 | # Process each sentence 401 | sents = zip(hyp_m2, ref_m2) 402 | for sent_id, sent in enumerate(sents): 403 | # Simplify the edits into lists of lists 404 | hyp_edits = simplify_edits(sent[0]) 405 | ref_edits = simplify_edits(sent[1]) 406 | # Process the edits for detection/correction based on args 407 | hyp_dict = process_edits(hyp_edits, args) 408 | ref_dict = process_edits(ref_edits, args) 409 | # original sentence for logging 410 | original_sentence = sent[0][2:].split("\nA")[0] 411 | # Evaluate edits and get best TP, FP, FN hyp+ref combo. 412 | count_dict, cat_dict = evaluate_edits( 413 | hyp_dict, ref_dict, best_dict, sent_id, original_sentence, args) 414 | # Merge these dicts with best_dict and best_cats 415 | best_dict += Counter(count_dict) 416 | best_cats = merge_dict(best_cats, cat_dict) 417 | return best_dict 418 | 419 | 420 | class ExactMatchMetric(Metric): 421 | 422 | def __init__( 423 | self, 424 | loss: Optional[float] = None, 425 | preds: Optional[Tuple[torch.Tensor, List, List]] = None, 426 | golds: Optional[Tuple[torch.Tensor, List, List]] = None, 427 | mask: Optional[torch.BoolTensor] = None, 428 | reverse: bool = True, 429 | eps: float = 1e-12 430 | ) -> ExactMatchMetric: 431 | super().__init__(reverse=reverse, eps=eps) 432 | 433 | self.n_tokens = 0. 434 | 435 | self.tp = 0.0 436 | self.total = 0.0 437 | self.total_loss = 0. 438 | 439 | if loss is not None: 440 | self(loss, preds, golds, mask) 441 | 442 | def __repr__(self): 443 | return f"loss: {self.loss:.4f} EM: {self.em:6.2%}" 444 | 445 | def __call__( 446 | self, 447 | loss: float, 448 | preds: Tuple[torch.Tensor, List, List], 449 | golds: Tuple[torch.Tensor, List, List], 450 | mask: torch.BoolTensor 451 | ) -> ExactMatchMetric: 452 | n_tokens = mask.sum().item() 453 | self.n += len(mask) 454 | self.count += 1 455 | self.n_tokens += n_tokens 456 | self.total_loss += float(loss) * n_tokens 457 | 458 | if preds is not None: 459 | self.tp += sum([p[3].equal(g[3]) for p, g in zip(preds, golds)]) 460 | self.total += len(preds) 461 | return self 462 | 463 | def __add__(self, other: ExactMatchMetric) -> ExactMatchMetric: 464 | metric = ExactMatchMetric(eps=self.eps) 465 | metric.n = self.n + other.n 466 | metric.count = self.count + other.count 467 | metric.n_tokens = self.n_tokens + other.n_tokens 468 | metric.total_loss = self.total_loss + other.total_loss 469 | 470 | metric.tp = self.tp + other.tp 471 | metric.total = self.total + other.total 472 | metric.reverse = self.reverse or other.reverse 473 | return metric 474 | 475 | @property 476 | def score(self): 477 | return self.em 478 | 479 | @property 480 | def loss(self): 481 | return self.total_loss / self.n_tokens 482 | 483 | @property 484 | def em(self): 485 | return self.tp / (self.total + self.eps) 486 | 487 | @property 488 | def values(self): 489 | return {'EM': self.em} 490 | -------------------------------------------------------------------------------- /ctc/model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from supar.model import Model 7 | from supar.modules import TokenDropout 8 | from supar.modules.transformer import (TransformerDecoder, 9 | TransformerDecoderLayer) 10 | from supar.config import Config 11 | from supar.utils.common import INF, MIN 12 | from supar.utils.fn import pad 13 | 14 | 15 | class CTCModel(Model): 16 | r""" 17 | The implementation of CTC Parser. 18 | 19 | Args: 20 | n_words (int): 21 | The size of the word vocabulary. 22 | n_tags (int): 23 | The number of POS tags, required if POS tag embeddings are used. Default: ``None``. 24 | n_chars (int): 25 | The number of characters, required if character-level representations are used. Default: ``None``. 26 | n_lemmas (int): 27 | The number of lemmas, required if lemma embeddings are used. Default: ``None``. 28 | encoder (str): 29 | Encoder to use. 30 | ``'lstm'``: BiLSTM encoder. 31 | ``'bert'``: BERT-like pretrained language model (for finetuning), e.g., ``'bert-base-cased'``. 32 | Default: ``'lstm'``. 33 | n_embed (int): 34 | The size of word embeddings. Default: 100. 35 | n_pretrained (int): 36 | The size of pretrained word embeddings. Default: 125. 37 | n_feat_embed (int): 38 | The size of feature representations. Default: 100. 39 | n_char_embed (int): 40 | The size of character embeddings serving as inputs of CharLSTM, required if using CharLSTM. Default: 50. 41 | n_char_hidden (int): 42 | The size of y states of CharLSTM, required if using CharLSTM. Default: 100. 43 | char_pad_index (int): 44 | The index of the padding token in the character vocabulary, required if using CharLSTM. Default: 0. 45 | elmo (str): 46 | Name of the pretrained ELMo registered in `ELMoEmbedding.OPTION`. Default: ``'original_5b'``. 47 | elmo_bos_eos (tuple[bool]): 48 | A tuple of two boolean values indicating whether to keep start/end boundaries of elmo outputs. 49 | Default: ``(True, False)``. 50 | bert (str): 51 | Specifies which kind of language model to use, e.g., ``'bert-base-cased'``. 52 | This is required if ``encoder='bert'`` or using BERT features. The full list can be found in `transformers`_. 53 | Default: ``None``. 54 | n_bert_layers (int): 55 | Specifies how many last layers to use, required if ``encoder='bert'`` or using BERT features. 56 | The final outputs would be weighted sum of the y states of these layers. 57 | Default: 4. 58 | mix_dropout (float): 59 | The dropout ratio of BERT layers, required if ``encoder='bert'`` or using BERT features. Default: .0. 60 | bert_pooling (str): 61 | Pooling way to get token embeddings. 62 | ``first``: take the first subtoken. ``last``: take the last subtoken. ``mean``: take a mean over all. 63 | Default: ``mean``. 64 | bert_pad_index (int): 65 | The index of the padding token in BERT vocabulary, required if ``encoder='bert'`` or using BERT features. 66 | Default: 0. 67 | freeze (bool): 68 | If ``True``, freezes BERT parameters, required if using BERT features. Default: ``True``. 69 | embed_dropout (float): 70 | The dropout ratio of input embeddings. Default: .2. 71 | n_encoder_hidden (int): 72 | The size of LSTM y states. Default: 600. 73 | n_encoder_layers (int): 74 | The number of LSTM layers. Default: 3. 75 | encoder_dropout (float): 76 | The dropout ratio of encoder layer. Default: .33. 77 | mlp_dropout (float): 78 | The dropout ratio of unary edge factor MLP layers. Default: .33. 79 | pad_index (int): 80 | The index of the padding token in the word vocabulary. Default: 0. 81 | unk_index (int): 82 | The index of the unknown token in the word vocabulary. Default: 1. 83 | 84 | .. _transformers: 85 | https://github.com/huggingface/transformers 86 | """ 87 | 88 | def __init__(self, 89 | n_words, 90 | n_tags=None, 91 | n_chars=None, 92 | n_lemmas=None, 93 | encoder='lstm', 94 | n_embed=100, 95 | n_pretrained=100, 96 | n_feat_embed=100, 97 | n_char_embed=50, 98 | n_char_hidden=100, 99 | char_pad_index=0, 100 | char_dropout=0, 101 | elmo='original_5b', 102 | elmo_bos_eos=(True, False), 103 | bert=None, 104 | n_bert_layers=4, 105 | mix_dropout=.0, 106 | bert_pooling='mean', 107 | bert_pad_index=0, 108 | freeze=True, 109 | embed_dropout=.33, 110 | n_encoder_hidden=512, 111 | n_encoder_layers=3, 112 | encoder_dropout=.1, 113 | dropout=.1, 114 | pad_index=0, 115 | unk_index=1, 116 | **kwargs): 117 | super().__init__(**Config().update(locals())) 118 | 119 | from transformers import AutoModel 120 | self.encoder = AutoModel.from_pretrained(self.args.bert, 121 | add_pooling_layer=False, 122 | attention_probs_dropout_prob=self.args.dropout, 123 | hidden_dropout_prob=self.args.dropout) 124 | if self.args.vocab: 125 | self.encoder.resize_token_embeddings(self.args.n_words) 126 | self.token_dropout = TokenDropout(self.args.get('token_dropout', 0)) 127 | self.proj = nn.Linear(self.args.n_encoder_hidden, self.args.upsampling * self.args.n_encoder_hidden) 128 | self.decoder = TransformerDecoder(layer=TransformerDecoderLayer(n_model=self.args.n_encoder_hidden, 129 | dropout=self.args.dropout), 130 | n_layers=self.args.n_decoder_layers) 131 | self.classifier = nn.Linear(self.args.n_encoder_hidden, self.args.n_words) 132 | 133 | def forward(self, words): 134 | r""" 135 | Args: 136 | words (~torch.LongTensor): ``[batch_size, seq_len]``. 137 | Word indices. 138 | 139 | Returns: 140 | ~torch.Tensor: 141 | Representations for the src sentences of the shape ``[batch_size, seq_len, n_model]``. 142 | """ 143 | x = self.encoder(inputs_embeds=self.token_dropout(self.encoder.embeddings.word_embeddings(words)), 144 | attention_mask=words.ne(self.args.pad_index))[0] 145 | return self.encoder_dropout(x) 146 | 147 | def resize(self, x): 148 | batch_size, seq_len, *_, upsampling = x.shape 149 | resized = x.new_zeros(batch_size, seq_len * upsampling, *_) 150 | for i, j in enumerate(x.unbind(-1)): 151 | resized[:, i::upsampling] = j 152 | return resized 153 | 154 | def loss(self, x, src, tgt, src_mask, tgt_mask, ratio=0): 155 | x_tgt, glat_mask = self.resize(self.proj(x).view(*x.shape, self.args.upsampling)), None 156 | if ratio > 0: 157 | with torch.no_grad(), torch.cuda.amp.autocast(enabled=False): 158 | mask = self.resize(src_mask.unsqueeze(-1).repeat(1, 1, self.args.upsampling)) 159 | preds, s_x = self.decode(x, src, src_mask, True) 160 | align = self.align(s_x.log_softmax(2).transpose(0, 1), src, tgt, src_mask, tgt_mask) 161 | probs = ((align.ne(preds) & mask).sum(-1) / mask.sum(-1) * ratio).clamp_(0, 1) 162 | glat_mask = (src.new_zeros(mask.shape) + probs.unsqueeze(-1)).bernoulli().bool() 163 | e_tgt = self.encoder.embeddings(torch.where(align.ge(self.args.n_words-2), self.args.mask_index, align)) 164 | x_tgt = torch.where(glat_mask.unsqueeze(-1), e_tgt, x_tgt) 165 | x = self.decoder(x_tgt=x_tgt, 166 | x_src=x, 167 | tgt_mask=self.resize(src_mask.unsqueeze(-1).repeat(1, 1, self.args.upsampling)), 168 | src_mask=src_mask) 169 | # [tgt_len, batch_size, n_words] 170 | s_x = self.classifier(x).log_softmax(2).transpose(0, 1) 171 | return self.ctc(s_x, src, tgt, src_mask, tgt_mask) 172 | 173 | def ctc(self, s_x, src, tgt, src_mask, tgt_mask, glat_mask=None): 174 | src = self.resize(src.unsqueeze(-1).repeat(1, 1, self.args.upsampling)) 175 | # [tgt_len, batch_size] 176 | s_k, s_b = s_x[..., self.args.keep_index], s_x[..., self.args.nul_index] 177 | # [tgt_len, seq_len, batch_size] 178 | s_x = s_x.gather(-1, tgt.repeat(s_b.shape[0], 1, 1)).transpose(1, 2) 179 | s_x = torch.where(src.unsqueeze(-1).eq(tgt.unsqueeze(1)).movedim(0, -1), s_k.unsqueeze(1), s_x) 180 | if glat_mask is not None: 181 | glat_mask = glat_mask.t() 182 | s_b = s_b.masked_fill(glat_mask, 0) 183 | s_x = s_x.masked_fill(glat_mask.unsqueeze(1), 0) 184 | src_lens, tgt_lens = src_mask.sum(-1) * self.args.upsampling, tgt_mask.sum(-1) 185 | tgt_len, seq_len, batch_size = s_x.shape 186 | # [tgt_len, 2, seq_len + 1, batch_size] 187 | s = s_x.new_full((tgt_len, 2, seq_len + 1, batch_size), MIN) 188 | s[0, 0, 0], s[0, 1, 0] = s_b[0], s_x[0, 0] 189 | for t in range(1, tgt_len): 190 | s0 = torch.cat((torch.full_like(s[0, 0, :1], MIN), s[t-1, 1, :-1])) 191 | s1 = s[t-1, 0] 192 | s2 = s[t-1, 1] 193 | s[t, 0] = torch.stack((s0, s1)).logsumexp(0) + s_b[t] 194 | s[t, 1, :-1] = torch.stack((s0[:-1], s1[:-1], s2[:-1])).logsumexp(0) + s_x[t] 195 | s = s[src_lens - 1, 0, tgt_lens, range(batch_size)].logaddexp(s[src_lens - 1, 1, tgt_lens - 1, range(batch_size)]) 196 | return -s.sum() / tgt_lens.sum() 197 | 198 | def decode(self, x, src, src_mask, score=False): 199 | batch_size, *_ = x.shape 200 | beam_size, n_words = self.args.beam_size, self.args.n_words 201 | keep_index, nul_index, pad_index = self.args.keep_index, self.args.nul_index, self.args.pad_index 202 | indices = src.new_tensor(range(batch_size)).unsqueeze(1).repeat(1, beam_size).view(-1) 203 | x = self.decoder(x_tgt=self.resize(self.proj(x).view(*x.shape, self.args.upsampling)), 204 | x_src=x, 205 | tgt_mask=self.resize(src_mask.unsqueeze(-1).repeat(1, 1, self.args.upsampling)), 206 | src_mask=src_mask) 207 | src = self.resize(src.unsqueeze(-1).repeat(1, 1, self.args.upsampling)) 208 | src_mask = self.resize(src_mask.unsqueeze(-1).repeat(1, 1, self.args.upsampling)) 209 | 210 | if not self.args.prefix: 211 | s_x = self.classifier(x) 212 | # [batch_size, tgt_len, topk] 213 | tgt = s_x.topk(self.args.topk, -1)[1] 214 | tgt = torch.where(tgt.eq(keep_index), src.unsqueeze(-1), tgt) 215 | # [batch_size, topk, tgt_len] 216 | tgt = tgt.masked_fill_(~src_mask.unsqueeze(2), self.args.pad_index).transpose(1, 2) 217 | if score: 218 | return tgt[:, 0], s_x 219 | # [batch_size, topk, tgt_len] 220 | tgt = [[j.unique_consecutive() for j in i.unbind(0)] for i in tgt.unbind(0)] 221 | tgt = pad([pad([j[j.ne(nul_index)] for j in i], pad_index) for i in tgt], pad_index) 222 | return tgt 223 | 224 | # [batch_size * beam_size, tgt_len, ...] 225 | x, src, src_mask = x[indices], src[indices], src_mask[indices] 226 | # [batch_size * beam_size, max_len] 227 | tgt = x.new_full((batch_size * beam_size, x.shape[1]), nul_index, dtype=torch.long) 228 | lens = tgt.new_full((tgt.shape[0],), 0) 229 | # [batch_size] 230 | batches = tgt.new_tensor(range(batch_size)) * beam_size 231 | # accumulated scores 232 | # [2, batch_size * beam_size] 233 | s = torch.stack((x.new_full((batch_size, beam_size), -INF).index_fill_(-1, tgt.new_tensor(0), 0).view(-1), 234 | x.new_full((batch_size * beam_size,), -INF))) 235 | 236 | def merge(s_b, s_n, tgt, lens, ends): 237 | # merge the prefixes that have grown in the new step 238 | s_n = s_n.view(batch_size, beam_size, -1) 239 | tgt, lens, ends = tgt.view(batch_size, beam_size, -1), lens.view(batch_size, -1), ends.view(batch_size, -1) 240 | # [batch_size, beam_size, beam_size] 241 | mask = tgt.scatter(-1, (lens.clamp(1) - 1).unsqueeze(-1), nul_index).unsqueeze(2).eq(tgt.unsqueeze(1)).all(-1) 242 | mask = mask & lens.gt(0).unsqueeze(2) 243 | s_g = s_n.gather(-1, ends.unsqueeze(2)) 244 | s_n[..., nul_index] = s_n[..., nul_index].logaddexp(s_g.transpose(1, 2).masked_fill(~mask, -INF).logsumexp(2)) 245 | s_n.scatter_(-1, ends.unsqueeze(2), torch.where(mask.any(2, True), -INF, s_g)) 246 | s_n = s_n.view(batch_size * beam_size, -1) 247 | return s_b, s_n 248 | 249 | for t in range(x.shape[1]): 250 | # [batch_size * beam_size] 251 | mask = src_mask[:, t] 252 | # the past prefixes 253 | ends = tgt[range(tgt.shape[0]), lens - 1] 254 | # [batch_size * beam_size, n_words] 255 | s_t = self.classifier(x[:, t]).log_softmax(1) 256 | s_k = s_t.gather(-1, src[:, t].unsqueeze(-1)).logaddexp(s_t[:, keep_index].unsqueeze(-1)) 257 | s_t = s_t.scatter_(-1, src[:, t].unsqueeze(-1), s_k) 258 | s_t[:, keep_index] = -INF 259 | s_e = s_t.gather(1, ends.unsqueeze(1)) 260 | s_p = s.logsumexp(0).unsqueeze(-1) 261 | # [batch_size * beam_size] 262 | # the position for blanks are used for storing prefixes kept unchanged 263 | # *a - -> *a 264 | s_b = s_p + s_t.masked_fill(tgt.new_tensor(range(n_words)).ne(nul_index).unsqueeze(0), -INF) 265 | # *a b -> *ab 266 | s_n = s_p + s_t 267 | # *a- a -> *aa 268 | s_n = s_n.scatter_(1, ends.unsqueeze(1), s[0].unsqueeze(1) + s_e) 269 | # *a a -> *a 270 | s_n[:, nul_index] = s[1] + s_e.squeeze(1) 271 | # [2, batch_size * beam_size, n_words] 272 | s = torch.stack((merge(s_b, s_n, tgt, lens, ends))) 273 | # [batch_size, beam_size] 274 | cands = s.logsumexp(0).view(batch_size, -1).topk(beam_size, -1)[1] 275 | # [2, batch_size * beam_size] 276 | s = s.view(2, batch_size, -1).gather(-1, cands.repeat(2, 1, 1)).view(2, -1) 277 | # beams, tokens = cands // n_words, cands % n_words 278 | beams, tokens = cands.div(n_words, rounding_mode='floor'), (cands % n_words).view(-1, 1) 279 | indices = (batches.unsqueeze(-1) + beams).view(-1) 280 | lens[mask] = lens[indices[mask]] 281 | # [batch_size * beam_size, max_len] 282 | tgt[mask] = tgt[indices[mask]].scatter_(1, lens[mask].unsqueeze(1), tokens[mask]) 283 | lens += tokens.ne(nul_index).squeeze(1) & mask 284 | cands = s.logsumexp(0).view(batch_size, -1).topk(self.args.topk, -1)[1] 285 | tgt = tgt[(batches.unsqueeze(-1) + cands).view(-1)].view(batch_size, self.args.topk, -1) 286 | tgt = pad([pad([j[j.ne(nul_index)] for j in i], pad_index) for i in tgt], pad_index) 287 | return tgt 288 | 289 | def align(self, s_x, src, tgt, src_mask, tgt_mask): 290 | src = self.resize(src.unsqueeze(-1).repeat(1, 1, self.args.upsampling)) 291 | # [tgt_len, batch_size] 292 | s_k, s_b = s_x[..., self.args.keep_index], s_x[..., self.args.nul_index] 293 | # [tgt_len, seq_len, batch_size] 294 | s_x = s_x.gather(-1, tgt.repeat(s_b.shape[0], 1, 1)).transpose(1, 2) 295 | s_x = torch.where(src.unsqueeze(-1).eq(tgt.unsqueeze(1)).movedim(0, -1), s_k.unsqueeze(1), s_x) 296 | src_lens, tgt_lens = src_mask.sum(-1) * self.args.upsampling, tgt_mask.sum(-1) 297 | tgt_len, seq_len, batch_size = s_x.shape 298 | # [tgt_len, 2, seq_len + 1, batch_size] 299 | s = s_x.new_full((tgt_len, 2, seq_len + 1, batch_size), -INF) 300 | p = tgt.new_full((tgt_len, 2, seq_len + 1, batch_size), -1) 301 | s[0, 0, 0], s[0, 1, 0] = s_b[0], s_x[0, 0] 302 | 303 | for t in range(1, tgt_len): 304 | s0 = torch.cat((torch.full_like(s[0, 0, :1], -INF), s[t-1, 1, :-1])) 305 | s1 = s[t-1, 0] 306 | s2 = s[t-1, 1] 307 | s_t, p[t, 0] = torch.stack((s0, s1)).max(0) 308 | s[t, 0] = s_t + s_b[t] 309 | s_t, p[t, 1, :-1] = torch.stack((s0[:-1], s1[:-1], s2[:-1])).max(0) 310 | s[t, 1, :-1] = s_t + s_x[t] 311 | _, p_t = torch.stack((s[src_lens - 1, 0, tgt_lens, range(batch_size)], 312 | s[src_lens - 1, 1, tgt_lens - 1, range(batch_size)])).max(0) 313 | 314 | def backtrack(p, tgt, notnul): 315 | j, pred = [len(p[0][0])-1, len(p[0][0])-2], [] 316 | for i in reversed(range(len(p))): 317 | prev = p[i][notnul][j[notnul]] 318 | pred.append(tgt[j[notnul]] if bool(notnul) else self.args.nul_index) 319 | if notnul == 0: 320 | if prev == 0: 321 | notnul = 1 322 | j[notnul] = j[1-notnul] - 1 323 | elif notnul == 1: 324 | if prev == 0: 325 | j[notnul] -= 1 326 | if prev == 1: 327 | notnul = 0 328 | j[notnul] = j[1-notnul] 329 | return tuple(reversed(pred)) 330 | p_t, tgt, preds = p_t.tolist(), tgt.tolist(), torch.full_like(src, self.args.pad_index) 331 | for i, (src_len, tgt_len) in enumerate(zip(src_lens.tolist(), tgt_lens.tolist())): 332 | preds[i, :src_len] = src.new_tensor(backtrack(p[:src_len, :, :tgt_len+1, i].tolist(), tgt[i][:tgt_len], p_t[i])) 333 | return preds 334 | -------------------------------------------------------------------------------- /ctc/parser.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import tempfile 5 | 6 | import errant 7 | import torch 8 | import torch.distributed as dist 9 | from torch.optim import AdamW, Optimizer 10 | 11 | from supar.config import Config 12 | from supar.parser import Parser 13 | from supar.utils import Dataset 14 | from supar.utils.field import Field 15 | from supar.utils.logging import get_logger 16 | from supar.utils.parallel import gather, is_dist, is_master 17 | from supar.utils.tokenizer import TransformerTokenizer 18 | from supar.utils.transform import Batch 19 | 20 | from .metric import PerplexityMetric 21 | from .model import CTCModel 22 | from .transform import Text 23 | 24 | logger = get_logger(__name__) 25 | 26 | 27 | class CTCParser(Parser): 28 | 29 | NAME = 'ctc' 30 | MODEL = CTCModel 31 | 32 | def __init__(self, *args, **kwargs): 33 | super().__init__(*args, **kwargs) 34 | 35 | self.SRC = self.transform.SRC 36 | self.TGT = self.transform.TGT 37 | self.annotator = errant.load("en") 38 | 39 | def init_optimizer(self) -> Optimizer: 40 | return AdamW(params=[{'params': p, 'lr': self.args.lr * (1 if n.startswith('encoder') else self.args.lr_rate)} 41 | for n, p in self.model.named_parameters()], 42 | lr=self.args.lr, 43 | betas=(self.args.get('mu', 0.9), self.args.get('nu', 0.999)), 44 | eps=self.args.get('eps', 1e-8), 45 | weight_decay=self.args.get('weight_decay', 0)) 46 | 47 | def train_step(self, batch: Batch) -> torch.Tensor: 48 | src, tgt = batch 49 | src_mask, tgt_mask = batch.mask, tgt.ne(self.args.pad_index) 50 | mask = tgt_mask.sum(-1).lt(src_mask.sum(-1) * self.args.upsampling) 51 | src, tgt, src_mask, tgt_mask = src[mask], tgt[mask], src_mask[mask], tgt_mask[mask] 52 | x = self.model(src) 53 | loss = self.model.loss(x, src, tgt, src_mask, tgt_mask, self.args.glat) 54 | return loss 55 | 56 | @torch.no_grad() 57 | def eval_step(self, batch: Batch) -> PerplexityMetric: 58 | src, tgt = batch 59 | src_mask, tgt_mask = batch.mask, tgt.ne(self.args.pad_index) 60 | mask = tgt_mask.sum(-1).lt(src_mask.sum(-1) * self.args.upsampling) 61 | src, tgt, src_mask, tgt_mask = src[mask], tgt[mask], src_mask[mask], tgt_mask[mask] 62 | x = self.model(src) 63 | loss = self.model.loss(x, src, tgt, src_mask, tgt_mask) 64 | preds = golds = None 65 | if self.args.eval_tgt: 66 | golds = [(s.values[0], s.values[1], s.fields['src'].tolist(), t.tolist()) 67 | for s, t in zip(batch.sentences, tgt[tgt_mask].split(tgt_mask.sum(-1).tolist()))] 68 | preds = self.model.decode(x, src, batch.mask)[:, 0] 69 | pred_mask = preds.ne(self.args.pad_index) 70 | preds = [i.tolist() for i in preds[pred_mask].split(pred_mask.sum(-1).tolist())] 71 | preds = [(s.values[0], self.TGT.tokenize.decode(i), s.fields['src'].tolist(), i) 72 | for s, i in zip(batch.sentences, preds)] 73 | return PerplexityMetric(loss, 74 | preds, 75 | golds, 76 | tgt_mask, 77 | (None if self.args.lev else self.annotator), 78 | not self.args.eval_tgt) 79 | 80 | @torch.no_grad() 81 | def pred_step(self, batch: Batch) -> Batch: 82 | src, = batch 83 | mask = batch.mask 84 | for _ in range(self.args.iteration): 85 | x = self.model(src) 86 | tgt = self.model.decode(x, src, mask) 87 | src = tgt[:, 0] 88 | mask = src.ne(self.args.pad_index) 89 | batch.tgt = [[self.TGT.tokenize.decode(cand).strip() for cand in i] for i in tgt.tolist()] 90 | return batch 91 | 92 | @classmethod 93 | def build(cls, path, min_freq=2, fix_len=20, **kwargs): 94 | r""" 95 | Build a brand-new Parser, including initialization of all data fields and model parameters. 96 | 97 | Args: 98 | path (str): 99 | The path of the model to be saved. 100 | min_freq (str): 101 | The minimum frequency needed to include a token in the vocabulary. Default: 2. 102 | fix_len (int): 103 | The max length of all subword pieces. The excess part of each piece will be truncated. 104 | Required if using CharLSTM/BERT. 105 | Default: 20. 106 | kwargs (dict): 107 | A dict holding the unconsumed arguments. 108 | """ 109 | 110 | args = Config(**locals()) 111 | os.makedirs(os.path.dirname(path) or './', exist_ok=True) 112 | if os.path.exists(path) and not args.build: 113 | return cls.load(**args) 114 | 115 | logger.info("Building the fields") 116 | t = TransformerTokenizer(args.bert) 117 | SRC = Field('src', pad=t.pad, unk=t.unk, bos=t.bos, eos=t.eos, tokenize=t) 118 | TGT = Field('tgt', pad=t.pad, unk=t.unk, bos=t.bos, eos=t.eos, tokenize=t) 119 | transform = Text(SRC=SRC, TGT=TGT) 120 | if args.vocab: 121 | if is_master(): 122 | t.extend(Dataset(transform, args.train, **args).src) 123 | if is_dist(): 124 | with tempfile.TemporaryDirectory(dir='.') as td: 125 | td = gather(td)[0] 126 | if is_master(): 127 | torch.save(t, f'{td}/t') 128 | dist.barrier() 129 | t = torch.load(f'{td}/t') 130 | SRC.vocab = TGT.vocab = t.vocab 131 | 132 | args.update({'n_words': len(SRC.vocab) + 2, 133 | 'pad_index': SRC.pad_index, 134 | 'unk_index': SRC.unk_index, 135 | 'bos_index': SRC.bos_index, 136 | 'eos_index': SRC.eos_index, 137 | 'mask_index': t.mask_token_id, 138 | 'keep_index': len(SRC.vocab), 139 | 'nul_index': len(SRC.vocab) + 1}) 140 | logger.info(f"{transform}") 141 | logger.info("Building the model") 142 | model = cls.MODEL(**args) 143 | logger.info(f"{model}\n") 144 | 145 | parser = cls(args, model, transform) 146 | parser.model.to(parser.device) 147 | return parser 148 | -------------------------------------------------------------------------------- /ctc/struct.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import annotations 4 | 5 | from typing import List, Optional 6 | 7 | import torch 8 | from torch.distributions.utils import lazy_property 9 | 10 | from supar.structs.dist import StructuredDistribution 11 | from supar.structs.semiring import LogSemiring, Semiring 12 | 13 | 14 | class Levenshtein(StructuredDistribution): 15 | 16 | def __init__( 17 | self, 18 | scores: torch.Tensor, 19 | lens: Optional[torch.LongTensor] = None 20 | ) -> Levenshtein: 21 | super().__init__(scores) 22 | 23 | batch_size, _, seq_len, src_len = scores.shape[:4] 24 | if lens is not None: 25 | self.lens = lens 26 | else: 27 | self.lens = (scores.new_zeros(batch_size, 2) + scores.new_tensor(src_len, seq_len)).long() 28 | self.src_lens, self.tgt_lens = lens.unbind(-1) 29 | self.src_mask = self.src_lens.unsqueeze(-1).gt(self.lens.new_tensor(range(src_len))) 30 | self.tgt_mask = self.tgt_lens.unsqueeze(-1).gt(self.lens.new_tensor(range(seq_len))) 31 | 32 | def __add__(self, other): 33 | return Levenshtein(torch.stack((self.scores, other.scores)), self.lens) 34 | 35 | @lazy_property 36 | def argmax(self): 37 | margs = self.backward(self.max.sum()) 38 | margs, edits = margs.argmax(1).transpose(1, 2), [torch.where(i) for i in margs.sum(1).transpose(1, 2).unbind()] 39 | return [torch.stack((e[0], e[1], m[e])).t().tolist() for e, m in zip(edits, margs)] 40 | 41 | def score(self, value: List) -> torch.Tensor: 42 | lens = self.lens.new_tensor([len(i) for i in value]) 43 | edit_mask = lens.unsqueeze(-1).gt(lens.new_tensor(range(max(lens)))) 44 | edits = list(self.lens.new_tensor([(i,) + span for i, spans in enumerate(value) for span in spans]).unbind(-1)) 45 | s_edit = self.scores[edits[0], edits[3], edits[2], edits[1]] 46 | s = s_edit.new_full(edit_mask.shape, LogSemiring.one).masked_scatter_(edit_mask, s_edit) 47 | return LogSemiring.prod(s) 48 | 49 | def forward(self, semiring: Semiring) -> torch.Tensor: 50 | # [4, seq_len, src_len, batch_size, ...] 51 | s_edit = semiring.convert(self.scores.movedim(0, 3)) 52 | 53 | _, seq_len, src_len, batch_size = s_edit.shape[:4] 54 | tgt_lens, src_lens, src_mask = self.tgt_lens, self.src_lens, self.src_mask.t() 55 | # [seq_len, src_len, batch_size] 56 | alpha = semiring.zeros_like(s_edit[0]) 57 | trans = semiring.cumprod(torch.cat((semiring.ones_like(s_edit[0, :, :1]), s_edit[0, :, 1:]), 1), 1) 58 | # [seq_len, src_len, src_len, batch_size] 59 | trans = trans.unsqueeze(2) - trans.unsqueeze(1) 60 | trans_mask = src_mask.unsqueeze(0) & torch.ones_like(src_mask).unsqueeze(1) 61 | # [src_len, src_len, batch_size] 62 | trans_mask = trans_mask & src_mask.new_ones(src_len, src_len).tril(-1).unsqueeze(-1) 63 | 64 | for t in range(seq_len): 65 | s_a = alpha[t - 1] if t > 0 else semiring.ones_like(trans[0, 0]) 66 | # INSERT 67 | s_i = semiring.mul(s_a, s_edit[1, t]) 68 | # KEEP 69 | s_k = torch.cat((semiring.zeros_like(s_a[:1]), semiring.mul(s_a[:-1], s_edit[2, t, 1:])), 0) 70 | # REPLACE 71 | s_r = torch.cat((semiring.zeros_like(s_a[:1]), semiring.mul(s_a[:-1], s_edit[3, t, 1:])), 0) 72 | # SWAP 73 | s_s = torch.cat((semiring.zeros_like(s_a[:1]), semiring.mul(s_a[:-1], s_edit[4, t, 1:])), 0) 74 | # [src_len, batch_size] 75 | s_a = semiring.sum(torch.stack((s_i, s_k, s_r, s_s)), 0) 76 | # DELETE 77 | s_d = semiring.sum(semiring.zero_mask_(semiring.mul(trans[t], s_a.unsqueeze(0)), ~trans_mask), 1) 78 | # [src_len, batch_size] 79 | alpha[t] = semiring.add(s_d, s_a) 80 | # the full input is consumed when the final output symbol is generated 81 | return semiring.unconvert(alpha[tgt_lens - 1, src_lens - 1, range(batch_size)]) 82 | -------------------------------------------------------------------------------- /ctc/transform.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import annotations 4 | 5 | import os 6 | import tempfile 7 | from contextlib import contextmanager 8 | from io import StringIO 9 | from typing import Iterable, List, Optional, Union 10 | 11 | import pathos.multiprocessing as mp 12 | import spacy 13 | import spacy.parts_of_speech as POS 14 | import torch.distributed as dist 15 | from rapidfuzz.distance import Indel 16 | from spacy.tokens import Doc 17 | 18 | from supar.utils import Field 19 | from supar.utils.fn import binarize, debinarize 20 | from supar.utils.logging import progress_bar 21 | from supar.utils.parallel import gather, is_dist, is_master 22 | from supar.utils.tokenizer import Tokenizer 23 | from supar.utils.transform import Sentence, Transform 24 | 25 | 26 | class Alignment: 27 | # Protected class resource 28 | _open_pos = {POS.ADJ, POS.ADV, POS.NOUN, POS.VERB} 29 | 30 | # Input 1: An original text string parsed by spacy 31 | # Input 2: A corrected text string parsed by spacy 32 | # Input 3: A flag for standard Levenshtein alignment 33 | def __init__(self, orig, cor, lev=False, nlp=None): 34 | # Set orig and cor 35 | self.nlp = nlp 36 | self.orig_toks, self.cor_toks = orig, cor 37 | self.orig = self.parse(orig) 38 | self.cor = self.parse(cor) 39 | # Align orig and cor and get the cost and op matrices 40 | self.cost_matrix, self.op_matrix = self.align(lev) 41 | # Get the cheapest align sequence from the op matrix 42 | self.align_seq = self.get_cheapest_align_seq() 43 | 44 | # Input: A flag for standard Levenshtein alignment 45 | # Output: The cost matrix and the operation matrix of the alignment 46 | def align(self, lev): 47 | # Sentence lengths 48 | o_len = len(self.orig) 49 | c_len = len(self.cor) 50 | # Lower case token IDs (for transpositions) 51 | # Create the cost_matrix and the op_matrix 52 | cost_matrix = [[0.0 for j in range(c_len+1)] for i in range(o_len+1)] 53 | op_matrix = [["O" for j in range(c_len+1)] for i in range(o_len+1)] 54 | # Fill in the edges 55 | for i in range(1, o_len+1): 56 | cost_matrix[i][0] = cost_matrix[i-1][0] + 1 57 | op_matrix[i][0] = "D" 58 | for j in range(1, c_len+1): 59 | cost_matrix[0][j] = cost_matrix[0][j-1] + 1 60 | op_matrix[0][j] = "I" 61 | 62 | # Loop through the cost_matrix 63 | for i in range(o_len): 64 | for j in range(c_len): 65 | # Matches 66 | if self.orig[i].orth == self.cor[j].orth and self.orig_toks[i] == self.cor_toks[j]: 67 | cost_matrix[i+1][j+1] = cost_matrix[i][j] 68 | op_matrix[i+1][j+1] = "M" 69 | # Non-matches 70 | else: 71 | del_cost = cost_matrix[i][j+1] + 1 72 | ins_cost = cost_matrix[i+1][j] + 1 73 | trans_cost = float("inf") # currently ignore swap/transpose 74 | k = 0 75 | # Standard Levenshtein (S = 1) 76 | if lev: 77 | sub_cost = cost_matrix[i][j] + 1 78 | # Linguistic Damerau-Levenshtein 79 | else: 80 | # Custom substitution 81 | sub_cost = cost_matrix[i][j] + self.get_sub_cost(self.orig[i], self.cor[j]) 82 | # Costs 83 | costs = [trans_cost, sub_cost, ins_cost, del_cost] 84 | # Get the index of the cheapest (first cheapest if tied) 85 | l = costs.index(min(costs)) 86 | # Save the cost and the op in the matrices 87 | cost_matrix[i+1][j+1] = costs[l] 88 | if l == 0: 89 | op_matrix[i+1][j+1] = "T"+str(k+1) 90 | elif l == 1: 91 | op_matrix[i+1][j+1] = "S" 92 | elif l == 2: 93 | op_matrix[i+1][j+1] = "I" 94 | else: 95 | op_matrix[i+1][j+1] = "D" 96 | # Return the matrices 97 | return cost_matrix, op_matrix 98 | 99 | # Input 1: A spacy orig Token 100 | # Input 2: A spacy cor Token 101 | # Output: A linguistic cost between 0 < x < 2 102 | def get_sub_cost(self, o, c): 103 | # Short circuit if the only difference is case 104 | if o.lower == c.lower: 105 | return 0 106 | # Lemma cost 107 | if o.lemma == c.lemma: 108 | lemma_cost = 0 109 | else: 110 | lemma_cost = 0.499 111 | # POS cost 112 | if o.pos == c.pos: 113 | pos_cost = 0 114 | elif o.pos in self._open_pos and c.pos in self._open_pos: 115 | pos_cost = 0.25 116 | else: 117 | pos_cost = 0.5 118 | # Char cost 119 | char_cost = Indel.normalized_distance(o.text, c.text) 120 | # Combine the costs 121 | return lemma_cost + pos_cost + char_cost 122 | 123 | # Get the cheapest alignment sequence and indices from the op matrix 124 | def get_cheapest_align_seq(self): 125 | i = len(self.op_matrix)-1 126 | j = len(self.op_matrix[0])-1 127 | op_set = {'D': 0, 'I': 1, 'M': 2, 'S': 3} 128 | align_seq = [(i, j, op_set['M'])] 129 | # Work backwards from bottom right until we hit top left 130 | while i + j != 0: 131 | # Get the edit operation in the current cell 132 | op = self.op_matrix[i][j] 133 | # Matches and substitutions 134 | if op in {"M", "S"}: 135 | i -= 1 136 | j -= 1 137 | # Deletions 138 | elif op == "D": 139 | i -= 1 140 | # Insertions 141 | elif op == "I": 142 | j -= 1 143 | align_seq.append((i, j, op_set[op])) 144 | # Reverse the list to go from left to right and return 145 | align_seq.reverse() 146 | return align_seq 147 | 148 | # Alignment object string representation 149 | def __str__(self): 150 | orig = " ".join(["Orig:"]+[tok.text for tok in self.orig]) 151 | cor = " ".join(["Cor:"]+[tok.text for tok in self.cor]) 152 | cost_matrix = "\n".join(["Cost Matrix:"]+[str(row) for row in self.cost_matrix]) 153 | op_matrix = "\n".join(["Operation Matrix:"]+[str(row) for row in self.op_matrix]) 154 | seq = "Best alignment: "+str(self.align_seq) 155 | return "\n".join([orig, cor, cost_matrix, op_matrix, seq]) 156 | 157 | def parse(self, text): 158 | if isinstance(text, str): 159 | new_text = [] 160 | for tok in text.split(): # remove bpe delimeter 161 | new_text.append(tok if tok[-4:] != "" else tok[:-4]) 162 | text = Doc(self.nlp.vocab, new_text) 163 | else: 164 | new_text = [] 165 | for tok in text: 166 | new_text.append(tok if tok[-4:] != "" else tok[:-4]) 167 | text = Doc(self.nlp.vocab, new_text) 168 | self.nlp.tagger(text) 169 | self.nlp.parser(text) 170 | return text 171 | 172 | 173 | class Text(Transform): 174 | 175 | fields = ['SRC', 'TGT'] 176 | 177 | def __init__( 178 | self, 179 | SRC: Optional[Union[Field, Iterable[Field]]] = None, 180 | TGT: Optional[Union[Field, Iterable[Field]]] = None, 181 | ) -> Text: 182 | super().__init__() 183 | 184 | self.SRC = SRC 185 | self.TGT = TGT 186 | 187 | @property 188 | def src(self): 189 | return self.SRC, 190 | 191 | @property 192 | def tgt(self): 193 | return self.TGT, 194 | 195 | def load( 196 | self, 197 | data: Union[str, Iterable], 198 | lang: Optional[str] = None, 199 | **kwargs 200 | ) -> Iterable[TextSentence]: 201 | r""" 202 | Loads the data in Text-X format. 203 | Also supports for loading data from Text-U file with comments and non-integer IDs. 204 | 205 | Args: 206 | data (str or Iterable): 207 | A filename or a list of instances. 208 | lang (str): 209 | Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize. 210 | ``None`` if tokenization is not required. 211 | Default: ``None``. 212 | 213 | Returns: 214 | A list of :class:`TextSentence` instances. 215 | """ 216 | 217 | if lang is not None: 218 | tokenizer = Tokenizer(lang) 219 | if isinstance(data, str) and os.path.exists(data): 220 | f = open(data) 221 | if data.endswith('.txt'): 222 | lines = (i 223 | for s in f 224 | if len(s) > 1 225 | for i in StringIO((s.split() if lang is None else tokenizer(s)) + '\n')) 226 | else: 227 | lines = f 228 | else: 229 | if lang is not None: 230 | data = [tokenizer(s) for s in ([data] if isinstance(data, str) else data)] 231 | else: 232 | data = [data] if isinstance(data[0], str) else data 233 | lines = (i for s in data for i in StringIO(s + '\n')) 234 | 235 | index, sentence, nlp = 0, [], spacy.load("en", disable=["ner"]) 236 | for line in lines: 237 | line = line.strip() 238 | if len(line) == 0: 239 | yield TextSentence(self, sentence, index, nlp) 240 | index += 1 241 | sentence = [] 242 | else: 243 | sentence.append(line) 244 | 245 | 246 | class TextSentence(Sentence): 247 | 248 | def __init__(self, transform: Text, lines: List[str], index: Optional[int] = None, nlp=None) -> TextSentence: 249 | super().__init__(transform, index) 250 | 251 | self.cands = [(line+'\t').split('\t')[1] for line in lines[1:]] 252 | src, tgt = lines[0].split('\t')[1], self.cands[0] 253 | self.values = [src, tgt] 254 | 255 | def __repr__(self): 256 | self.cands = self.values[1] if isinstance(self.values[1], list) else [self.values[1]] 257 | lines = ['S\t' + self.values[0]] 258 | lines.extend(['T\t' + i for i in self.cands]) 259 | return '\n'.join(lines) + '\n' 260 | 261 | @classmethod 262 | def align(cls, src, tgt, nlp): 263 | return Alignment(src, tgt, nlp=nlp).align_seq 264 | -------------------------------------------------------------------------------- /data/clang8.toy: -------------------------------------------------------------------------------- 1 | S About winter 2 | T About winter 3 | 4 | S This is my second post . 5 | T This is my second post . 6 | 7 | S I will appreciate it if you correct my sentences . 8 | T I would appreciate it if you corrected my sentences . 9 | 10 | S It 's been getting colder these days here in Japan . 11 | T It 's been getting colder these days here in Japan . 12 | 13 | S The summer weather in Japan is not agreeable to me with its high humidity and temperature . 14 | T The summer weather in Japan is not agreeable to me with its high humidity and temperature . 15 | 16 | S So , as the winter is coming , I 'm getting to feel better . 17 | T So , as the winter is coming , I 'm getting to feel better . 18 | 19 | S Coldness is my energy . 20 | T Coldness is my energy . 21 | 22 | S And also , around the new year 's holidays , we will have a lot of enjoyable events 23 | T And also , around the new year 's holidays , we will have a lot of enjoyable events . 24 | 25 | S mostly with delicious foods , drinks , and good conversations . 26 | T Mostly with delicious food , drinks , and good conversation . 27 | 28 | S In addition , it is the time for skiing and snow boarding :) 29 | T In addition , it is the time for skiing and snowboarding :) 30 | 31 | S It is the very exciting season . 32 | T It is a very exciting season . 33 | 34 | S But , before enjoying those kind of happy time , I have to do a kind of boring , 35 | T But , before enjoying those kinds of happy times , I have to do some kind of boring , 36 | 37 | S customary practice . 38 | T customary practice . 39 | 40 | S Writing new year 's greeting cards is somehow a pain in the neck . 41 | T Writing new year 's greeting cards is somehow a pain in the neck . 42 | 43 | S Actually , I do n't have enough time to come up with an idea of the card 's design . 44 | T Actually , I did n't have enough time to come up with an idea for the card 's design . 45 | 46 | S I wish i could come across an good one in my mind . 47 | T I wish I could come across a good one in my mind . 48 | 49 | S Thank you for reading & thanks for your time . 50 | T Thank you for reading & thanks for your time . 51 | -------------------------------------------------------------------------------- /pred.sh: -------------------------------------------------------------------------------- 1 | args=$@ 2 | for arg in $args; do 3 | eval "$arg" 4 | done 5 | 6 | echo "config: ${config:=configs/roberta.yaml}" 7 | echo "path: ${path:=exp/ctc.roberta/model}" 8 | echo "data: ${data:=data/conll14.test}" 9 | echo "pred: ${pred:=$path.conll14.test.pred}" 10 | echo "input: ${input:=data/conll14.test.input}" 11 | echo "errant: ${errant:=data/conll14.test.errant.m2}" 12 | echo "devices: ${devices:=0}" 13 | echo "batch: ${batch:=10000}" 14 | echo "beam: ${beam:=12}" 15 | echo "iteration: ${iteration:=2}" 16 | 17 | (set -x; python -u run.py predict -d $devices -c $config -p $path --data $data --pred $pred --batch-size=$batch --beam-size=$beam --iteration $iteration 18 | CUDA_VISIBLE_DEVICES=$devices python recover.py --hyp $pred -o $pred.out -i $input -p $path -m 62) 19 | 20 | if ! conda env list | grep -q "^py27"; then 21 | echo "Creating the py27 environment..."; conda create -n py27 -y python=2.7 22 | fi 23 | 24 | source ~/anaconda3/etc/profile.d/conda.sh 25 | conda activate py27 26 | python tools/m2scorer/scripts/m2scorer.py -v $pred.out data/conll14.test.m2 > $pred.m2scorer.log 27 | tail -n 9 $pred.m2scorer.log 28 | conda deactivate -------------------------------------------------------------------------------- /recover.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import argparse 4 | from ctc.parser import CTCParser 5 | 6 | 7 | def convert(file, fout, fin, fpath, max_len=64): 8 | count, sentence = 0, [] 9 | tokenize_func = CTCParser.load(fpath).SRC.tokenize 10 | with open(file) as f, open(fout, 'w') as fout, open(fin) as fin: 11 | src_lines = [line.rstrip("\n") for line in fin] 12 | tgt_lines = [] 13 | for line in f: 14 | line = line.strip() 15 | if len(line) == 0: 16 | tgt_lines.append((sentence[1]+'\t').split('\t')[1]) 17 | sentence = [] 18 | else: 19 | sentence.append(line) 20 | count = 0 21 | for line in src_lines: 22 | if len(tokenize_func(line)) >= max_len: 23 | fout.write(line + "\n") 24 | else: 25 | fout.write(tgt_lines[count] + "\n") 26 | count += 1 27 | 28 | 29 | if __name__ == "__main__": 30 | parser = argparse.ArgumentParser(description='Output files in line with m2scorer eval format.') 31 | parser.add_argument('--path', '-p', help='path to the model file') 32 | parser.add_argument('--input', '-i', help='path to the input file') 33 | parser.add_argument('--hyp', help='path to the predicted file') 34 | parser.add_argument('--fout', '-o', help='path to output file') 35 | parser.add_argument('--max_len', '-m', help='max length') 36 | args = parser.parse_args() 37 | convert(args.hyp, args.fout, args.input, args.path, int(args.max_len)) 38 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import argparse 4 | 5 | from ctc import CTCParser 6 | from supar.cmds.run import init 7 | 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser(description='Create CTC Parser.') 11 | parser.set_defaults(Parser=CTCParser) 12 | parser.add_argument('--eval-tgt', action='store_true', help='whether to evaluate tgt') 13 | parser.add_argument('--lev', action='store_true', help='whether to evaluate P/R/F using levenshtein') 14 | parser.add_argument('--prefix', action='store_true', help='whether to perform prefix decoding') 15 | parser.add_argument('--glat', type=float, default=0, help='GLAT sampling ratio') 16 | parser.add_argument('--iteration', type=int, default=1, help='times of iterative decoding') 17 | subparsers = parser.add_subparsers(title='Commands', dest='mode') 18 | # train 19 | subparser = subparsers.add_parser('train', help='Train a parser.') 20 | subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first') 21 | subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training') 22 | subparser.add_argument('--encoder', choices=['lstm', 'transformer', 'bert'], default='bert', help='encoder to use') 23 | subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use') 24 | subparser.add_argument('--train', default='data/clang8.train', help='path to train file') 25 | subparser.add_argument('--dev', default='data/bea19.dev', help='path to dev file') 26 | subparser.add_argument('--test', default='data/conll14.test', help='path to test file') 27 | subparser.add_argument('--embed', default='glove-6b-100', help='file or embeddings available at `supar.utils.Embedding`') 28 | subparser.add_argument('--vocab', action='store_true', help='extend the vocab from new data') 29 | # evaluate 30 | subparser = subparsers.add_parser('evaluate', help='Evaluate the specified parser and dataset.') 31 | subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use') 32 | subparser.add_argument('--data', default='data/conll14.test', help='path to dataset') 33 | # predict 34 | subparser = subparsers.add_parser('predict', help='Use a trained parser to make predictions.') 35 | subparser.add_argument('--buckets', default=8, type=int, help='max num of buckets to use') 36 | subparser.add_argument('--data', default='data/conll14.test', help='path to dataset') 37 | subparser.add_argument('--pred', default='pred.txt', help='path to predicted result') 38 | subparser.add_argument('--prob', action='store_true', help='whether to output probs') 39 | init(parser) 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /supar: -------------------------------------------------------------------------------- 1 | 3rdparty/parser/supar -------------------------------------------------------------------------------- /tools/m2scorer/LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 2, June 1991 3 | 4 | Copyright (C) 1989, 1991 Free Software Foundation, Inc., 5 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 6 | Everyone is permitted to copy and distribute verbatim copies 7 | of this license document, but changing it is not allowed. 8 | 9 | Preamble 10 | 11 | The licenses for most software are designed to take away your 12 | freedom to share and change it. By contrast, the GNU General Public 13 | License is intended to guarantee your freedom to share and change free 14 | software--to make sure the software is free for all its users. This 15 | General Public License applies to most of the Free Software 16 | Foundation's software and to any other program whose authors commit to 17 | using it. (Some other Free Software Foundation software is covered by 18 | the GNU Lesser General Public License instead.) You can apply it to 19 | your programs, too. 20 | 21 | When we speak of free software, we are referring to freedom, not 22 | price. Our General Public Licenses are designed to make sure that you 23 | have the freedom to distribute copies of free software (and charge for 24 | this service if you wish), that you receive source code or can get it 25 | if you want it, that you can change the software or use pieces of it 26 | in new free programs; and that you know you can do these things. 27 | 28 | To protect your rights, we need to make restrictions that forbid 29 | anyone to deny you these rights or to ask you to surrender the rights. 30 | These restrictions translate to certain responsibilities for you if you 31 | distribute copies of the software, or if you modify it. 32 | 33 | For example, if you distribute copies of such a program, whether 34 | gratis or for a fee, you must give the recipients all the rights that 35 | you have. You must make sure that they, too, receive or can get the 36 | source code. And you must show them these terms so they know their 37 | rights. 38 | 39 | We protect your rights with two steps: (1) copyright the software, and 40 | (2) offer you this license which gives you legal permission to copy, 41 | distribute and/or modify the software. 42 | 43 | Also, for each author's protection and ours, we want to make certain 44 | that everyone understands that there is no warranty for this free 45 | software. If the software is modified by someone else and passed on, we 46 | want its recipients to know that what they have is not the original, so 47 | that any problems introduced by others will not reflect on the original 48 | authors' reputations. 49 | 50 | Finally, any free program is threatened constantly by software 51 | patents. We wish to avoid the danger that redistributors of a free 52 | program will individually obtain patent licenses, in effect making the 53 | program proprietary. To prevent this, we have made it clear that any 54 | patent must be licensed for everyone's free use or not licensed at all. 55 | 56 | The precise terms and conditions for copying, distribution and 57 | modification follow. 58 | 59 | GNU GENERAL PUBLIC LICENSE 60 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 61 | 62 | 0. This License applies to any program or other work which contains 63 | a notice placed by the copyright holder saying it may be distributed 64 | under the terms of this General Public License. The "Program", below, 65 | refers to any such program or work, and a "work based on the Program" 66 | means either the Program or any derivative work under copyright law: 67 | that is to say, a work containing the Program or a portion of it, 68 | either verbatim or with modifications and/or translated into another 69 | language. (Hereinafter, translation is included without limitation in 70 | the term "modification".) Each licensee is addressed as "you". 71 | 72 | Activities other than copying, distribution and modification are not 73 | covered by this License; they are outside its scope. The act of 74 | running the Program is not restricted, and the output from the Program 75 | is covered only if its contents constitute a work based on the 76 | Program (independent of having been made by running the Program). 77 | Whether that is true depends on what the Program does. 78 | 79 | 1. You may copy and distribute verbatim copies of the Program's 80 | source code as you receive it, in any medium, provided that you 81 | conspicuously and appropriately publish on each copy an appropriate 82 | copyright notice and disclaimer of warranty; keep intact all the 83 | notices that refer to this License and to the absence of any warranty; 84 | and give any other recipients of the Program a copy of this License 85 | along with the Program. 86 | 87 | You may charge a fee for the physical act of transferring a copy, and 88 | you may at your option offer warranty protection in exchange for a fee. 89 | 90 | 2. You may modify your copy or copies of the Program or any portion 91 | of it, thus forming a work based on the Program, and copy and 92 | distribute such modifications or work under the terms of Section 1 93 | above, provided that you also meet all of these conditions: 94 | 95 | a) You must cause the modified files to carry prominent notices 96 | stating that you changed the files and the date of any change. 97 | 98 | b) You must cause any work that you distribute or publish, that in 99 | whole or in part contains or is derived from the Program or any 100 | part thereof, to be licensed as a whole at no charge to all third 101 | parties under the terms of this License. 102 | 103 | c) If the modified program normally reads commands interactively 104 | when run, you must cause it, when started running for such 105 | interactive use in the most ordinary way, to print or display an 106 | announcement including an appropriate copyright notice and a 107 | notice that there is no warranty (or else, saying that you provide 108 | a warranty) and that users may redistribute the program under 109 | these conditions, and telling the user how to view a copy of this 110 | License. (Exception: if the Program itself is interactive but 111 | does not normally print such an announcement, your work based on 112 | the Program is not required to print an announcement.) 113 | 114 | These requirements apply to the modified work as a whole. If 115 | identifiable sections of that work are not derived from the Program, 116 | and can be reasonably considered independent and separate works in 117 | themselves, then this License, and its terms, do not apply to those 118 | sections when you distribute them as separate works. But when you 119 | distribute the same sections as part of a whole which is a work based 120 | on the Program, the distribution of the whole must be on the terms of 121 | this License, whose permissions for other licensees extend to the 122 | entire whole, and thus to each and every part regardless of who wrote it. 123 | 124 | Thus, it is not the intent of this section to claim rights or contest 125 | your rights to work written entirely by you; rather, the intent is to 126 | exercise the right to control the distribution of derivative or 127 | collective works based on the Program. 128 | 129 | In addition, mere aggregation of another work not based on the Program 130 | with the Program (or with a work based on the Program) on a volume of 131 | a storage or distribution medium does not bring the other work under 132 | the scope of this License. 133 | 134 | 3. You may copy and distribute the Program (or a work based on it, 135 | under Section 2) in object code or executable form under the terms of 136 | Sections 1 and 2 above provided that you also do one of the following: 137 | 138 | a) Accompany it with the complete corresponding machine-readable 139 | source code, which must be distributed under the terms of Sections 140 | 1 and 2 above on a medium customarily used for software interchange; or, 141 | 142 | b) Accompany it with a written offer, valid for at least three 143 | years, to give any third party, for a charge no more than your 144 | cost of physically performing source distribution, a complete 145 | machine-readable copy of the corresponding source code, to be 146 | distributed under the terms of Sections 1 and 2 above on a medium 147 | customarily used for software interchange; or, 148 | 149 | c) Accompany it with the information you received as to the offer 150 | to distribute corresponding source code. (This alternative is 151 | allowed only for noncommercial distribution and only if you 152 | received the program in object code or executable form with such 153 | an offer, in accord with Subsection b above.) 154 | 155 | The source code for a work means the preferred form of the work for 156 | making modifications to it. For an executable work, complete source 157 | code means all the source code for all modules it contains, plus any 158 | associated interface definition files, plus the scripts used to 159 | control compilation and installation of the executable. However, as a 160 | special exception, the source code distributed need not include 161 | anything that is normally distributed (in either source or binary 162 | form) with the major components (compiler, kernel, and so on) of the 163 | operating system on which the executable runs, unless that component 164 | itself accompanies the executable. 165 | 166 | If distribution of executable or object code is made by offering 167 | access to copy from a designated place, then offering equivalent 168 | access to copy the source code from the same place counts as 169 | distribution of the source code, even though third parties are not 170 | compelled to copy the source along with the object code. 171 | 172 | 4. You may not copy, modify, sublicense, or distribute the Program 173 | except as expressly provided under this License. Any attempt 174 | otherwise to copy, modify, sublicense or distribute the Program is 175 | void, and will automatically terminate your rights under this License. 176 | However, parties who have received copies, or rights, from you under 177 | this License will not have their licenses terminated so long as such 178 | parties remain in full compliance. 179 | 180 | 5. You are not required to accept this License, since you have not 181 | signed it. However, nothing else grants you permission to modify or 182 | distribute the Program or its derivative works. These actions are 183 | prohibited by law if you do not accept this License. Therefore, by 184 | modifying or distributing the Program (or any work based on the 185 | Program), you indicate your acceptance of this License to do so, and 186 | all its terms and conditions for copying, distributing or modifying 187 | the Program or works based on it. 188 | 189 | 6. Each time you redistribute the Program (or any work based on the 190 | Program), the recipient automatically receives a license from the 191 | original licensor to copy, distribute or modify the Program subject to 192 | these terms and conditions. You may not impose any further 193 | restrictions on the recipients' exercise of the rights granted herein. 194 | You are not responsible for enforcing compliance by third parties to 195 | this License. 196 | 197 | 7. If, as a consequence of a court judgment or allegation of patent 198 | infringement or for any other reason (not limited to patent issues), 199 | conditions are imposed on you (whether by court order, agreement or 200 | otherwise) that contradict the conditions of this License, they do not 201 | excuse you from the conditions of this License. If you cannot 202 | distribute so as to satisfy simultaneously your obligations under this 203 | License and any other pertinent obligations, then as a consequence you 204 | may not distribute the Program at all. For example, if a patent 205 | license would not permit royalty-free redistribution of the Program by 206 | all those who receive copies directly or indirectly through you, then 207 | the only way you could satisfy both it and this License would be to 208 | refrain entirely from distribution of the Program. 209 | 210 | If any portion of this section is held invalid or unenforceable under 211 | any particular circumstance, the balance of the section is intended to 212 | apply and the section as a whole is intended to apply in other 213 | circumstances. 214 | 215 | It is not the purpose of this section to induce you to infringe any 216 | patents or other property right claims or to contest validity of any 217 | such claims; this section has the sole purpose of protecting the 218 | integrity of the free software distribution system, which is 219 | implemented by public license practices. Many people have made 220 | generous contributions to the wide range of software distributed 221 | through that system in reliance on consistent application of that 222 | system; it is up to the author/donor to decide if he or she is willing 223 | to distribute software through any other system and a licensee cannot 224 | impose that choice. 225 | 226 | This section is intended to make thoroughly clear what is believed to 227 | be a consequence of the rest of this License. 228 | 229 | 8. If the distribution and/or use of the Program is restricted in 230 | certain countries either by patents or by copyrighted interfaces, the 231 | original copyright holder who places the Program under this License 232 | may add an explicit geographical distribution limitation excluding 233 | those countries, so that distribution is permitted only in or among 234 | countries not thus excluded. In such case, this License incorporates 235 | the limitation as if written in the body of this License. 236 | 237 | 9. The Free Software Foundation may publish revised and/or new versions 238 | of the General Public License from time to time. Such new versions will 239 | be similar in spirit to the present version, but may differ in detail to 240 | address new problems or concerns. 241 | 242 | Each version is given a distinguishing version number. If the Program 243 | specifies a version number of this License which applies to it and "any 244 | later version", you have the option of following the terms and conditions 245 | either of that version or of any later version published by the Free 246 | Software Foundation. If the Program does not specify a version number of 247 | this License, you may choose any version ever published by the Free Software 248 | Foundation. 249 | 250 | 10. If you wish to incorporate parts of the Program into other free 251 | programs whose distribution conditions are different, write to the author 252 | to ask for permission. For software which is copyrighted by the Free 253 | Software Foundation, write to the Free Software Foundation; we sometimes 254 | make exceptions for this. Our decision will be guided by the two goals 255 | of preserving the free status of all derivatives of our free software and 256 | of promoting the sharing and reuse of software generally. 257 | 258 | NO WARRANTY 259 | 260 | 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY 261 | FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN 262 | OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES 263 | PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED 264 | OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 265 | MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS 266 | TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE 267 | PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, 268 | REPAIR OR CORRECTION. 269 | 270 | 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 271 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR 272 | REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, 273 | INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING 274 | OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED 275 | TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY 276 | YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER 277 | PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE 278 | POSSIBILITY OF SUCH DAMAGES. 279 | 280 | END OF TERMS AND CONDITIONS 281 | 282 | How to Apply These Terms to Your New Programs 283 | 284 | If you develop a new program, and you want it to be of the greatest 285 | possible use to the public, the best way to achieve this is to make it 286 | free software which everyone can redistribute and change under these terms. 287 | 288 | To do so, attach the following notices to the program. It is safest 289 | to attach them to the start of each source file to most effectively 290 | convey the exclusion of warranty; and each file should have at least 291 | the "copyright" line and a pointer to where the full notice is found. 292 | 293 | 294 | Copyright (C) 295 | 296 | This program is free software; you can redistribute it and/or modify 297 | it under the terms of the GNU General Public License as published by 298 | the Free Software Foundation; either version 2 of the License, or 299 | (at your option) any later version. 300 | 301 | This program is distributed in the hope that it will be useful, 302 | but WITHOUT ANY WARRANTY; without even the implied warranty of 303 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 304 | GNU General Public License for more details. 305 | 306 | You should have received a copy of the GNU General Public License along 307 | with this program; if not, write to the Free Software Foundation, Inc., 308 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 309 | 310 | Also add information on how to contact you by electronic and paper mail. 311 | 312 | If the program is interactive, make it output a short notice like this 313 | when it starts in an interactive mode: 314 | 315 | Gnomovision version 69, Copyright (C) year name of author 316 | Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 317 | This is free software, and you are welcome to redistribute it 318 | under certain conditions; type `show c' for details. 319 | 320 | The hypothetical commands `show w' and `show c' should show the appropriate 321 | parts of the General Public License. Of course, the commands you use may 322 | be called something other than `show w' and `show c'; they could even be 323 | mouse-clicks or menu items--whatever suits your program. 324 | 325 | You should also get your employer (if you work as a programmer) or your 326 | school, if any, to sign a "copyright disclaimer" for the program, if 327 | necessary. Here is a sample; alter the names: 328 | 329 | Yoyodyne, Inc., hereby disclaims all copyright interest in the program 330 | `Gnomovision' (which makes passes at compilers) written by James Hacker. 331 | 332 | , 1 April 1989 333 | Ty Coon, President of Vice 334 | 335 | This General Public License does not permit incorporating your program into 336 | proprietary programs. If your program is a subroutine library, you may 337 | consider it more useful to permit linking proprietary applications with the 338 | library. If this is what you want to do, use the GNU Lesser General 339 | Public License instead of this License. 340 | -------------------------------------------------------------------------------- /tools/m2scorer/README: -------------------------------------------------------------------------------- 1 | Release 3.2 2 | Revision: 22 April 2014 3 | 4 | This README file describes the NUS MaxMatch (M^2) scorer. 5 | Copyright (C) 2013 Daniel Dahlmeier, Hwee Tou Ng, Christian Hadiwinoto, 6 | and Raymond Hendy Susanto 7 | 8 | This program is free software: you can redistribute it and/or modify 9 | it under the terms of the GNU General Public License as published by 10 | the Free Software Foundation, either version 3 of the License, or (at 11 | your option) any later version. 12 | 13 | This program is distributed in the hope that it will be useful, but 14 | WITHOUT ANY WARRANTY; without even the implied warranty of 15 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 16 | General Public License for more details. 17 | 18 | You should have received a copy of the GNU General Public License 19 | along with this program. If not, see . 20 | 21 | If you are using the NUS M^2 scorer in your work, please include a 22 | citation of the following paper: 23 | 24 | Daniel Dahlmeier and Hwee Tou Ng. 2012. Better Evaluation for 25 | Grammatical Error Correction. In Proceedings of the 2012 Conference of 26 | the North American Chapter of the Association for Computational 27 | Linguistics: Human Language Technologies (NAACL 2012). 28 | 29 | Any questions regarding the NUS M^2 scorer should be directed to 30 | Hwee Tou Ng (nght@comp.nus.edu.sg). 31 | 32 | 33 | Contents 34 | ======== 35 | 0. Quickstart 36 | 1. Pre-requisites 37 | 2. Using the scorer 38 | 2.1 System output format 39 | 2.2 Scorer's gold standard format 40 | 3. Converting the CoNLL-2014 data format 41 | 4. Revisions 42 | 4.1 Alternative edits 43 | 4.2 F-beta measure 44 | 4.3 Handling of insertion edits 45 | 4.4 Bug fix for scoring against multiple sets of gold edits, and 46 | dealing with sequences of insertion/deletion edits 47 | 48 | 49 | 0. Quickstart 50 | ============= 51 | ./m2scorer [-v] SYSTEM SOURCE_GOLD 52 | 53 | SYSTEM = the system output in sentence-per-line plain text. 54 | SOURCE_GOLD = the source sentences with gold edits. 55 | 56 | 57 | 1. Pre-requisites 58 | ================= 59 | The following dependencies have to be installed to use the M^2 scorer. 60 | 61 | + Python (>= 2.6.4, < 3.0, older versions might work but are not tested) 62 | + nltk (http://www.nltk.org, needed for sentence splitting) 63 | 64 | 65 | 2. Using the scorer 66 | =================== 67 | Usage: m2scorer [OPTIONS] SYSTEM SOURCE_GOLD 68 | where 69 | SYSTEM - system output, one sentence per line 70 | SOURCE_GOLD - source sentences with gold token edits 71 | 72 | OPTIONS 73 | -v --verbose - print verbose output 74 | --very_verbose - print lots of verbose output 75 | --max_unchanged_words N - Maximum unchanged words when extracting edits. Default = 2. 76 | --ignore_whitespace_casing - Ignore edits that only affect whitespace and casing. Default no. 77 | --beta - Set the ratio of recall importance against precision. Default = 0.5. 78 | 79 | 80 | 2.1 System output format 81 | ======================== 82 | SYSTEM = File that contains the output of the error correction 83 | system. The sentences should be in tokenized plain text, sentence-per-line 84 | format. 85 | 86 | Format: 87 | 88 | 89 | ... 90 | 91 | Examples of tokenization: 92 | ------------------------- 93 | Original : He said, "We shouldn't go to the place. It'll kill one of us." 94 | Tokenized : He said , " We should n't go to the place . It 'll kill one of us . " 95 | 96 | Note: Tokenization in the CoNLL-2014 shared task uses NLTK word tokenizer. 97 | 98 | Sample output: 99 | -------------- 100 | ===> system <=== 101 | A cat sat on the mat . 102 | The Dog . 103 | 104 | 105 | 2.2 Scorer's gold standard format 106 | ================================= 107 | SOURCE_GOLD = source sentences (i.e. input to the error correction 108 | system) and the gold annotation in TOKEN offsets (starting from zero). 109 | 110 | Format: 111 | S 112 | A |||||||||||||| 113 | A |||||||||||||| 114 | 115 | S 116 | A |||||||||||||| 117 | 118 | 119 | Notes: 120 | ------ 121 | - Each source sentence should appear on a single line starting with "S " 122 | - Each source sentence is followed by zero or more annotations. 123 | - Each annotation is on a separate line starting with "A ". 124 | - Sentences are separated by one or more empty lines. 125 | - The source sentences need to be tokenized in the same way as the system output. 126 | - Start and end offset for annotations are in token offsets (starting from zero). 127 | - The gold edits can include one or more possible correction strings. Multiple corrections should be separate by '||'. 128 | - The error type, required field, and comment are not used for scoring at the moment. You can put dummy values there. 129 | - The annotator ID is used to identify a distinct annotation set by which system edits will be evaluated. 130 | - Each distinct annotation set, identified by an annotator ID, is an alternative 131 | - If one sentence has multiple annotator IDs, score will be computed for each annotator. 132 | - If one of the multiple annotation alternatives is no edit at all, an edit with type 'noop' or with offsets '-1 -1' must be specified. 133 | - The final score for the sentence will use the set of edits by an annotation set maximizing the score. 134 | 135 | 136 | Example: 137 | -------- 138 | ===> source_gold <=== 139 | S The cat sat at mat . 140 | A 3 4|||Prep|||on|||REQUIRED|||-NONE-|||0 141 | A 4 4|||ArtOrDet|||the||a|||REQUIRED|||-NONE-|||0 142 | 143 | S The dog . 144 | A 1 2|||NN|||dogs|||REQUIRED|||-NONE-|||0 145 | A -1 -1|||noop|||-NONE-|||-NONE-|||-NONE-|||1 146 | 147 | S Giant otters is an apex predator . 148 | A 2 3|||SVA|||are|||REQUIRED|||-NONE-|||0 149 | A 3 4|||ArtOrDet|||-NONE-|||REQUIRED|||-NONE-|||0 150 | A 5 6|||NN|||predators|||REQUIRED|||-NONE-|||0 151 | A 1 2|||NN|||otter|||REQUIRED|||-NONE-|||1 152 | 153 | 154 | 155 | ===> system <=== 156 | A cat sat on the mat . 157 | The dog . 158 | Giant otters are apex predator . 159 | 160 | ./m2scorer system source_gold 161 | Precision : 0.8 162 | Recall : 0.8 163 | F_0.5 : 0.8 164 | 165 | For sentence #1, the system makes two valid edits {(at-> on), 166 | (\epsilon -> the)} and one unnecessary edit (The -> A). 167 | 168 | For sentence #2, despite missing one gold edit (dog -> dogs) according 169 | to annotation set 0, the system misses nothing according to set 1. 170 | 171 | For sentence #3, according to annotation set 0, the system makes two 172 | valid edits {(is -> are), (an -> \epsilon)} and misses one edit 173 | (predator -> predators); however according to set 1, the system makes 174 | two unnecessary edits {(is -> are), (an -> \epsilon)} and misses one 175 | edit (otters -> otter). 176 | 177 | By the case above, there are four valid edits, one unnecessary edit, 178 | and one missing edit. Therefore precision is 4/5 = 0.8. Similarly for 179 | recall. In the above example, the beta value for the F-measure is 0.5 180 | (the default value). 181 | 182 | 183 | 3. Converting the CoNLL-2014 data format 184 | ======================================== 185 | The data format used in the M^2 scorer differs from the format used in 186 | the CoNLL-2014 shared task (http://www.comp.nus.edu.sg/~nlp/conll14st.html) 187 | in two aspects: 188 | - sentence-level edits 189 | - token edit offsets 190 | 191 | To convert source files and gold edits from the CoNLL-2014 format into 192 | the M^2 format, run the preprocessing script bundled with the CoNLL-2014 193 | training data. 194 | 195 | 196 | 4. Revisions 197 | ============ 198 | 199 | 4.1 Alternative edits 200 | 201 | In this release, there is a major modification which enables scoring 202 | with multiple sets of gold edits. For every sentence, the system 203 | output will be scored against every available set of gold edits for 204 | the sentence, and the set of gold edits that maximizes the F score of 205 | the sentence is chosen. 206 | 207 | This modification was carried out by Christian Hadiwinoto, 2013. 208 | 209 | 210 | 4.2 F-beta measure 211 | 212 | While the previous release always uses the F1 measure, i.e. beta = 213 | 1.0, this release supports any value for beta. The default value for 214 | beta for this version is 0.5. 215 | 216 | This modification was carried out by Raymond Hendy Susanto, 2013. 217 | 218 | 219 | 4.3 Handling of insertion edits 220 | 221 | Multiple insertion edits (starting and ending at the same offset) that 222 | match a gold edit were counted repeatedly, leading to erroneous and 223 | inflated scores. A fix has been made to handle this. The order of 224 | insertion edits, which was not handled in the previous version, is now 225 | enforced. 226 | 227 | This modification was jointly carried out by Raymond Hendy Susanto and 228 | Christian Hadiwinoto, 2014. 229 | 230 | 231 | 4.4 Bug fix for scoring against multiple sets of gold edits, and 232 | dealing with sequences of insertion/deletion edits 233 | 234 | Fixed a bug in the M2 scorer arising from scoring against gold edits 235 | from multiple annotators. Specifically, the bug sometimes caused 236 | incorrect scores to be reported when scoring against the gold edits of 237 | subsequent annotators (other than the first annotator). 238 | 239 | Fixed a bug in the M2 scorer that caused erroneous scores to be 240 | reported when dealing with insertion edits followed by deletion edits 241 | (or vice versa). 242 | 243 | The above modifications were carried out by Christian Hadiwinoto, 244 | 2014. 245 | -------------------------------------------------------------------------------- /tools/m2scorer/example/README: -------------------------------------------------------------------------------- 1 | (execute these examples from the m2scorer top-level directory) 2 | 3 | 4 | ./m2scorer example/system_output.txt example/source_gold 5 | 6 | 7 | -------------------------------------------------------------------------------- /tools/m2scorer/example/source_gold: -------------------------------------------------------------------------------- 1 | S The cat sat at mat . 2 | A 3 4|||Prep|||on|||REQUIRED|||-NONE-|||0 3 | A 4 4|||ArtOrDet|||the||a|||REQUIRED|||-NONE-|||0 4 | 5 | S The dog . 6 | A 1 2|||NN|||dogs|||REQUIRED|||-NONE-|||0 7 | A -1 -1|||noop|||-NONE-|||-NONE-|||-NONE-|||1 8 | 9 | S Giant otters is an apex predator . 10 | A 2 3|||SVA|||are|||REQUIRED|||-NONE-|||0 11 | A 3 4|||ArtOrDet|||-NONE-|||REQUIRED|||-NONE-|||0 12 | A 5 6|||NN|||predators|||REQUIRED|||-NONE-|||0 13 | A 1 2|||NN|||otter|||REQUIRED|||-NONE-|||1 14 | 15 | -------------------------------------------------------------------------------- /tools/m2scorer/example/system: -------------------------------------------------------------------------------- 1 | A cat sat on the mat . 2 | The dog . 3 | Giant otters are apex predator . 4 | -------------------------------------------------------------------------------- /tools/m2scorer/example/system2: -------------------------------------------------------------------------------- 1 | A cat sat on mat . 2 | The dog . 3 | Giant otters are apex predator . 4 | -------------------------------------------------------------------------------- /tools/m2scorer/m2scorer: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # This file is part of the NUS M2 scorer. 4 | # The NUS M2 scorer is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | 9 | # The NUS M2 scorer is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program. If not, see . 16 | 17 | # file: m2scorer.py 18 | # 19 | # score a system's output against a gold reference 20 | # 21 | # Usage: m2scorer.py [OPTIONS] proposed_sentences source_gold 22 | # where 23 | # proposed_sentences - system output, sentence per line 24 | # source_gold - source sentences with gold token edits 25 | # OPTIONS 26 | # -v --verbose - print verbose output 27 | # --very_verbose - print lots of verbose output 28 | # --max_unchanged_words N - Maximum unchanged words when extracting edits. Default 2." 29 | # --beta B - Beta value for F-measure. Default 0.5." 30 | # --ignore_whitespace_casing - Ignore edits that only affect whitespace and caseing. Default no." 31 | # 32 | 33 | import sys 34 | import levenshtein 35 | from getopt import getopt 36 | from util import paragraphs 37 | from util import smart_open 38 | 39 | 40 | 41 | def load_annotation(gold_file): 42 | source_sentences = [] 43 | gold_edits = [] 44 | fgold = smart_open(gold_file, 'r') 45 | puffer = fgold.read() 46 | fgold.close() 47 | puffer = puffer.decode('utf8') 48 | for item in paragraphs(puffer.splitlines(True)): 49 | item = item.splitlines(False) 50 | sentence = [line[2:].strip() for line in item if line.startswith('S ')] 51 | assert sentence != [] 52 | annotations = {} 53 | for line in item[1:]: 54 | if line.startswith('I ') or line.startswith('S '): 55 | continue 56 | assert line.startswith('A ') 57 | line = line[2:] 58 | fields = line.split('|||') 59 | start_offset = int(fields[0].split()[0]) 60 | end_offset = int(fields[0].split()[1]) 61 | etype = fields[1] 62 | if etype == 'noop': 63 | start_offset = -1 64 | end_offset = -1 65 | corrections = [c.strip() if c != '-NONE-' else '' for c in fields[2].split('||')] 66 | # NOTE: start and end are *token* offsets 67 | original = ' '.join(' '.join(sentence).split()[start_offset:end_offset]) 68 | annotator = int(fields[5]) 69 | if annotator not in annotations.keys(): 70 | annotations[annotator] = [] 71 | annotations[annotator].append((start_offset, end_offset, original, corrections)) 72 | tok_offset = 0 73 | for this_sentence in sentence: 74 | tok_offset += len(this_sentence.split()) 75 | source_sentences.append(this_sentence) 76 | this_edits = {} 77 | for annotator, annotation in annotations.iteritems(): 78 | this_edits[annotator] = [edit for edit in annotation if edit[0] <= tok_offset and edit[1] <= tok_offset and edit[0] >= 0 and edit[1] >= 0] 79 | if len(this_edits) == 0: 80 | this_edits[0] = [] 81 | gold_edits.append(this_edits) 82 | return (source_sentences, gold_edits) 83 | 84 | 85 | def print_usage(): 86 | print >> sys.stderr, "Usage: m2scorer.py [OPTIONS] proposed_sentences gold_source" 87 | print >> sys.stderr, "where" 88 | print >> sys.stderr, " proposed_sentences - system output, sentence per line" 89 | print >> sys.stderr, " source_gold - source sentences with gold token edits" 90 | print >> sys.stderr, "OPTIONS" 91 | print >> sys.stderr, " -v --verbose - print verbose output" 92 | print >> sys.stderr, " --very_verbose - print lots of verbose output" 93 | print >> sys.stderr, " --max_unchanged_words N - Maximum unchanged words when extraction edit. Default 2." 94 | print >> sys.stderr, " --beta B - Beta value for F-measure. Default 0.5." 95 | print >> sys.stderr, " --ignore_whitespace_casing - Ignore edits that only affect whitespace and caseing. Default no." 96 | 97 | 98 | 99 | max_unchanged_words=2 100 | beta = 0.5 101 | ignore_whitespace_casing= False 102 | verbose = False 103 | very_verbose = False 104 | opts, args = getopt(sys.argv[1:], "v", ["max_unchanged_words=", "beta=", "verbose", "ignore_whitespace_casing", "very_verbose"]) 105 | for o, v in opts: 106 | if o in ('-v', '--verbose'): 107 | verbose = True 108 | elif o == '--very_verbose': 109 | very_verbose = True 110 | elif o == '--max_unchanged_words': 111 | max_unchanged_words = int(v) 112 | elif o == '--beta': 113 | beta = float(v) 114 | elif o == '--ignore_whitespace_casing': 115 | ignore_whitespace_casing = True 116 | else: 117 | print >> sys.stderr, "Unknown option :", o 118 | print_usage() 119 | sys.exit(-1) 120 | 121 | # starting point 122 | if len(args) != 2: 123 | print_usage() 124 | sys.exit(-1) 125 | 126 | system_file = args[0] 127 | gold_file = args[1] 128 | 129 | # load source sentences and gold edits 130 | source_sentences, gold_edits = load_annotation(gold_file) 131 | 132 | # load system hypotheses 133 | fin = smart_open(system_file, 'r') 134 | system_sentences = [line.decode("utf8").strip() for line in fin.readlines()] 135 | fin.close() 136 | 137 | p, r, f1 = levenshtein.batch_multi_pre_rec_f1(system_sentences, source_sentences, gold_edits, max_unchanged_words, beta, ignore_whitespace_casing, verbose, very_verbose) 138 | 139 | print "Precision : %.4f" % p 140 | print "Recall : %.4f" % r 141 | print "F_%.1f : %.4f" % (beta, f1) 142 | 143 | -------------------------------------------------------------------------------- /tools/m2scorer/scripts/Tokenizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: iso-8859-15 -*- 3 | 4 | # This file is part of the NUS M2 scorer. 5 | # The NUS M2 scorer is free software: you can redistribute it and/or modify 6 | # it under the terms of the GNU General Public License as published by 7 | # the Free Software Foundation, either version 3 of the License, or 8 | # (at your option) any later version. 9 | 10 | # The NUS M2 scorer is distributed in the hope that it will be useful, 11 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 12 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 | # GNU General Public License for more details. 14 | 15 | # You should have received a copy of the GNU General Public License 16 | # along with this program. If not, see . 17 | 18 | # file: Tokenizer.py 19 | # 20 | # A Penn Treebank tokenizer reimplemented based on the MOSES implementation. 21 | # 22 | # usage : %prog < input > output 23 | 24 | 25 | import re 26 | import sys 27 | 28 | 29 | class DummyTokenizer(object): 30 | 31 | def tokenize(self, text): 32 | return text.split() 33 | 34 | 35 | 36 | class PTBTokenizer(object): 37 | 38 | def __init__(self, language="en"): 39 | self.language = language 40 | self.nonbreaking_prefixes = {} 41 | self.nonbreaking_prefixes_numeric = {} 42 | self.nonbreaking_prefixes["en"] = ''' A B C D E F G H I J K L M N O P Q R S T U V W X Y Z 43 | Adj Adm Adv Asst Bart Bldg Brig Bros Capt Cmdr Col Comdr Con Corp Cpl DR Dr Drs Ens 44 | Gen Gov Hon Hr Hosp Insp Lt MM MR MRS MS Maj Messrs Mlle Mme Mr Mrs Ms Msgr Op Ord 45 | Pfc Ph Prof Pvt Rep Reps Res Rev Rt Sen Sens Sfc Sgt Sr St Supt Surg 46 | v vs i.e rev e.g Nos Nr'''.split() 47 | self.nonbreaking_prefixes_numeric["en"] = '''No Art pp'''.split() 48 | self.special_chars = re.compile(r"([^\w\s\.\'\`\,\-\"\|\/])", flags=re.UNICODE) 49 | 50 | def tokenize(self, text, ptb=False): 51 | text = text.strip() 52 | text = " " + text + " " 53 | 54 | # Separate all "other" punctuation 55 | 56 | text = re.sub(self.special_chars, r' \1 ', text) 57 | text = re.sub(r";", r' ; ', text) 58 | text = re.sub(r":", r' : ', text) 59 | 60 | # replace the pipe character 61 | text = re.sub(r"\|", r' -PIPE- ', text) 62 | 63 | # split internal slash, keep others 64 | text = re.sub(r"(\S)/(\S)", r'\1 / \2', text) 65 | 66 | # PTB tokenization 67 | if ptb: 68 | text = re.sub(r"\(", r' -LRB- ', text) 69 | text = re.sub(r"\)", r' -RRB- ', text) 70 | text = re.sub(r"\[", r' -LSB- ', text) 71 | text = re.sub(r"\]", r' -RSB- ', text) 72 | text = re.sub(r"\{", r' -LCB- ', text) 73 | text = re.sub(r"\}", r' -RCB- ', text) 74 | 75 | text = re.sub(r"\"\s*$", r" '' ", text) 76 | text = re.sub(r"^\s*\"", r' `` ', text) 77 | text = re.sub(r"(\S)\"\s", r"\1 '' ", text) 78 | text = re.sub(r"\s\"(\S)", r" `` \1", text) 79 | text = re.sub(r"(\S)\"", r"\1 '' ", text) 80 | text = re.sub(r"\"(\S)", r" `` \1", text) 81 | text = re.sub(r"'\s*$", r" ' ", text) 82 | text = re.sub(r"^\s*'", r" ` ", text) 83 | text = re.sub(r"(\S)'\s", r"\1 ' ", text) 84 | text = re.sub(r"\s'(\S)", r" ` \1", text) 85 | 86 | text = re.sub(r"'ll", r" -CONTRACT-ll", text) 87 | text = re.sub(r"'re", r" -CONTRACT-re", text) 88 | text = re.sub(r"'ve", r" -CONTRACT-ve", text) 89 | text = re.sub(r"n't", r" n-CONTRACT-t", text) 90 | text = re.sub(r"'LL", r" -CONTRACT-LL", text) 91 | text = re.sub(r"'RE", r" -CONTRACT-RE", text) 92 | text = re.sub(r"'VE", r" -CONTRACT-VE", text) 93 | text = re.sub(r"N'T", r" N-CONTRACT-T", text) 94 | text = re.sub(r"cannot", r"can not", text) 95 | text = re.sub(r"Cannot", r"Can not", text) 96 | 97 | # multidots stay together 98 | text = re.sub(r"\.([\.]+)", r" DOTMULTI\1", text) 99 | while re.search("DOTMULTI\.", text): 100 | text = re.sub(r"DOTMULTI\.([^\.])", r"DOTDOTMULTI \1", text) 101 | text = re.sub(r"DOTMULTI\.", r"DOTDOTMULTI", text) 102 | 103 | # multidashes stay together 104 | text = re.sub(r"\-([\-]+)", r" DASHMULTI\1", text) 105 | while re.search("DASHMULTI\-", text): 106 | text = re.sub(r"DASHMULTI\-([^\-])", r"DASHDASHMULTI \1", text) 107 | text = re.sub(r"DASHMULTI\-", r"DASHDASHMULTI", text) 108 | 109 | # Separate ',' except if within number. 110 | text = re.sub(r"(\D),(\D)", r'\1 , \2', text) 111 | # Separate ',' pre and post number. 112 | text = re.sub(r"(\d),(\D)", r'\1 , \2', text) 113 | text = re.sub(r"(\D),(\d)", r'\1 , \2', text) 114 | 115 | if self.language == "en": 116 | text = re.sub(r"([^a-zA-Z])'([^a-zA-Z])", r"\1 ' \2", text) 117 | text = re.sub(r"(\W)'([a-zA-Z])", r"\1 ' \2", text) 118 | text = re.sub(r"([a-zA-Z])'([^a-zA-Z])", r"\1 ' \2", text) 119 | text = re.sub(r"([a-zA-Z])'([a-zA-Z])", r"\1 '\2", text) 120 | text = re.sub(r"(\d)'(s)", r"\1 '\2", text) 121 | text = re.sub(r" '\s+s ", r" 's ", text) 122 | text = re.sub(r" '\s+s ", r" 's ", text) 123 | elif self.language == "fr": 124 | text = re.sub(r"([^a-zA-Z])'([^a-zA-Z])", r"\1 ' \2", text) 125 | text = re.sub(r"([^a-zA-Z])'([a-zA-Z])", r"\1 ' \2", text) 126 | text = re.sub(r"([a-zA-Z])'([^a-zA-Z])", r"\1 ' \2", text) 127 | text = re.sub(r"([a-zA-Z])'([a-zA-Z])", r"\1' \2", text) 128 | else: 129 | text = re.sub(r"'", r" ' ") 130 | 131 | # re-combine single quotes 132 | text = re.sub(r"' '", r"''", text) 133 | 134 | words = text.split() 135 | text = '' 136 | for i, word in enumerate(words): 137 | m = re.match("^(\S+)\.$", word) 138 | if m: 139 | pre = m.group(1) 140 | if ((re.search("\.", pre) and re.search("[a-zA-Z]", pre)) or \ 141 | (pre in self.nonbreaking_prefixes[self.language]) or \ 142 | ((i < len(words)-1) and re.match("^\d+", words[i+1]))): 143 | pass # do nothing 144 | elif ((pre in self.nonbreaking_prefixes_numeric[self.language] ) and \ 145 | (i < len(words)-1) and re.match("\d+", words[i+1])): 146 | pass # do nothing 147 | else: 148 | word = pre + " ." 149 | 150 | text += word + " " 151 | text = re.sub(r"'\s+'", r"''", text) 152 | 153 | # restore multidots 154 | while re.search("DOTDOTMULTI", text): 155 | text = re.sub(r"DOTDOTMULTI", r"DOTMULTI.", text) 156 | text = re.sub(r"DOTMULTI", r".", text) 157 | 158 | # restore multidashes 159 | while re.search("DASHDASHMULTI", text): 160 | text = re.sub(r"DASHDASHMULTI", r"DASHMULTI-", text) 161 | text = re.sub(r"DASHMULTI", r"-", text) 162 | text = re.sub(r"-CONTRACT-", r"'", text) 163 | 164 | return text.split() 165 | 166 | 167 | def tokenize_all(self,sentences, ptb=False): 168 | return [self.tokenize(t, ptb) for t in sentences] 169 | 170 | # starting point 171 | if __name__ == "__main__": 172 | tokenizer = PTBTokenizer() 173 | for line in sys.stdin: 174 | line = line.decode("utf8") 175 | tokens = tokenizer.tokenize(line.strip()) 176 | out = ' '.join(tokens) 177 | print out.encode("utf8") 178 | -------------------------------------------------------------------------------- /tools/m2scorer/scripts/combiner.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # This file is part of the NUS M2 scorer. 4 | # The NUS M2 scorer is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | 9 | # The NUS M2 scorer is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program. If not, see . 16 | 17 | # file: m2scorer.py 18 | # 19 | # score a system's output against a gold reference 20 | # 21 | # Usage: m2scorer.py [OPTIONS] proposed_sentences source_gold 22 | # where 23 | # proposed_sentences - system output, sentence per line 24 | # source_gold - source sentences with gold token edits 25 | # OPTIONS 26 | # -v --verbose - print verbose output 27 | # --very_verbose - print lots of verbose output 28 | # --max_unchanged_words N - Maximum unchanged words when extracting edits. Default 2." 29 | # --ignore_whitespace_casing - Ignore edits that only affect whitespace and caseing. Default no." 30 | # 31 | 32 | import sys 33 | import levenshtein 34 | from getopt import getopt 35 | from util import paragraphs 36 | from util import smart_open 37 | 38 | 39 | 40 | def load_annotation(gold_file): 41 | source_sentences = [] 42 | gold_edits = [] 43 | fgold = smart_open(gold_file, 'r') 44 | puffer = fgold.read() 45 | fgold.close() 46 | puffer = puffer.decode('utf8') 47 | for item in paragraphs(puffer.splitlines(True)): 48 | item = item.splitlines(False) 49 | sentence = [line[2:].strip() for line in item if line.startswith('S ')] 50 | assert sentence != [] 51 | annotations = {} 52 | for line in item[1:]: 53 | if line.startswith('I ') or line.startswith('S '): 54 | continue 55 | assert line.startswith('A ') 56 | line = line[2:] 57 | fields = line.split('|||') 58 | start_offset = int(fields[0].split()[0]) 59 | end_offset = int(fields[0].split()[1]) 60 | etype = fields[1] 61 | if etype == 'noop': 62 | start_offset = -1 63 | end_offset = -1 64 | corrections = [c.strip() if c != '-NONE-' else '' for c in fields[2].split('||')] 65 | # NOTE: start and end are *token* offsets 66 | original = ' '.join(' '.join(sentence).split()[start_offset:end_offset]) 67 | annotator = int(fields[5]) 68 | if annotator not in annotations.keys(): 69 | annotations[annotator] = [] 70 | annotations[annotator].append((start_offset, end_offset, original, corrections)) 71 | tok_offset = 0 72 | for this_sentence in sentence: 73 | tok_offset += len(this_sentence.split()) 74 | source_sentences.append(this_sentence) 75 | this_edits = {} 76 | for annotator, annotation in annotations.iteritems(): 77 | this_edits[annotator] = [edit for edit in annotation if edit[0] <= tok_offset and edit[1] <= tok_offset and edit[0] >= 0 and edit[1] >= 0] 78 | if len(this_edits) == 0: 79 | this_edits[0] = [] 80 | gold_edits.append(this_edits) 81 | return (source_sentences, gold_edits) 82 | 83 | 84 | def print_usage(): 85 | print >> sys.stderr, "Usage: m2scorer.py [OPTIONS] proposed_sentences gold_source" 86 | print >> sys.stderr, "where" 87 | print >> sys.stderr, " proposed_sentences - system output, sentence per line" 88 | print >> sys.stderr, " source_gold - source sentences with gold token edits" 89 | print >> sys.stderr, "OPTIONS" 90 | print >> sys.stderr, " -v --verbose - print verbose output" 91 | print >> sys.stderr, " --very_verbose - print lots of verbose output" 92 | print >> sys.stderr, " --max_unchanged_words N - Maximum unchanged words when extraction edit. Default 2." 93 | print >> sys.stderr, " --ignore_whitespace_casing - Ignore edits that only affect whitespace and caseing. Default no." 94 | 95 | 96 | 97 | max_unchanged_words=2 98 | ignore_whitespace_casing= False 99 | verbose = False 100 | very_verbose = False 101 | opts, args = getopt(sys.argv[1:], "v", ["max_unchanged_words=", "verbose", "ignore_whitespace_casing", "very_verbose"]) 102 | for o, v in opts: 103 | if o in ('-v', '--verbose'): 104 | verbose = True 105 | elif o == '--very_verbose': 106 | very_verbose = True 107 | elif o == '--max_unchanged_words': 108 | max_unchanged_words = int(v) 109 | elif o == '--ignore_whitespace_casing': 110 | ignore_whitespace_casing = True 111 | else: 112 | print >> sys.stderr, "Unknown option :", o 113 | print_usage() 114 | sys.exit(-1) 115 | 116 | 117 | -------------------------------------------------------------------------------- /tools/m2scorer/scripts/levenshtein.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # This file is part of the NUS M2 scorer. 4 | # The NUS M2 scorer is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | 9 | # The NUS M2 scorer is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program. If not, see . 16 | 17 | # file: levenshtein.py 18 | 19 | from optparse import OptionParser 20 | from itertools import izip 21 | from util import uniq 22 | import re 23 | import sys 24 | from copy import deepcopy 25 | 26 | # batch evaluation of a list of sentences 27 | def batch_precision(candidates, sources, gold_edits, max_unchanged_words=2, beta=0.5, ignore_whitespace_casing=False, verbose=False): 28 | return batch_pre_rec_f1(candidates, sources, gold_edits, max_unchanged_words, beta, ignore_whitespace_casing, verbose)[0] 29 | 30 | def batch_recall(candidates, sources, gold_edits, max_unchanged_words=2, beta=0.5, ignore_whitespace_casing=False, verbose=False): 31 | return batch_pre_rec_f1(candidates, sources, gold_edits, max_unchanged_words, beta, ignore_whitespace_casing, verbose)[1] 32 | 33 | def batch_f1(candidates, sources, gold_edits, max_unchanged_words=2, beta=0.5, ignore_whitespace_casing=False, verbose=False): 34 | return batch_pre_rec_f1(candidates, sources, gold_edits, max_unchanged_words, beta, ignore_whitespace_casing, verbose)[2] 35 | 36 | def comp_p(a, b): 37 | try: 38 | p = a / b 39 | except ZeroDivisionError: 40 | p = 1.0 41 | return p 42 | 43 | def comp_r(c, g): 44 | try: 45 | r = c / g 46 | except ZeroDivisionError: 47 | r = 1.0 48 | return r 49 | 50 | def comp_f1(c, e, g, b): 51 | try: 52 | f = (1+b*b) * c / (b*b*g+e) 53 | #f = 2 * c / (g+e) 54 | except ZeroDivisionError: 55 | if c == 0.0: 56 | f = 1.0 57 | else: 58 | f = 0.0 59 | return f 60 | 61 | def f1_suffstats(candidate, source, gold_edits, max_unchanged_words=2, ignore_whitespace_casing= False, verbose=False, very_verbose=False): 62 | stat_correct = 0.0 63 | stat_proposed = 0.0 64 | stat_gold = 0.0 65 | 66 | candidate_tok = candidate.split() 67 | source_tok = source.split() 68 | lmatrix, backpointers = levenshtein_matrix(source_tok, candidate_tok) 69 | V, E, dist, edits = edit_graph(lmatrix, backpointers) 70 | if very_verbose: 71 | print "edit matrix:", lmatrix 72 | print "backpointers:", backpointers 73 | print "edits (w/o transitive arcs):", edits 74 | V, E, dist, edits = transitive_arcs(V, E, dist, edits, max_unchanged_words, very_verbose) 75 | dist = set_weights(E, dist, edits, gold_edits, very_verbose) 76 | editSeq = best_edit_seq_bf(V, E, dist, edits, very_verbose) 77 | if very_verbose: 78 | print "Graph(V,E) = " 79 | print "V =", V 80 | print "E =", E 81 | print "edits (with transitive arcs):", edits 82 | print "dist() =", dist 83 | print "viterbi path =", editSeq 84 | if ignore_whitespace_casing: 85 | editSeq = filter(lambda x : not equals_ignore_whitespace_casing(x[2], x[3]), editSeq) 86 | correct = matchSeq(editSeq, gold_edits, ignore_whitespace_casing) 87 | stat_correct = len(correct) 88 | stat_proposed = len(editSeq) 89 | stat_gold = len(gold_edits) 90 | if verbose: 91 | print "SOURCE :", source.encode("utf8") 92 | print "HYPOTHESIS :", candidate.encode("utf8") 93 | print "EDIT SEQ :", list(reversed(editSeq)) 94 | print "GOLD EDITS :", gold_edits 95 | print "CORRECT EDITS :", correct 96 | print "# correct :", int(stat_correct) 97 | print "# proposed :", int(stat_proposed) 98 | print "# gold :", int(stat_gold) 99 | print "-------------------------------------------" 100 | return (stat_correct, stat_proposed, stat_gold) 101 | 102 | def batch_multi_pre_rec_f1(candidates, sources, gold_edits, max_unchanged_words=2, beta=0.5, ignore_whitespace_casing= False, verbose=False, very_verbose=False): 103 | print len(candidates) 104 | print len(sources) 105 | print len(gold_edits) 106 | assert len(candidates) == len(sources) == len(gold_edits) 107 | stat_correct = 0.0 108 | stat_proposed = 0.0 109 | stat_gold = 0.0 110 | i = 0 111 | for candidate, source, golds_set in zip(candidates, sources, gold_edits): 112 | i = i + 1 113 | # Candidate system edit extraction 114 | candidate_tok = candidate.split() 115 | source_tok = source.split() 116 | #lmatrix, backpointers = levenshtein_matrix(source_tok, candidate_tok) 117 | lmatrix1, backpointers1 = levenshtein_matrix(source_tok, candidate_tok, 1, 1, 1) 118 | lmatrix2, backpointers2 = levenshtein_matrix(source_tok, candidate_tok, 1, 1, 2) 119 | 120 | #V, E, dist, edits = edit_graph(lmatrix, backpointers) 121 | V1, E1, dist1, edits1 = edit_graph(lmatrix1, backpointers1) 122 | V2, E2, dist2, edits2 = edit_graph(lmatrix2, backpointers2) 123 | 124 | V, E, dist, edits = merge_graph(V1, V2, E1, E2, dist1, dist2, edits1, edits2) 125 | if very_verbose: 126 | print "edit matrix 1:", lmatrix1 127 | print "edit matrix 2:", lmatrix2 128 | print "backpointers 1:", backpointers1 129 | print "backpointers 2:", backpointers2 130 | print "edits (w/o transitive arcs):", edits 131 | V, E, dist, edits = transitive_arcs(V, E, dist, edits, max_unchanged_words, very_verbose) 132 | 133 | # Find measures maximizing current cumulative F1; local: curent annotator only 134 | sqbeta = beta * beta 135 | chosen_ann = -1 136 | f1_max = -1.0 137 | 138 | argmax_correct = 0.0 139 | argmax_proposed = 0.0 140 | argmax_gold = 0.0 141 | max_stat_correct = -1.0 142 | min_stat_proposed = float("inf") 143 | min_stat_gold = float("inf") 144 | for annotator, gold in golds_set.iteritems(): 145 | localdist = set_weights(E, dist, edits, gold, verbose, very_verbose) 146 | editSeq = best_edit_seq_bf(V, E, localdist, edits, very_verbose) 147 | if verbose: 148 | print ">> Annotator:", annotator 149 | if very_verbose: 150 | print "Graph(V,E) = " 151 | print "V =", V 152 | print "E =", E 153 | print "edits (with transitive arcs):", edits 154 | print "dist() =", localdist 155 | print "viterbi path =", editSeq 156 | if ignore_whitespace_casing: 157 | editSeq = filter(lambda x : not equals_ignore_whitespace_casing(x[2], x[3]), editSeq) 158 | correct = matchSeq(editSeq, gold, ignore_whitespace_casing, verbose) 159 | 160 | # local cumulative counts, P, R and F1 161 | stat_correct_local = stat_correct + len(correct) 162 | stat_proposed_local = stat_proposed + len(editSeq) 163 | stat_gold_local = stat_gold + len(gold) 164 | p_local = comp_p(stat_correct_local, stat_proposed_local) 165 | r_local = comp_r(stat_correct_local, stat_gold_local) 166 | f1_local = comp_f1(stat_correct_local, stat_proposed_local, stat_gold_local, beta) 167 | 168 | if f1_max < f1_local or \ 169 | (f1_max == f1_local and max_stat_correct < stat_correct_local) or \ 170 | (f1_max == f1_local and max_stat_correct == stat_correct_local and min_stat_proposed + sqbeta * min_stat_gold > stat_proposed_local + sqbeta * stat_gold_local): 171 | chosen_ann = annotator 172 | f1_max = f1_local 173 | max_stat_correct = stat_correct_local 174 | min_stat_proposed = stat_proposed_local 175 | min_stat_gold = stat_gold_local 176 | argmax_correct = len(correct) 177 | argmax_proposed = len(editSeq) 178 | argmax_gold = len(gold) 179 | 180 | if verbose: 181 | print "SOURCE :", source.encode("utf8") 182 | print "HYPOTHESIS :", candidate.encode("utf8") 183 | print "EDIT SEQ :", [shrinkEdit(ed) for ed in list(reversed(editSeq))] 184 | print "GOLD EDITS :", gold 185 | print "CORRECT EDITS :", correct 186 | print "# correct :", int(stat_correct_local) 187 | print "# proposed :", int(stat_proposed_local) 188 | print "# gold :", int(stat_gold_local) 189 | print "precision :", p_local 190 | print "recall :", r_local 191 | print "f_%.1f :" % beta, f1_local 192 | print "-------------------------------------------" 193 | if verbose: 194 | print ">> Chosen Annotator for line", i, ":", chosen_ann 195 | print "" 196 | stat_correct += argmax_correct 197 | stat_proposed += argmax_proposed 198 | stat_gold += argmax_gold 199 | 200 | try: 201 | p = stat_correct / stat_proposed 202 | except ZeroDivisionError: 203 | p = 1.0 204 | 205 | try: 206 | r = stat_correct / stat_gold 207 | except ZeroDivisionError: 208 | r = 1.0 209 | try: 210 | f1 = (1.0+beta*beta) * p * r / (beta*beta*p+r) 211 | except ZeroDivisionError: 212 | f1 = 0.0 213 | if verbose: 214 | print "CORRECT EDITS :", int(stat_correct) 215 | print "PROPOSED EDITS :", int(stat_proposed) 216 | print "GOLD EDITS :", int(stat_gold) 217 | print "P =", p 218 | print "R =", r 219 | print "F_%.1f =" % beta, f1 220 | return (p, r, f1) 221 | 222 | 223 | def batch_pre_rec_f1(candidates, sources, gold_edits, max_unchanged_words=2, beta=0.5, ignore_whitespace_casing= False, verbose=False, very_verbose=False): 224 | assert len(candidates) == len(sources) == len(gold_edits) 225 | stat_correct = 0.0 226 | stat_proposed = 0.0 227 | stat_gold = 0.0 228 | for candidate, source, gold in zip(candidates, sources, gold_edits): 229 | candidate_tok = candidate.split() 230 | source_tok = source.split() 231 | lmatrix, backpointers = levenshtein_matrix(source_tok, candidate_tok) 232 | V, E, dist, edits = edit_graph(lmatrix, backpointers) 233 | if very_verbose: 234 | print "edit matrix:", lmatrix 235 | print "backpointers:", backpointers 236 | print "edits (w/o transitive arcs):", edits 237 | V, E, dist, edits = transitive_arcs(V, E, dist, edits, max_unchanged_words, very_verbose) 238 | dist = set_weights(E, dist, edits, gold, verbose, very_verbose) 239 | editSeq = best_edit_seq_bf(V, E, dist, edits, very_verbose) 240 | if very_verbose: 241 | print "Graph(V,E) = " 242 | print "V =", V 243 | print "E =", E 244 | print "edits (with transitive arcs):", edits 245 | print "dist() =", dist 246 | print "viterbi path =", editSeq 247 | if ignore_whitespace_casing: 248 | editSeq = filter(lambda x : not equals_ignore_whitespace_casing(x[2], x[3]), editSeq) 249 | correct = matchSeq(editSeq, gold, ignore_whitespace_casing) 250 | stat_correct += len(correct) 251 | stat_proposed += len(editSeq) 252 | stat_gold += len(gold) 253 | if verbose: 254 | print "SOURCE :", source.encode("utf8") 255 | print "HYPOTHESIS :", candidate.encode("utf8") 256 | print "EDIT SEQ :", list(reversed(editSeq)) 257 | print "GOLD EDITS :", gold 258 | print "CORRECT EDITS :", correct 259 | print "# correct :", stat_correct 260 | print "# proposed :", stat_proposed 261 | print "# gold :", stat_gold 262 | print "precision :", comp_p(stat_correct, stat_proposed) 263 | print "recall :", comp_r(stat_correct, stat_gold) 264 | print "f_%.1f :" % beta, comp_f1(stat_correct, stat_proposed, stat_gold, beta) 265 | print "-------------------------------------------" 266 | 267 | try: 268 | p = stat_correct / stat_proposed 269 | except ZeroDivisionError: 270 | p = 1.0 271 | 272 | try: 273 | r = stat_correct / stat_gold 274 | except ZeroDivisionError: 275 | r = 1.0 276 | try: 277 | f1 = (1.0+beta*beta) * p * r / (beta*beta*p+r) 278 | #f1 = 2.0 * p * r / (p+r) 279 | except ZeroDivisionError: 280 | f1 = 0.0 281 | if verbose: 282 | print "CORRECT EDITS :", stat_correct 283 | print "PROPOSED EDITS :", stat_proposed 284 | print "GOLD EDITS :", stat_gold 285 | print "P =", p 286 | print "R =", r 287 | print "F_%.1f =" % beta, f1 288 | return (p, r, f1) 289 | 290 | # precision, recall, F1 291 | def precision(candidate, source, gold_edits, max_unchanged_words=2, beta=0.5, verbose=False): 292 | return pre_rec_f1(candidate, source, gold_edits, max_unchanged_words, beta, verbose)[0] 293 | 294 | def recall(candidate, source, gold_edits, max_unchanged_words=2, beta=0.5, verbose=False): 295 | return pre_rec_f1(candidate, source, gold_edits, max_unchanged_words, beta, verbose)[1] 296 | 297 | def f1(candidate, source, gold_edits, max_unchanged_words=2, beta=0.5, verbose=False): 298 | return pre_rec_f1(candidate, source, gold_edits, max_unchanged_words, beta, verbose)[2] 299 | 300 | def shrinkEdit(edit): 301 | shrunkEdit = deepcopy(edit) 302 | origtok = edit[2].split() 303 | corrtok = edit[3].split() 304 | i = 0 305 | cstart = 0 306 | cend = len(corrtok) 307 | found = False 308 | while i < min(len(origtok), len(corrtok)) and not found: 309 | if origtok[i] != corrtok[i]: 310 | found = True 311 | else: 312 | cstart += 1 313 | i += 1 314 | j = 1 315 | found = False 316 | while j <= min(len(origtok), len(corrtok)) - cstart and not found: 317 | if origtok[len(origtok) - j] != corrtok[len(corrtok) - j]: 318 | found = True 319 | else: 320 | cend -= 1 321 | j += 1 322 | shrunkEdit = (edit[0] + i, edit[1] - (j-1), ' '.join(origtok[i : len(origtok)-(j-1)]), ' '.join(corrtok[i : len(corrtok)-(j-1)])) 323 | return shrunkEdit 324 | 325 | def matchSeq(editSeq, gold_edits, ignore_whitespace_casing= False, verbose=False): 326 | m = [] 327 | goldSeq = deepcopy(gold_edits) 328 | last_index = 0 329 | CInsCDel = False 330 | CInsWDel = False 331 | CDelWIns = False 332 | for e in reversed(editSeq): 333 | for i in range(last_index, len(goldSeq)): 334 | g = goldSeq[i] 335 | if matchEdit(e,g, ignore_whitespace_casing): 336 | m.append(e) 337 | last_index = i+1 338 | if verbose: 339 | nextEditList = [shrinkEdit(edit) for edit in editSeq if e[1] == edit[0]] 340 | prevEditList = [shrinkEdit(edit) for edit in editSeq if e[0] == edit[1]] 341 | 342 | if e[0] != e[1]: 343 | nextEditList = [edit for edit in nextEditList if edit[0] == edit[1]] 344 | prevEditList = [edit for edit in prevEditList if edit[0] == edit[1]] 345 | else: 346 | nextEditList = [edit for edit in nextEditList if edit[0] < edit[1] and edit[3] == ''] 347 | prevEditList = [edit for edit in prevEditList if edit[0] < edit[1] and edit[3] == ''] 348 | 349 | matchAdj = any(any(matchEdit(edit, gold, ignore_whitespace_casing) for gold in goldSeq) for edit in nextEditList) or \ 350 | any(any(matchEdit(edit, gold, ignore_whitespace_casing) for gold in goldSeq) for edit in prevEditList) 351 | if e[0] < e[1] and len(e[3].strip()) == 0 and \ 352 | (len(nextEditList) > 0 or len(prevEditList) > 0): 353 | if matchAdj: 354 | print "!", e 355 | else: 356 | print "&", e 357 | elif e[0] == e[1] and \ 358 | (len(nextEditList) > 0 or len(prevEditList) > 0): 359 | if matchAdj: 360 | print "!", e 361 | else: 362 | print "*", e 363 | return m 364 | 365 | def matchEdit(e, g, ignore_whitespace_casing= False): 366 | # start offset 367 | if e[0] != g[0]: 368 | return False 369 | # end offset 370 | if e[1] != g[1]: 371 | return False 372 | # original string 373 | if e[2] != g[2]: 374 | return False 375 | # correction string 376 | if not e[3] in g[3]: 377 | return False 378 | # all matches 379 | return True 380 | 381 | def equals_ignore_whitespace_casing(a,b): 382 | return a.replace(" ", "").lower() == b.replace(" ", "").lower() 383 | 384 | 385 | def get_edits(candidate, source, gold_edits, max_unchanged_words=2, ignore_whitespace_casing= False, verbose=False, very_verbose=False): 386 | candidate_tok = candidate.split() 387 | source_tok = source.split() 388 | lmatrix, backpointers = levenshtein_matrix(source_tok, candidate_tok) 389 | V, E, dist, edits = edit_graph(lmatrix, backpointers) 390 | V, E, dist, edits = transitive_arcs(V, E, dist, edits, max_unchanged_words, very_verbose) 391 | dist = set_weights(E, dist, edits, gold_edits, verbose, very_verbose) 392 | editSeq = best_edit_seq_bf(V, E, dist, edits) 393 | if ignore_whitespace_casing: 394 | editSeq = filter(lambda x : not equals_ignore_whitespace_casing(x[2], x[3]), editSeq) 395 | correct = matchSeq(editSeq, gold_edits) 396 | return (correct, editSeq, gold_edits) 397 | 398 | def pre_rec_f1(candidate, source, gold_edits, max_unchanged_words=2, beta=0.5, ignore_whitespace_casing= False, verbose=False, very_verbose=False): 399 | candidate_tok = candidate.split() 400 | source_tok = source.split() 401 | lmatrix, backpointers = levenshtein_matrix(source_tok, candidate_tok) 402 | V, E, dist, edits = edit_graph(lmatrix, backpointers) 403 | V, E, dist, edits = transitive_arcs(V, E, dist, edits, max_unchanged_words, very_verbose) 404 | dist = set_weights(E, dist, edits, gold_edits, verbose, very_verbose) 405 | editSeq = best_edit_seq_bf(V, E, dist, edits) 406 | if ignore_whitespace_casing: 407 | editSeq = filter(lambda x : not equals_ignore_whitespace_casing(x[2], x[3]), editSeq) 408 | correct = matchSeq(editSeq, gold_edits) 409 | try: 410 | p = float(len(correct)) / len(editSeq) 411 | except ZeroDivisionError: 412 | p = 1.0 413 | try: 414 | r = float(len(correct)) / len(gold_edits) 415 | except ZeroDivisionError: 416 | r = 1.0 417 | try: 418 | f1 = (1.0+beta*beta) * p * r / (beta*beta*p+r) 419 | #f1 = 2.0 * p * r / (p+r) 420 | except ZeroDivisionError: 421 | f1 = 0.0 422 | if verbose: 423 | print "Source:", source.encode("utf8") 424 | print "Hypothesis:", candidate.encode("utf8") 425 | print "edit seq", editSeq 426 | print "gold edits", gold_edits 427 | print "correct edits", correct 428 | print "p =", p 429 | print "r =", r 430 | print "f_%.1f =" % beta, f1 431 | return (p, r, f1) 432 | 433 | # distance function 434 | def get_distance(dist, v1, v2): 435 | try: 436 | return dist[(v1, v2)] 437 | except KeyError: 438 | return float('inf') 439 | 440 | 441 | # find maximally matching edit squence through the graph using bellman-ford 442 | def best_edit_seq_bf(V, E, dist, edits, verby_verbose=False): 443 | thisdist = {} 444 | path = {} 445 | for v in V: 446 | thisdist[v] = float('inf') 447 | thisdist[(0,0)] = 0 448 | for i in range(len(V)-1): 449 | for edge in E: 450 | v = edge[0] 451 | w = edge[1] 452 | if thisdist[v] + dist[edge] < thisdist[w]: 453 | thisdist[w] = thisdist[v] + dist[edge] 454 | path[w] = v 455 | # backtrack 456 | v = sorted(V)[-1] 457 | editSeq = [] 458 | while True: 459 | try: 460 | w = path[v] 461 | except KeyError: 462 | break 463 | edit = edits[(w,v)] 464 | if edit[0] != 'noop': 465 | editSeq.append((edit[1], edit[2], edit[3], edit[4])) 466 | v = w 467 | return editSeq 468 | 469 | 470 | # # find maximally matching edit squence through the graph 471 | # def best_edit_seq(V, E, dist, edits, verby_verbose=False): 472 | # thisdist = {} 473 | # path = {} 474 | # for v in V: 475 | # thisdist[v] = float('inf') 476 | # thisdist[(0,0)] = 0 477 | # queue = [(0,0)] 478 | # while len(queue) > 0: 479 | # v = queue[0] 480 | # queue = queue[1:] 481 | # for edge in E: 482 | # if edge[0] != v: 483 | # continue 484 | # w = edge[1] 485 | # if thisdist[v] + dist[edge] < thisdist[w]: 486 | # thisdist[w] = thisdist[v] + dist[edge] 487 | # path[w] = v 488 | # if not w in queue: 489 | # queue.append(w) 490 | # # backtrack 491 | # v = sorted(V)[-1] 492 | # editSeq = [] 493 | # while True: 494 | # try: 495 | # w = path[v] 496 | # except KeyError: 497 | # break 498 | # edit = edits[(w,v)] 499 | # if edit[0] != 'noop': 500 | # editSeq.append((edit[1], edit[2], edit[3], edit[4])) 501 | # v = w 502 | # return editSeq 503 | 504 | def prev_identical_edge(cur, E, edits): 505 | for e in E: 506 | if e[1] == cur[0] and edits[e] == edits[cur]: 507 | return e 508 | return None 509 | 510 | def next_identical_edge(cur, E, edits): 511 | for e in E: 512 | if e[0] == cur[1] and edits[e] == edits[cur]: 513 | return e 514 | return None 515 | 516 | def get_prev_edges(cur, E): 517 | prev = [] 518 | for e in E: 519 | if e[0] == cur[1]: 520 | prev.append(e) 521 | return prev 522 | 523 | def get_next_edges(cur, E): 524 | next = [] 525 | for e in E: 526 | if e[0] == cur[1]: 527 | next.append(e) 528 | return next 529 | 530 | 531 | # set weights on the graph, gold edits edges get negative weight 532 | # other edges get an epsilon weight added 533 | # gold_edits = (start, end, original, correction) 534 | def set_weights(E, dist, edits, gold_edits, verbose=False, very_verbose=False): 535 | EPSILON = 0.001 536 | if very_verbose: 537 | print "set weights of edges()", 538 | print "gold edits :", gold_edits 539 | 540 | gold_set = deepcopy(gold_edits) 541 | retdist = deepcopy(dist) 542 | 543 | M = {} 544 | G = {} 545 | for edge in E: 546 | tE = edits[edge] 547 | s, e = tE[1], tE[2] 548 | if (s, e) not in M: 549 | M[(s,e)] = [] 550 | M[(s,e)].append(edge) 551 | if (s, e) not in G: 552 | G[(s,e)] = [] 553 | 554 | for gold in gold_set: 555 | s, e = gold[0], gold[1] 556 | if (s, e) not in G: 557 | G[(s,e)] = [] 558 | G[(s,e)].append(gold) 559 | 560 | for k in sorted(M.keys()): 561 | M[k] = sorted(M[k]) 562 | 563 | if k[0] == k[1]: # insertion case 564 | lptr = 0 565 | rptr = len(M[k])-1 566 | cur = lptr 567 | 568 | g_lptr = 0 569 | g_rptr = len(G[k])-1 570 | 571 | while lptr <= rptr: 572 | hasGoldMatch = False 573 | edge = M[k][cur] 574 | thisEdit = edits[edge] 575 | # only check start offset, end offset, original string, corrections 576 | if very_verbose: 577 | print "set weights of edge", edge 578 | print "edit =", thisEdit 579 | 580 | cur_gold = [] 581 | if cur == lptr: 582 | cur_gold = range(g_lptr, g_rptr+1) 583 | else: 584 | cur_gold = reversed(range(g_lptr, g_rptr+1)) 585 | 586 | for i in cur_gold: 587 | gold = G[k][i] 588 | if thisEdit[1] == gold[0] and \ 589 | thisEdit[2] == gold[1] and \ 590 | thisEdit[3] == gold[2] and \ 591 | thisEdit[4] in gold[3]: 592 | hasGoldMatch = True 593 | retdist[edge] = - len(E) 594 | if very_verbose: 595 | print "matched gold edit :", gold 596 | print "set weight to :", retdist[edge] 597 | if cur == lptr: 598 | #g_lptr += 1 # why? 599 | g_lptr = i + 1 600 | else: 601 | #g_rptr -= 1 # why? 602 | g_rptr = i - 1 603 | break 604 | 605 | if not hasGoldMatch and thisEdit[0] != 'noop': 606 | retdist[edge] += EPSILON 607 | if hasGoldMatch: 608 | if cur == lptr: 609 | lptr += 1 610 | while lptr < len(M[k]) and M[k][lptr][0] != M[k][cur][1]: 611 | if edits[M[k][lptr]] != 'noop': 612 | retdist[M[k][lptr]] += EPSILON 613 | lptr += 1 614 | cur = lptr 615 | else: 616 | rptr -= 1 617 | while rptr >= 0 and M[k][rptr][1] != M[k][cur][0]: 618 | if edits[M[k][rptr]] != 'noop': 619 | retdist[M[k][rptr]] += EPSILON 620 | rptr -= 1 621 | cur = rptr 622 | else: 623 | if cur == lptr: 624 | lptr += 1 625 | cur = rptr 626 | else: 627 | rptr -= 1 628 | cur = lptr 629 | else: #deletion or substitution, don't care about order, no harm if setting parallel edges weight < 0 630 | for edge in M[k]: 631 | hasGoldMatch = False 632 | thisEdit = edits[edge] 633 | if very_verbose: 634 | print "set weights of edge", edge 635 | print "edit =", thisEdit 636 | for gold in G[k]: 637 | if thisEdit[1] == gold[0] and \ 638 | thisEdit[2] == gold[1] and \ 639 | thisEdit[3] == gold[2] and \ 640 | thisEdit[4] in gold[3]: 641 | hasGoldMatch = True 642 | retdist[edge] = - len(E) 643 | if very_verbose: 644 | print "matched gold edit :", gold 645 | print "set weight to :", retdist[edge] 646 | break 647 | if not hasGoldMatch and thisEdit[0] != 'noop': 648 | retdist[edge] += EPSILON 649 | return retdist 650 | 651 | # add transitive arcs 652 | def transitive_arcs(V, E, dist, edits, max_unchanged_words=2, very_verbose=False): 653 | if very_verbose: 654 | print "-- Add transitive arcs --" 655 | for k in range(len(V)): 656 | vk = V[k] 657 | if very_verbose: 658 | print "v _k :", vk 659 | 660 | for i in range(len(V)): 661 | vi = V[i] 662 | if very_verbose: 663 | print "v _i :", vi 664 | try: 665 | eik = edits[(vi, vk)] 666 | except KeyError: 667 | continue 668 | for j in range(len(V)): 669 | vj = V[j] 670 | if very_verbose: 671 | print "v _j :", vj 672 | try: 673 | ekj = edits[(vk, vj)] 674 | except KeyError: 675 | continue 676 | dik = get_distance(dist, vi, vk) 677 | dkj = get_distance(dist, vk, vj) 678 | if dik + dkj < get_distance(dist, vi, vj): 679 | eij = merge_edits(eik, ekj) 680 | if eij[-1] <= max_unchanged_words: 681 | if very_verbose: 682 | print " add new arcs v_i -> v_j:", eij 683 | E.append((vi, vj)) 684 | dist[(vi, vj)] = dik + dkj 685 | edits[(vi, vj)] = eij 686 | # remove noop transitive arcs 687 | if very_verbose: 688 | print "-- Remove transitive noop arcs --" 689 | for edge in E: 690 | e = edits[edge] 691 | if e[0] == 'noop' and dist[edge] > 1: 692 | if very_verbose: 693 | print " remove noop arc v_i -> vj:", edge 694 | E.remove(edge) 695 | dist[edge] = float('inf') 696 | del edits[edge] 697 | return(V, E, dist, edits) 698 | 699 | 700 | # combine two edits into one 701 | # edit = (type, start, end, orig, correction, #unchanged_words) 702 | def merge_edits(e1, e2, joiner = ' '): 703 | if e1[0] == 'ins': 704 | if e2[0] == 'ins': 705 | e = ('ins', e1[1], e2[2], '', e1[4] + joiner + e2[4], e1[5] + e2[5]) 706 | elif e2[0] == 'del': 707 | e = ('sub', e1[1], e2[2], e2[3], e1[4], e1[5] + e2[5]) 708 | elif e2[0] == 'sub': 709 | e = ('sub', e1[1], e2[2], e2[3], e1[4] + joiner + e2[4], e1[5] + e2[5]) 710 | elif e2[0] == 'noop': 711 | e = ('sub', e1[1], e2[2], e2[3], e1[4] + joiner + e2[4], e1[5] + e2[5]) 712 | elif e1[0] == 'del': 713 | if e2[0] == 'ins': 714 | e = ('sub', e1[1], e2[2], e1[3], e2[4], e1[5] + e2[5]) 715 | elif e2[0] == 'del': 716 | e = ('del', e1[1], e2[2], e1[3] + joiner + e2[3], '', e1[5] + e2[5]) 717 | elif e2[0] == 'sub': 718 | e = ('sub', e1[1], e2[2], e1[3] + joiner + e2[3], e2[4], e1[5] + e2[5]) 719 | elif e2[0] == 'noop': 720 | e = ('sub', e1[1], e2[2], e1[3] + joiner + e2[3], e2[4], e1[5] + e2[5]) 721 | elif e1[0] == 'sub': 722 | if e2[0] == 'ins': 723 | e = ('sub', e1[1], e2[2], e1[3], e1[4] + joiner + e2[4], e1[5] + e2[5]) 724 | elif e2[0] == 'del': 725 | e = ('sub', e1[1], e2[2], e1[3] + joiner + e2[3], e1[4], e1[5] + e2[5]) 726 | elif e2[0] == 'sub': 727 | e = ('sub', e1[1], e2[2], e1[3] + joiner + e2[3], e1[4] + joiner + e2[4], e1[5] + e2[5]) 728 | elif e2[0] == 'noop': 729 | e = ('sub', e1[1], e2[2], e1[3] + joiner + e2[3], e1[4] + joiner + e2[4], e1[5] + e2[5]) 730 | elif e1[0] == 'noop': 731 | if e2[0] == 'ins': 732 | e = ('sub', e1[1], e2[2], e1[3], e1[4] + joiner + e2[4], e1[5] + e2[5]) 733 | elif e2[0] == 'del': 734 | e = ('sub', e1[1], e2[2], e1[3] + joiner + e2[3], e1[4], e1[5] + e2[5]) 735 | elif e2[0] == 'sub': 736 | e = ('sub', e1[1], e2[2], e1[3] + joiner + e2[3], e1[4] + joiner + e2[4], e1[5] + e2[5]) 737 | elif e2[0] == 'noop': 738 | e = ('noop', e1[1], e2[2], e1[3] + joiner + e2[3], e1[4] + joiner + e2[4], e1[5] + e2[5]) 739 | else: 740 | assert False 741 | return e 742 | 743 | # build edit graph 744 | def edit_graph(levi_matrix, backpointers): 745 | V = [] 746 | E = [] 747 | dist = {} 748 | edits = {} 749 | # breath-first search through the matrix 750 | v_start = (len(levi_matrix)-1, len(levi_matrix[0])-1) 751 | queue = [v_start] 752 | while len(queue) > 0: 753 | v = queue[0] 754 | queue = queue[1:] 755 | if v in V: 756 | continue 757 | V.append(v) 758 | try: 759 | for vnext_edits in backpointers[v]: 760 | vnext = vnext_edits[0] 761 | edit_next = vnext_edits[1] 762 | E.append((vnext, v)) 763 | dist[(vnext, v)] = 1 764 | edits[(vnext, v)] = edit_next 765 | if not vnext in queue: 766 | queue.append(vnext) 767 | except KeyError: 768 | pass 769 | return (V, E, dist, edits) 770 | 771 | # merge two lattices, vertices, edges, and distance and edit table 772 | def merge_graph(V1, V2, E1, E2, dist1, dist2, edits1, edits2): 773 | # vertices 774 | V = deepcopy(V1) 775 | for v in V2: 776 | if v not in V: 777 | V.append(v) 778 | V = sorted(V) 779 | 780 | # edges 781 | E = E1 782 | for e in E2: 783 | if e not in V: 784 | E.append(e) 785 | E = sorted(E) 786 | 787 | # distances 788 | dist = deepcopy(dist1) 789 | for k in dist2.keys(): 790 | if k not in dist.keys(): 791 | dist[k] = dist2[k] 792 | else: 793 | if dist[k] != dist2[k]: 794 | print >> sys.stderr, "WARNING: merge_graph: distance does not match!" 795 | dist[k] = min(dist[k], dist2[k]) 796 | 797 | # edit contents 798 | edits = deepcopy(edits1) 799 | for e in edits2.keys(): 800 | if e not in edits.keys(): 801 | edits[e] = edits2[e] 802 | else: 803 | if edits[e] != edits2[e]: 804 | print >> sys.stderr, "WARNING: merge_graph: edit does not match!" 805 | return (V, E, dist, edits) 806 | 807 | # convenience method for levenshtein distance 808 | def levenshtein_distance(first, second): 809 | lmatrix, backpointers = levenshtein_matrix(first, second) 810 | return lmatrix[-1][-1] 811 | 812 | 813 | # levenshtein matrix 814 | def levenshtein_matrix(first, second, cost_ins=1, cost_del=1, cost_sub=2): 815 | #if len(second) == 0 or len(second) == 0: 816 | # return len(first) + len(second) 817 | first_length = len(first) + 1 818 | second_length = len(second) + 1 819 | 820 | # init 821 | distance_matrix = [[None] * second_length for x in range(first_length)] 822 | backpointers = {} 823 | distance_matrix[0][0] = 0 824 | for i in range(1, first_length): 825 | distance_matrix[i][0] = i 826 | edit = ("del", i-1, i, first[i-1], '', 0) 827 | backpointers[(i, 0)] = [((i-1,0), edit)] 828 | for j in range(1, second_length): 829 | distance_matrix[0][j]=j 830 | edit = ("ins", j-1, j-1, '', second[j-1], 0) 831 | backpointers[(0, j)] = [((0,j-1), edit)] 832 | 833 | # fill the matrix 834 | for i in xrange(1, first_length): 835 | for j in range(1, second_length): 836 | deletion = distance_matrix[i-1][j] + cost_del 837 | insertion = distance_matrix[i][j-1] + cost_ins 838 | if first[i-1] == second[j-1]: 839 | substitution = distance_matrix[i-1][j-1] 840 | else: 841 | substitution = distance_matrix[i-1][j-1] + cost_sub 842 | if substitution == min(substitution, deletion, insertion): 843 | distance_matrix[i][j] = substitution 844 | if first[i-1] != second[j-1]: 845 | edit = ("sub", i-1, i, first[i-1], second[j-1], 0) 846 | else: 847 | edit = ("noop", i-1, i, first[i-1], second[j-1], 1) 848 | try: 849 | backpointers[(i, j)].append(((i-1,j-1), edit)) 850 | except KeyError: 851 | backpointers[(i, j)] = [((i-1,j-1), edit)] 852 | if deletion == min(substitution, deletion, insertion): 853 | distance_matrix[i][j] = deletion 854 | edit = ("del", i-1, i, first[i-1], '', 0) 855 | try: 856 | backpointers[(i, j)].append(((i-1,j), edit)) 857 | except KeyError: 858 | backpointers[(i, j)] = [((i-1,j), edit)] 859 | if insertion == min(substitution, deletion, insertion): 860 | distance_matrix[i][j] = insertion 861 | edit = ("ins", i, i, '', second[j-1], 0) 862 | try: 863 | backpointers[(i, j)].append(((i,j-1), edit)) 864 | except KeyError: 865 | backpointers[(i, j)] = [((i,j-1), edit)] 866 | return (distance_matrix, backpointers) 867 | 868 | -------------------------------------------------------------------------------- /tools/m2scorer/scripts/m2scorer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # This file is part of the NUS M2 scorer. 4 | # The NUS M2 scorer is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | 9 | # The NUS M2 scorer is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program. If not, see . 16 | 17 | # file: m2scorer.py 18 | # 19 | # score a system's output against a gold reference 20 | # 21 | # Usage: m2scorer.py [OPTIONS] proposed_sentences source_gold 22 | # where 23 | # proposed_sentences - system output, sentence per line 24 | # source_gold - source sentences with gold token edits 25 | # OPTIONS 26 | # -v --verbose - print verbose output 27 | # --very_verbose - print lots of verbose output 28 | # --max_unchanged_words N - Maximum unchanged words when extracting edits. Default 2." 29 | # --beta B - Beta value for F-measure. Default 0.5." 30 | # --ignore_whitespace_casing - Ignore edits that only affect whitespace and caseing. Default no." 31 | # 32 | 33 | import sys 34 | import levenshtein 35 | from getopt import getopt 36 | from util import paragraphs 37 | from util import smart_open 38 | 39 | 40 | 41 | def load_annotation(gold_file): 42 | source_sentences = [] 43 | gold_edits = [] 44 | fgold = smart_open(gold_file, 'r') 45 | puffer = fgold.read() 46 | fgold.close() 47 | puffer = puffer.decode('utf8') 48 | for item in paragraphs(puffer.splitlines(True)): 49 | item = item.splitlines(False) 50 | sentence = [line[2:].strip() for line in item if line.startswith('S ')] 51 | assert sentence != [] 52 | annotations = {} 53 | for line in item[1:]: 54 | if line.startswith('I ') or line.startswith('S '): 55 | continue 56 | assert line.startswith('A ') 57 | line = line[2:] 58 | fields = line.split('|||') 59 | start_offset = int(fields[0].split()[0]) 60 | end_offset = int(fields[0].split()[1]) 61 | etype = fields[1] 62 | if etype == 'noop': 63 | start_offset = -1 64 | end_offset = -1 65 | corrections = [c.strip() if c != '-NONE-' else '' for c in fields[2].split('||')] 66 | # NOTE: start and end are *token* offsets 67 | original = ' '.join(' '.join(sentence).split()[start_offset:end_offset]) 68 | annotator = int(fields[5]) 69 | if annotator not in annotations.keys(): 70 | annotations[annotator] = [] 71 | annotations[annotator].append((start_offset, end_offset, original, corrections)) 72 | tok_offset = 0 73 | for this_sentence in sentence: 74 | tok_offset += len(this_sentence.split()) 75 | source_sentences.append(this_sentence) 76 | this_edits = {} 77 | for annotator, annotation in annotations.iteritems(): 78 | this_edits[annotator] = [edit for edit in annotation if edit[0] <= tok_offset and edit[1] <= tok_offset and edit[0] >= 0 and edit[1] >= 0] 79 | if len(this_edits) == 0: 80 | this_edits[0] = [] 81 | gold_edits.append(this_edits) 82 | return (source_sentences, gold_edits) 83 | 84 | 85 | def print_usage(): 86 | print >> sys.stderr, "Usage: m2scorer.py [OPTIONS] proposed_sentences gold_source" 87 | print >> sys.stderr, "where" 88 | print >> sys.stderr, " proposed_sentences - system output, sentence per line" 89 | print >> sys.stderr, " source_gold - source sentences with gold token edits" 90 | print >> sys.stderr, "OPTIONS" 91 | print >> sys.stderr, " -v --verbose - print verbose output" 92 | print >> sys.stderr, " --very_verbose - print lots of verbose output" 93 | print >> sys.stderr, " --max_unchanged_words N - Maximum unchanged words when extraction edit. Default 2." 94 | print >> sys.stderr, " --beta B - Beta value for F-measure. Default 0.5." 95 | print >> sys.stderr, " --ignore_whitespace_casing - Ignore edits that only affect whitespace and caseing. Default no." 96 | 97 | 98 | 99 | max_unchanged_words=2 100 | beta = 0.5 101 | ignore_whitespace_casing= False 102 | verbose = False 103 | very_verbose = False 104 | opts, args = getopt(sys.argv[1:], "v", ["max_unchanged_words=", "beta=", "verbose", "ignore_whitespace_casing", "very_verbose"]) 105 | for o, v in opts: 106 | if o in ('-v', '--verbose'): 107 | verbose = True 108 | elif o == '--very_verbose': 109 | very_verbose = True 110 | elif o == '--max_unchanged_words': 111 | max_unchanged_words = int(v) 112 | elif o == '--beta': 113 | beta = float(v) 114 | elif o == '--ignore_whitespace_casing': 115 | ignore_whitespace_casing = True 116 | else: 117 | print >> sys.stderr, "Unknown option :", o 118 | print_usage() 119 | sys.exit(-1) 120 | 121 | # starting point 122 | if len(args) != 2: 123 | print_usage() 124 | sys.exit(-1) 125 | 126 | system_file = args[0] 127 | gold_file = args[1] 128 | 129 | # load source sentences and gold edits 130 | source_sentences, gold_edits = load_annotation(gold_file) 131 | 132 | # load system hypotheses 133 | fin = smart_open(system_file, 'r') 134 | system_sentences = [line.decode("utf8").strip() for line in fin.readlines()] 135 | fin.close() 136 | 137 | p, r, f1 = levenshtein.batch_multi_pre_rec_f1(system_sentences, source_sentences, gold_edits, max_unchanged_words, beta, ignore_whitespace_casing, verbose, very_verbose) 138 | 139 | print "Precision : %.4f" % p 140 | print "Recall : %.4f" % r 141 | print "F_%.1f : %.4f" % (beta, f1) 142 | 143 | -------------------------------------------------------------------------------- /tools/m2scorer/scripts/nuclesgmlparser.py: -------------------------------------------------------------------------------- 1 | # nuclesgmlparser.py 2 | # 3 | # Author: Yuanbin Wu 4 | # National University of Singapore (NUS) 5 | # Date: 12 Mar 2013 6 | # Version: 1.0 7 | # 8 | # Contact: wuyb@comp.nus.edu.sg 9 | # 10 | # This script is distributed to support the CoNLL-2013 Shared Task. 11 | # It is free for research and educational purposes. 12 | 13 | from sgmllib import SGMLParser 14 | from nucle_doc import nucle_doc 15 | 16 | 17 | class nuclesgmlparser(SGMLParser): 18 | def __init__(self): 19 | SGMLParser.__init__(self) 20 | self.docs = [] 21 | 22 | def reset(self): 23 | self.docs = [] 24 | self.data = [] 25 | SGMLParser.reset(self) 26 | 27 | def unknow_starttag(self, tag, attrs): 28 | pass 29 | 30 | def unknow_endtag(self): 31 | pass 32 | 33 | def start_doc(self, attrs): 34 | self.docs.append(nucle_doc()) 35 | self.docs[-1].docattrs = attrs 36 | 37 | def end_doc(self): 38 | pass 39 | 40 | def start_matric(self, attrs): 41 | pass 42 | 43 | def end_matric(self): 44 | self.docs[-1].matric = ''.join(self.data) 45 | self.data = [] 46 | pass 47 | 48 | def start_email(self, attrs): 49 | pass 50 | 51 | def end_email(self): 52 | self.docs[-1].email = ''.join(self.data) 53 | self.data = [] 54 | pass 55 | 56 | def start_nationality(self, attrs): 57 | pass 58 | 59 | def end_nationality(self): 60 | self.docs[-1].nationality = ''.join(self.data) 61 | self.data = [] 62 | pass 63 | 64 | def start_first_language(self, attrs): 65 | pass 66 | 67 | def end_first_language(self): 68 | self.docs[-1].firstLanguage = ''.join(self.data) 69 | self.data = [] 70 | pass 71 | 72 | def start_school_language(self, attrs): 73 | pass 74 | 75 | def end_school_language(self): 76 | self.docs[-1].schoolLanguage = ''.join(self.data) 77 | self.data = [] 78 | pass 79 | 80 | def start_english_tests(self, attrs): 81 | pass 82 | 83 | def end_english_tests(self): 84 | self.docs[-1].englishTests = ''.join(self.data) 85 | self.data = [] 86 | pass 87 | 88 | 89 | def start_text(self, attrs): 90 | pass 91 | 92 | def end_text(self): 93 | pass 94 | 95 | def start_title(self, attrs): 96 | pass 97 | 98 | def end_title(self): 99 | self.docs[-1].paragraphs.append(''.join(self.data)) 100 | self.data = [] 101 | pass 102 | 103 | 104 | def start_p(self, attrs): 105 | pass 106 | 107 | def end_p(self): 108 | self.docs[-1].paragraphs.append(''.join(self.data)) 109 | self.data = [] 110 | pass 111 | 112 | 113 | def start_annotation(self, attrs): 114 | self.docs[-1].annotation.append(attrs) 115 | 116 | def end_annotation(self): 117 | pass 118 | 119 | def start_mistake(self, attrs): 120 | d = {} 121 | for t in attrs: 122 | d[t[0]] = int(t[1]) 123 | self.docs[-1].mistakes.append(d) 124 | pass 125 | 126 | def end_mistake(self): 127 | pass 128 | 129 | def start_type(self, attrs): 130 | pass 131 | 132 | def end_type(self): 133 | self.docs[-1].mistakes[-1]['type'] = ''.join(self.data) 134 | self.data = [] 135 | 136 | def start_correction(self, attrs): 137 | pass 138 | 139 | def end_correction(self): 140 | self.docs[-1].mistakes[-1]['correction'] = ''.join(self.data) 141 | self.data = [] 142 | 143 | def start_comment(self, attrs): 144 | pass 145 | 146 | def end_comment(self): 147 | self.docs[-1].mistakes[-1]['comment'] = ''.join( self.data) 148 | self.data = [] 149 | 150 | 151 | def handle_charref(self, ref): 152 | self.data.append('&' + ref) 153 | 154 | def handle_entityref(self, ref): 155 | self.data.append('&' + ref) 156 | 157 | def handle_data(self, text): 158 | if text.strip() == '': 159 | self.data.append('') 160 | return 161 | else: 162 | if text.startswith('\n'): 163 | text = text[1:] 164 | if text.endswith('\n'): 165 | text = text[:-1] 166 | self.data.append(text) 167 | 168 | 169 | -------------------------------------------------------------------------------- /tools/m2scorer/scripts/token_offsets.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | # This file is part of the NUS M2 scorer. 4 | # The NUS M2 scorer is free software: you can redistribute it and/or modify 5 | # it under the terms of the GNU General Public License as published by 6 | # the Free Software Foundation, either version 3 of the License, or 7 | # (at your option) any later version. 8 | 9 | # The NUS M2 scorer is distributed in the hope that it will be useful, 10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 | # GNU General Public License for more details. 13 | 14 | # You should have received a copy of the GNU General Public License 15 | # along with this program. If not, see . 16 | 17 | # file: token_offsets.py 18 | # convert character to token offsets, tokenize sentence 19 | # 20 | # usage: %prog < input > output 21 | # 22 | 23 | 24 | import sys 25 | import re 26 | import os 27 | from util import * 28 | from Tokenizer import PTBTokenizer 29 | 30 | 31 | assert len(sys.argv) == 1 32 | 33 | 34 | # main 35 | # loop over sentences cum annotation 36 | tokenizer = PTBTokenizer() 37 | sentence = '' 38 | for line in sys.stdin: 39 | line = line.decode("utf8").strip() 40 | if line.startswith("S "): 41 | sentence = line[2:] 42 | sentence_tok = "S " + ' '.join(tokenizer.tokenize(sentence)) 43 | print sentence_tok.encode("utf8") 44 | elif line.startswith("A "): 45 | fields = line[2:].split('|||') 46 | start_end = fields[0] 47 | char_start, char_end = [int(a) for a in start_end.split()] 48 | # calculate token offsets 49 | prefix = sentence[:char_start] 50 | tok_start = len(tokenizer.tokenize(prefix)) 51 | postfix = sentence[:char_end] 52 | tok_end = len(tokenizer.tokenize(postfix)) 53 | start_end = str(tok_start) + " " + str(tok_end) 54 | fields[0] = start_end 55 | # tokenize corrections, remove trailing whitespace 56 | corrections = [(' '.join(tokenizer.tokenize(c))).strip() for c in fields[2].split('||')] 57 | fields[2] = '||'.join(corrections) 58 | annotation = "A " + '|||'.join(fields) 59 | print annotation.encode("utf8") 60 | else: 61 | print line.encode("utf8") 62 | 63 | -------------------------------------------------------------------------------- /tools/m2scorer/scripts/util.py: -------------------------------------------------------------------------------- 1 | # This file is part of the NUS M2 scorer. 2 | # The NUS M2 scorer is free software: you can redistribute it and/or modify 3 | # it under the terms of the GNU General Public License as published by 4 | # the Free Software Foundation, either version 3 of the License, or 5 | # (at your option) any later version. 6 | 7 | # The NUS M2 scorer is distributed in the hope that it will be useful, 8 | # but WITHOUT ANY WARRANTY; without even the implied warranty of 9 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 10 | # GNU General Public License for more details. 11 | 12 | # You should have received a copy of the GNU General Public License 13 | # along with this program. If not, see . 14 | 15 | # file: util.py 16 | # 17 | 18 | import operator 19 | import random 20 | import math 21 | import re 22 | 23 | def smart_open(fname, mode = 'r'): 24 | if fname.endswith('.gz'): 25 | import gzip 26 | # Using max compression (9) by default seems to be slow. 27 | # Let's try using the fastest. 28 | return gzip.open(fname, mode, 1) 29 | else: 30 | return open(fname, mode) 31 | 32 | 33 | def randint(b, a=0): 34 | return random.randint(a,b) 35 | 36 | def uniq(seq, idfun=None): 37 | # order preserving 38 | if idfun is None: 39 | def idfun(x): return x 40 | seen = {} 41 | result = [] 42 | for item in seq: 43 | marker = idfun(item) 44 | # in old Python versions: 45 | # if seen.has_key(marker) 46 | # but in new ones: 47 | if marker in seen: continue 48 | seen[marker] = 1 49 | result.append(item) 50 | return result 51 | 52 | 53 | def sort_dict(myDict, byValue=False, reverse=False): 54 | if byValue: 55 | items = myDict.items() 56 | items.sort(key = operator.itemgetter(1), reverse=reverse) 57 | else: 58 | items = sorted(myDict.items()) 59 | return items 60 | 61 | def max_dict(myDict, byValue=False): 62 | if byValue: 63 | skey=lambda x:x[1] 64 | else: 65 | skey=lambda x:x[0] 66 | return max(myDict.items(), key=skey) 67 | 68 | 69 | def min_dict(myDict, byValue=False): 70 | if byValue: 71 | skey=lambda x:x[1] 72 | else: 73 | skey=lambda x:x[0] 74 | return min(myDict.items(), key=skey) 75 | 76 | def paragraphs(lines, is_separator=lambda x : x == '\n', joiner=''.join): 77 | paragraph = [] 78 | for line in lines: 79 | if is_separator(line): 80 | if paragraph: 81 | yield joiner(paragraph) 82 | paragraph = [] 83 | else: 84 | paragraph.append(line) 85 | if paragraph: 86 | yield joiner(paragraph) 87 | 88 | 89 | def isASCII(word): 90 | try: 91 | word = word.decode("ascii") 92 | return True 93 | except UnicodeEncodeError : 94 | return False 95 | except UnicodeDecodeError: 96 | return False 97 | 98 | 99 | def intersect(x, y): 100 | return [z for z in x if z in y] 101 | 102 | 103 | 104 | # Mapping Windows CP1252 Gremlins to Unicode 105 | # from http://effbot.org/zone/unicode-gremlins.htm 106 | cp1252 = { 107 | # from http://www.microsoft.com/typography/unicode/1252.htm 108 | u"\x80": u"\u20AC", # EURO SIGN 109 | u"\x82": u"\u201A", # SINGLE LOW-9 QUOTATION MARK 110 | u"\x83": u"\u0192", # LATIN SMALL LETTER F WITH HOOK 111 | u"\x84": u"\u201E", # DOUBLE LOW-9 QUOTATION MARK 112 | u"\x85": u"\u2026", # HORIZONTAL ELLIPSIS 113 | u"\x86": u"\u2020", # DAGGER 114 | u"\x87": u"\u2021", # DOUBLE DAGGER 115 | u"\x88": u"\u02C6", # MODIFIER LETTER CIRCUMFLEX ACCENT 116 | u"\x89": u"\u2030", # PER MILLE SIGN 117 | u"\x8A": u"\u0160", # LATIN CAPITAL LETTER S WITH CARON 118 | u"\x8B": u"\u2039", # SINGLE LEFT-POINTING ANGLE QUOTATION MARK 119 | u"\x8C": u"\u0152", # LATIN CAPITAL LIGATURE OE 120 | u"\x8E": u"\u017D", # LATIN CAPITAL LETTER Z WITH CARON 121 | u"\x91": u"\u2018", # LEFT SINGLE QUOTATION MARK 122 | u"\x92": u"\u2019", # RIGHT SINGLE QUOTATION MARK 123 | u"\x93": u"\u201C", # LEFT DOUBLE QUOTATION MARK 124 | u"\x94": u"\u201D", # RIGHT DOUBLE QUOTATION MARK 125 | u"\x95": u"\u2022", # BULLET 126 | u"\x96": u"\u2013", # EN DASH 127 | u"\x97": u"\u2014", # EM DASH 128 | u"\x98": u"\u02DC", # SMALL TILDE 129 | u"\x99": u"\u2122", # TRADE MARK SIGN 130 | u"\x9A": u"\u0161", # LATIN SMALL LETTER S WITH CARON 131 | u"\x9B": u"\u203A", # SINGLE RIGHT-POINTING ANGLE QUOTATION MARK 132 | u"\x9C": u"\u0153", # LATIN SMALL LIGATURE OE 133 | u"\x9E": u"\u017E", # LATIN SMALL LETTER Z WITH CARON 134 | u"\x9F": u"\u0178", # LATIN CAPITAL LETTER Y WITH DIAERESIS 135 | } 136 | 137 | def fix_cp1252codes(text): 138 | # map cp1252 gremlins to real unicode characters 139 | if re.search(u"[\x80-\x9f]", text): 140 | def fixup(m): 141 | s = m.group(0) 142 | return cp1252.get(s, s) 143 | if isinstance(text, type("")): 144 | # make sure we have a unicode string 145 | text = unicode(text, "iso-8859-1") 146 | text = re.sub(u"[\x80-\x9f]", fixup, text) 147 | return text 148 | 149 | def clean_utf8(text): 150 | return filter(lambda x : x > '\x1f' and x < '\x7f', text) 151 | 152 | def pairs(iterable, overlapping=False): 153 | iterator = iterable.__iter__() 154 | token = iterator.next() 155 | i = 0 156 | for lookahead in iterator: 157 | if overlapping or i % 2 == 0: 158 | yield (token, lookahead) 159 | token = lookahead 160 | i += 1 161 | if i % 2 == 0: 162 | yield (token, None) 163 | 164 | def frange(start, end=None, inc=None): 165 | "A range function, that does accept float increments..." 166 | 167 | if end == None: 168 | end = start + 0.0 169 | start = 0.0 170 | 171 | if inc == None: 172 | inc = 1.0 173 | 174 | L = [] 175 | while 1: 176 | next = start + len(L) * inc 177 | if inc > 0 and next >= end: 178 | break 179 | elif inc < 0 and next <= end: 180 | break 181 | L.append(next) 182 | 183 | return L 184 | 185 | def softmax(values): 186 | a = max(values) 187 | Z = 0.0 188 | for v in values: 189 | Z += math.exp(v - a) 190 | sm = [math.exp(v-a) / Z for v in values] 191 | return sm 192 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | # nohup bash train.sh > log 2>&1 & 2 | args=$@ 3 | for arg in $args; do 4 | eval "$arg" 5 | done 6 | 7 | echo "seed: ${seed:=1}" 8 | echo "bert: ${bert:=roberta-large}" 9 | echo "lr1: ${lr1:=5e-5}" 10 | echo "lr2: ${lr2:=5e-6}" 11 | echo "lr3: ${lr3:=1e-6}" 12 | echo "rate1: ${rate1:=10}" 13 | echo "rate2: ${rate2:=10}" 14 | echo "rate3: ${rate3:=10}" 15 | echo "upsampling: ${upsampling:=4}" 16 | echo "batch: ${batch:=100000}" 17 | echo "epochs1: ${epochs1:=64}" 18 | echo "epochs2: ${epochs2:=64}" 19 | echo "epochs3: ${epochs3:=64}" 20 | echo "warmup1: ${warmup1:=1000}" 21 | echo "warmup2: ${warmup2:=0}" 22 | echo "warmup3: ${warmup3:=0}" 23 | echo "glat: ${glat:=1}" 24 | echo "update: ${update:=5}" 25 | echo "devices: ${devices:=0,1,2,3,4,5,6,7}" 26 | echo "config: ${config:=configs/roberta.yaml}" 27 | echo "path: ${path:=exp/ctc.roberta}" 28 | 29 | code=$path.code 30 | mkdir -p $path.code 31 | cp run.py $code/ 32 | cp -r ctc $code/ 33 | cp -r 3rdparty $code/ 34 | printf "Current commits:\n$(git log -1 --oneline)\n3rd parties:\n" 35 | cd 3rdparty/parser/ && printf "parser\n$(git log -1 --oneline)\n" && cd ../.. 36 | 37 | for stage in 1 2 3; do 38 | mkdir -p $path/stage$stage 39 | var="lr$stage"; lr=${!var} 40 | var="rate$stage"; rate=${!var} 41 | var="warmup$stage"; warmup=${!var} 42 | var="epochs$stage"; epochs=${!var} 43 | current="$path/stage$stage/model.lr$lr.rate$rate.upsampling$upsampling.batch$batch.epochs$epochs.warmup$warmup.glat$glat.seed$seed" 44 | 45 | if [ $stage -eq 1 ]; then 46 | train=data/clang8.train 47 | (set -x 48 | python -u run.py train -b -s $seed -d $devices -c $config -p $current --lr=$lr --lr-rate=$rate --upsampling=$upsampling --batch-size=$batch --epochs=$epochs --warmup-steps=$warmup --glat=$glat --update-steps=$update --encoder=bert --bert=$bert --train $train --eval-tgt --cache --amp 49 | ) 50 | else 51 | if [ $stage -eq 2 ]; then 52 | train=data/error_coded.train 53 | else 54 | train=data/wi_locness.train 55 | fi 56 | (set -x 57 | cp $prev $current 58 | python -u run.py train -s $seed -d $devices -c $config -p $current --lr=$lr --lr-rate=$rate --upsampling=$upsampling --batch-size=$batch --epochs=$epochs --warmup-steps=$warmup --glat=$glat --update-steps=$update --encoder=bert --bert=$bert --train $train --eval-tgt --cache --amp 59 | ) 60 | fi 61 | bash pred.sh path=$current 62 | prev=$current 63 | done 64 | --------------------------------------------------------------------------------