├── data └── output_vocabulary │ ├── non_padded_namespaces.txt │ └── d_tags.txt ├── requirements.txt ├── .gitignore ├── README.md ├── utils ├── prepare_clc_fce_data.py ├── helpers.py └── preprocess_data.py ├── predict.py ├── gector ├── datareader.py ├── seq2labels_model.py ├── bert_token_embedder.py ├── gec_model.py ├── wordpiece_indexer.py └── trainer.py ├── LICENSE └── train.py /data/output_vocabulary/non_padded_namespaces.txt: -------------------------------------------------------------------------------- 1 | *tags 2 | *labels 3 | -------------------------------------------------------------------------------- /data/output_vocabulary/d_tags.txt: -------------------------------------------------------------------------------- 1 | CORRECT 2 | INCORRECT 3 | @@UNKNOWN@@ 4 | @@PADDING@@ 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.3.0 2 | allennlp==0.8.4 3 | python-Levenshtein==0.12.0 4 | transformers==2.2.2 5 | scikit-learn==0.20.0 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | .DS_Store 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | 132 | # PyCharm 133 | .idea 134 | 135 | *.sh -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GECToR – Grammatical Error Correction: Tag, Not Rewrite 2 | 3 | This repository provides code for training and testing state-of-the-art models for grammatical error correction with the official PyTorch implementation of the following paper: 4 | > [GECToR – Grammatical Error Correction: Tag, Not Rewrite](https://arxiv.org/abs/2005.12592)
5 | > [Kostiantyn Omelianchuk](https://github.com/komelianchuk), [Vitaliy Atrasevych](https://github.com/atrasevych), [Artem Chernodub](https://github.com/achernodub), [Oleksandr Skurzhanskyi](https://github.com/skurzhanskyi)
6 | > Grammarly
7 | > [15th Workshop on Innovative Use of NLP for Building Educational Applications (co-located with ACL 2020)](https://sig-edu.org/bea/current)
8 | 9 | It is mainly based on `AllenNLP` and `transformers`. 10 | ## Installation 11 | The following command installs all necessary packages: 12 | ```.bash 13 | pip install -r requirements.txt 14 | ``` 15 | The project was tested using Python 3.7. 16 | 17 | ## Datasets 18 | All the public GEC datasets used in the paper can be downloaded from [here](https://www.cl.cam.ac.uk/research/nl/bea2019st/#data).
19 | Synthetically created datasets can be generated/downloaded [here](https://github.com/awasthiabhijeet/PIE/tree/master/errorify).
20 | To train the model data has to be preprocessed and converted to special format with the command: 21 | ```.bash 22 | python utils/preprocess_data.py -s SOURCE -t TARGET -o OUTPUT_FILE 23 | ``` 24 | ## Pretrained models 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 |
Pretrained encoderConfidence biasMin error probCoNNL-2014 (test)BEA-2019 (test)
BERT [link]0.100.4163.067.6
RoBERTa [link]0.200.5064.071.5
XLNet [link]0.350.6665.372.4
RoBERTa + XLNet0.240.4566.073.7
BERT + RoBERTa + XLNet0.160.4066.573.6
69 | 70 | ## Train model 71 | To train the model, simply run: 72 | ```.bash 73 | python train.py --train_set TRAIN_SET --dev_set DEV_SET \ 74 | --model_dir MODEL_DIR 75 | ``` 76 | There are a lot of parameters to specify among them: 77 | - `cold_steps_count` the number of epochs where we train only last linear layer 78 | - `transformer_model {bert,distilbert,gpt2,roberta,transformerxl,xlnet,albert}` model encoder 79 | - `tn_prob` probability of getting sentences with no errors; helps to balance precision/recall 80 | - `pieces_per_token` maximum number of subwords per token; helps not to get CUDA out of memory 81 | 82 | In our experiments we had 98/2 train/dev split. 83 | ## Model inference 84 | To run your model on the input file use the following command: 85 | ```.bash 86 | python predict.py --model_path MODEL_PATH [MODEL_PATH ...] \ 87 | --vocab_path VOCAB_PATH --input_file INPUT_FILE \ 88 | --output_file OUTPUT_FILE 89 | ``` 90 | Among parameters: 91 | - `min_error_probability` - minimum error probability (as in the paper) 92 | - `additional_confidence` - confidence bias (as in the paper) 93 | - `special_tokens_fix` to reproduce some reported results of pretrained models 94 | 95 | For evaluation use [M^2Scorer](https://github.com/nusnlp/m2scorer) and [ERRANT](https://github.com/chrisjbryant/errant). 96 | ## Citation 97 | If you find this work is useful for your research, please cite our paper: 98 | ``` 99 | @misc{omelianchuk2020gector, 100 | title={GECToR -- Grammatical Error Correction: Tag, Not Rewrite}, 101 | author={Kostiantyn Omelianchuk and Vitaliy Atrasevych and Artem Chernodub and Oleksandr Skurzhanskyi}, 102 | year={2020}, 103 | eprint={2005.12592}, 104 | archivePrefix={arXiv}, 105 | primaryClass={cs.CL} 106 | } 107 | ``` 108 | -------------------------------------------------------------------------------- /utils/prepare_clc_fce_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Convert CLC-FCE dataset (The Cambridge Learner Corpus) to the parallel sentences format. 4 | """ 5 | 6 | import argparse 7 | import glob 8 | import os 9 | import re 10 | from xml.etree import cElementTree 11 | 12 | from nltk.tokenize import sent_tokenize, word_tokenize 13 | from tqdm import tqdm 14 | 15 | 16 | def annotate_fce_doc(xml): 17 | """Takes a FCE xml document and yields sentences with annotated errors.""" 18 | result = [] 19 | doc = cElementTree.fromstring(xml) 20 | paragraphs = doc.findall('head/text/*/coded_answer/p') 21 | for p in paragraphs: 22 | text = _get_formatted_text(p) 23 | result.append(text) 24 | 25 | return '\n'.join(result) 26 | 27 | 28 | def _get_formatted_text(elem, ignore_tags=None): 29 | text = elem.text or '' 30 | ignore_tags = [tag.upper() for tag in (ignore_tags or [])] 31 | correct = None 32 | mistake = None 33 | 34 | for child in elem.getchildren(): 35 | tag = child.tag.upper() 36 | if tag == 'NS': 37 | text += _get_formatted_text(child) 38 | 39 | elif tag == 'UNKNOWN': 40 | text += ' UNKNOWN ' 41 | 42 | elif tag == 'C': 43 | assert correct is None 44 | correct = _get_formatted_text(child) 45 | 46 | elif tag == 'I': 47 | assert mistake is None 48 | mistake = _get_formatted_text(child) 49 | 50 | elif tag in ignore_tags: 51 | pass 52 | 53 | else: 54 | raise ValueError(f"Unknown tag `{child.tag}`", text) 55 | 56 | if correct or mistake: 57 | correct = correct or '' 58 | mistake = mistake or '' 59 | if '=>' not in mistake: 60 | text += f'{{{mistake}=>{correct}}}' 61 | else: 62 | text += mistake 63 | 64 | text += elem.tail or '' 65 | return text 66 | 67 | 68 | def convert_fce(fce_dir): 69 | """Processes the whole FCE directory. Yields annotated documents (strings).""" 70 | 71 | # Ensure we got the valid dataset path 72 | if not os.path.isdir(fce_dir): 73 | raise UserWarning( 74 | f"{fce_dir} is not a valid path") 75 | 76 | dataset_dir = os.path.join(fce_dir, 'dataset') 77 | if not os.path.exists(dataset_dir): 78 | raise UserWarning( 79 | f"{fce_dir} doesn't point to a dataset's root dir") 80 | 81 | # Convert XML docs to the corpora format 82 | filenames = sorted(glob.glob(os.path.join(dataset_dir, '*/*.xml'))) 83 | 84 | docs = [] 85 | for filename in filenames: 86 | with open(filename, encoding='utf-8') as f: 87 | doc = annotate_fce_doc(f.read()) 88 | docs.append(doc) 89 | return docs 90 | 91 | 92 | def main(): 93 | fce = convert_fce(args.fce_dataset_path) 94 | with open(args.output + "/fce-original.txt", 'w', encoding='utf-8') as out_original, \ 95 | open(args.output + "/fce-applied.txt", 'w', encoding='utf-8') as out_applied: 96 | for doc in tqdm(fce, unit='doc'): 97 | sents = re.split(r"\n +\n", doc) 98 | for sent in sents: 99 | tokenized_sents = sent_tokenize(sent) 100 | for i in range(len(tokenized_sents)): 101 | if re.search(r"[{>][.?!]$", tokenized_sents[i]): 102 | tokenized_sents[i + 1] = tokenized_sents[i] + " " + tokenized_sents[i + 1] 103 | tokenized_sents[i] = "" 104 | regexp = r'{([^{}]*?)=>([^{}]*?)}' 105 | original = re.sub(regexp, r"\1", tokenized_sents[i]) 106 | applied = re.sub(regexp, r"\2", tokenized_sents[i]) 107 | # filter out nested alerts 108 | if original != "" and applied != "" and not re.search(r"[{}=]", original) \ 109 | and not re.search(r"[{}=]", applied): 110 | out_original.write(" ".join(word_tokenize(original)) + "\n") 111 | out_applied.write(" ".join(word_tokenize(applied)) + "\n") 112 | 113 | 114 | if __name__ == '__main__': 115 | parser = argparse.ArgumentParser(description=( 116 | "Convert CLC-FCE dataset to the parallel sentences format.")) 117 | parser.add_argument('fce_dataset_path', 118 | help='Path to the folder with the FCE dataset') 119 | parser.add_argument('--output', 120 | help='Path to the output folder') 121 | args = parser.parse_args() 122 | 123 | main() 124 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from utils.helpers import read_lines 4 | from gector.gec_model import GecBERTModel 5 | 6 | 7 | def predict_for_file(input_file, output_file, model, batch_size=32): 8 | test_data = read_lines(input_file) 9 | predictions = [] 10 | cnt_corrections = 0 11 | batch = [] 12 | for sent in test_data: 13 | batch.append(sent.split()) 14 | if len(batch) == batch_size: 15 | preds, cnt = model.handle_batch(batch) 16 | predictions.extend(preds) 17 | cnt_corrections += cnt 18 | batch = [] 19 | if batch: 20 | preds, cnt = model.handle_batch(batch) 21 | predictions.extend(preds) 22 | cnt_corrections += cnt 23 | 24 | with open(output_file, 'w') as f: 25 | f.write("\n".join([" ".join(x) for x in predictions]) + '\n') 26 | return cnt_corrections 27 | 28 | 29 | def main(args): 30 | # get all paths 31 | model = GecBERTModel(vocab_path=args.vocab_path, 32 | model_paths=args.model_path, 33 | max_len=args.max_len, min_len=args.min_len, 34 | iterations=args.iteration_count, 35 | min_error_probability=args.min_error_probability, 36 | min_probability=args.min_error_probability, 37 | lowercase_tokens=args.lowercase_tokens, 38 | model_name=args.transformer_model, 39 | special_tokens_fix=args.special_tokens_fix, 40 | log=False, 41 | confidence=args.additional_confidence, 42 | is_ensemble=args.is_ensemble, 43 | weigths=args.weights) 44 | 45 | cnt_corrections = predict_for_file(args.input_file, args.output_file, model, 46 | batch_size=args.batch_size) 47 | # evaluate with m2 or ERRANT 48 | print(f"Produced overall corrections: {cnt_corrections}") 49 | 50 | 51 | if __name__ == '__main__': 52 | # read parameters 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument('--model_path', 55 | help='Path to the model file.', nargs='+', 56 | required=True) 57 | parser.add_argument('--vocab_path', 58 | help='Path to the model file.', 59 | default='data/output_vocabulary' # to use pretrained models 60 | ) 61 | parser.add_argument('--input_file', 62 | help='Path to the evalset file', 63 | required=True) 64 | parser.add_argument('--output_file', 65 | help='Path to the output file', 66 | required=True) 67 | parser.add_argument('--max_len', 68 | type=int, 69 | help='The max sentence length' 70 | '(all longer will be truncated)', 71 | default=50) 72 | parser.add_argument('--min_len', 73 | type=int, 74 | help='The minimum sentence length' 75 | '(all longer will be returned w/o changes)', 76 | default=3) 77 | parser.add_argument('--batch_size', 78 | type=int, 79 | help='The size of hidden unit cell.', 80 | default=128) 81 | parser.add_argument('--lowercase_tokens', 82 | type=int, 83 | help='Whether to lowercase tokens.', 84 | default=0) 85 | parser.add_argument('--transformer_model', 86 | choices=['bert', 'gpt2', 'transformerxl', 'xlnet', 'distilbert', 'roberta', 'albert'], 87 | help='Name of the transformer model.', 88 | default='roberta') 89 | parser.add_argument('--iteration_count', 90 | type=int, 91 | help='The number of iterations of the model.', 92 | default=5) 93 | parser.add_argument('--additional_confidence', 94 | type=float, 95 | help='How many probability to add to $KEEP token.', 96 | default=0) 97 | parser.add_argument('--min_probability', 98 | type=float, 99 | default=0.0) 100 | parser.add_argument('--min_error_probability', 101 | type=float, 102 | default=0.0) 103 | parser.add_argument('--special_tokens_fix', 104 | type=int, 105 | help='Whether to fix problem with [CLS], [SEP] tokens tokenization. ' 106 | 'For reproducing reported results it should be 0 for BERT/XLNet and 1 for RoBERTa.', 107 | default=1) 108 | parser.add_argument('--is_ensemble', 109 | type=int, 110 | help='Whether to do ensembling.', 111 | default=0) 112 | parser.add_argument('--weights', 113 | help='Used to calculate weighted average', nargs='+', 114 | default=None) 115 | args = parser.parse_args() 116 | main(args) 117 | -------------------------------------------------------------------------------- /gector/datareader.py: -------------------------------------------------------------------------------- 1 | """Tweaked AllenNLP dataset reader.""" 2 | import logging 3 | import re 4 | from random import random 5 | from typing import Dict, List 6 | 7 | from allennlp.common.file_utils import cached_path 8 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader 9 | from allennlp.data.fields import TextField, SequenceLabelField, MetadataField, Field 10 | from allennlp.data.instance import Instance 11 | from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer 12 | from allennlp.data.tokenizers import Token 13 | from overrides import overrides 14 | 15 | from utils.helpers import SEQ_DELIMETERS, START_TOKEN 16 | 17 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 18 | 19 | 20 | @DatasetReader.register("seq2labels_datareader") 21 | class Seq2LabelsDatasetReader(DatasetReader): 22 | """ 23 | Reads instances from a pretokenised file where each line is in the following format: 24 | 25 | WORD###TAG [TAB] WORD###TAG [TAB] ..... \n 26 | 27 | and converts it into a ``Dataset`` suitable for sequence tagging. You can also specify 28 | alternative delimiters in the constructor. 29 | 30 | Parameters 31 | ---------- 32 | delimiters: ``dict`` 33 | The dcitionary with all delimeters. 34 | token_indexers : ``Dict[str, TokenIndexer]``, optional (default=``{"tokens": SingleIdTokenIndexer()}``) 35 | We use this to define the input representation for the text. See :class:`TokenIndexer`. 36 | Note that the `output` tags will always correspond to single token IDs based on how they 37 | are pre-tokenised in the data file. 38 | max_len: if set than will truncate long sentences 39 | """ 40 | # fix broken sentences mostly in Lang8 41 | BROKEN_SENTENCES_REGEXP = re.compile(r'\.[a-zA-RT-Z]') 42 | 43 | def __init__(self, 44 | token_indexers: Dict[str, TokenIndexer] = None, 45 | delimeters: dict = SEQ_DELIMETERS, 46 | skip_correct: bool = False, 47 | skip_complex: int = 0, 48 | lazy: bool = False, 49 | max_len: int = None, 50 | test_mode: bool = False, 51 | tag_strategy: str = "keep_one", 52 | tn_prob: float = 0, 53 | tp_prob: float = 0, 54 | broken_dot_strategy: str = "keep") -> None: 55 | super().__init__(lazy) 56 | self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()} 57 | self._delimeters = delimeters 58 | self._max_len = max_len 59 | self._skip_correct = skip_correct 60 | self._skip_complex = skip_complex 61 | self._tag_strategy = tag_strategy 62 | self._broken_dot_strategy = broken_dot_strategy 63 | self._test_mode = test_mode 64 | self._tn_prob = tn_prob 65 | self._tp_prob = tp_prob 66 | 67 | @overrides 68 | def _read(self, file_path): 69 | # if `file_path` is a URL, redirect to the cache 70 | file_path = cached_path(file_path) 71 | with open(file_path, "r") as data_file: 72 | logger.info("Reading instances from lines in file at: %s", file_path) 73 | for line in data_file: 74 | line = line.strip("\n") 75 | # skip blank and broken lines 76 | if not line or (not self._test_mode and self._broken_dot_strategy == 'skip' 77 | and self.BROKEN_SENTENCES_REGEXP.search(line) is not None): 78 | continue 79 | 80 | tokens_and_tags = [pair.rsplit(self._delimeters['labels'], 1) 81 | for pair in line.split(self._delimeters['tokens'])] 82 | try: 83 | tokens = [Token(token) for token, tag in tokens_and_tags] 84 | tags = [tag for token, tag in tokens_and_tags] 85 | except ValueError: 86 | tokens = [Token(token[0]) for token in tokens_and_tags] 87 | tags = None 88 | 89 | if tokens and tokens[0] != Token(START_TOKEN): 90 | tokens = [Token(START_TOKEN)] + tokens 91 | 92 | words = [x.text for x in tokens] 93 | if self._max_len is not None: 94 | tokens = tokens[:self._max_len] 95 | tags = None if tags is None else tags[:self._max_len] 96 | instance = self.text_to_instance(tokens, tags, words) 97 | if instance: 98 | yield instance 99 | 100 | def extract_tags(self, tags: List[str]): 101 | op_del = self._delimeters['operations'] 102 | 103 | labels = [x.split(op_del) for x in tags] 104 | 105 | comlex_flag_dict = {} 106 | # get flags 107 | for i in range(5): 108 | idx = i + 1 109 | comlex_flag_dict[idx] = sum([len(x) > idx for x in labels]) 110 | 111 | if self._tag_strategy == "keep_one": 112 | # get only first candidates for r_tags in right and the last for left 113 | labels = [x[0] for x in labels] 114 | elif self._tag_strategy == "merge_all": 115 | # consider phrases as a words 116 | pass 117 | else: 118 | raise Exception("Incorrect tag strategy") 119 | 120 | detect_tags = ["CORRECT" if label == "$KEEP" else "INCORRECT" for label in labels] 121 | return labels, detect_tags, comlex_flag_dict 122 | 123 | def text_to_instance(self, tokens: List[Token], tags: List[str] = None, 124 | words: List[str] = None) -> Instance: # type: ignore 125 | """ 126 | We take `pre-tokenized` input here, because we don't have a tokenizer in this class. 127 | """ 128 | # pylint: disable=arguments-differ 129 | fields: Dict[str, Field] = {} 130 | sequence = TextField(tokens, self._token_indexers) 131 | fields["tokens"] = sequence 132 | fields["metadata"] = MetadataField({"words": words}) 133 | if tags is not None: 134 | labels, detect_tags, complex_flag_dict = self.extract_tags(tags) 135 | if self._skip_complex and complex_flag_dict[self._skip_complex] > 0: 136 | return None 137 | rnd = random() 138 | # skip TN 139 | if self._skip_correct and all(x == "CORRECT" for x in detect_tags): 140 | if rnd > self._tn_prob: 141 | return None 142 | # skip TP 143 | else: 144 | if rnd > self._tp_prob: 145 | return None 146 | 147 | fields["labels"] = SequenceLabelField(labels, sequence, 148 | label_namespace="labels") 149 | fields["d_tags"] = SequenceLabelField(detect_tags, sequence, 150 | label_namespace="d_tags") 151 | return Instance(fields) 152 | -------------------------------------------------------------------------------- /utils/helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | 5 | VOCAB_DIR = Path(__file__).resolve().parent.parent / "data" 6 | PAD = "@@PADDING@@" 7 | UNK = "@@UNKNOWN@@" 8 | START_TOKEN = "$START" 9 | SEQ_DELIMETERS = {"tokens": " ", 10 | "labels": "SEPL|||SEPR", 11 | "operations": "SEPL__SEPR"} 12 | 13 | 14 | def get_verb_form_dicts(): 15 | path_to_dict = os.path.join(VOCAB_DIR, "verb-form-vocab.txt") 16 | encode, decode = {}, {} 17 | with open(path_to_dict, encoding="utf-8") as f: 18 | for line in f: 19 | words, tags = line.split(":") 20 | word1, word2 = words.split("_") 21 | tag1, tag2 = tags.split("_") 22 | decode_key = f"{word1}_{tag1}_{tag2.strip()}" 23 | if decode_key not in decode: 24 | encode[words] = tags 25 | decode[decode_key] = word2 26 | return encode, decode 27 | 28 | 29 | ENCODE_VERB_DICT, DECODE_VERB_DICT = get_verb_form_dicts() 30 | 31 | 32 | def get_target_sent_by_edits(source_tokens, edits): 33 | target_tokens = source_tokens[:] 34 | shift_idx = 0 35 | for edit in edits: 36 | start, end, label, _ = edit 37 | target_pos = start + shift_idx 38 | source_token = target_tokens[target_pos] if target_pos >= 0 else '' 39 | if label == "": 40 | del target_tokens[target_pos] 41 | shift_idx -= 1 42 | elif start == end: 43 | word = label.replace("$APPEND_", "") 44 | target_tokens[target_pos: target_pos] = [word] 45 | shift_idx += 1 46 | elif label.startswith("$TRANSFORM_"): 47 | word = apply_reverse_transformation(source_token, label) 48 | if word is None: 49 | word = source_token 50 | target_tokens[target_pos] = word 51 | elif start == end - 1: 52 | word = label.replace("$REPLACE_", "") 53 | target_tokens[target_pos] = word 54 | elif label.startswith("$MERGE_"): 55 | target_tokens[target_pos + 1: target_pos + 1] = [label] 56 | shift_idx += 1 57 | 58 | return replace_merge_transforms(target_tokens) 59 | 60 | 61 | def replace_merge_transforms(tokens): 62 | if all(not x.startswith("$MERGE_") for x in tokens): 63 | return tokens 64 | 65 | target_line = " ".join(tokens) 66 | target_line = target_line.replace(" $MERGE_HYPHEN ", "-") 67 | target_line = target_line.replace(" $MERGE_SPACE ", "") 68 | return target_line.split() 69 | 70 | 71 | def convert_using_case(token, smart_action): 72 | if not smart_action.startswith("$TRANSFORM_CASE_"): 73 | return token 74 | if smart_action.endswith("LOWER"): 75 | return token.lower() 76 | elif smart_action.endswith("UPPER"): 77 | return token.upper() 78 | elif smart_action.endswith("CAPITAL"): 79 | return token.capitalize() 80 | elif smart_action.endswith("CAPITAL_1"): 81 | return token[0] + token[1:].capitalize() 82 | elif smart_action.endswith("UPPER_-1"): 83 | return token[:-1].upper() + token[-1] 84 | else: 85 | return token 86 | 87 | 88 | def convert_using_verb(token, smart_action): 89 | key_word = "$TRANSFORM_VERB_" 90 | if not smart_action.startswith(key_word): 91 | raise Exception(f"Unknown action type {smart_action}") 92 | encoding_part = f"{token}_{smart_action[len(key_word):]}" 93 | decoded_target_word = decode_verb_form(encoding_part) 94 | return decoded_target_word 95 | 96 | 97 | def convert_using_split(token, smart_action): 98 | key_word = "$TRANSFORM_SPLIT" 99 | if not smart_action.startswith(key_word): 100 | raise Exception(f"Unknown action type {smart_action}") 101 | target_words = token.split("-") 102 | return " ".join(target_words) 103 | 104 | 105 | def convert_using_plural(token, smart_action): 106 | if smart_action.endswith("PLURAL"): 107 | return token + "s" 108 | elif smart_action.endswith("SINGULAR"): 109 | return token[:-1] 110 | else: 111 | raise Exception(f"Unknown action type {smart_action}") 112 | 113 | 114 | def apply_reverse_transformation(source_token, transform): 115 | if transform.startswith("$TRANSFORM"): 116 | # deal with equal 117 | if transform == "$KEEP": 118 | return source_token 119 | # deal with case 120 | if transform.startswith("$TRANSFORM_CASE"): 121 | return convert_using_case(source_token, transform) 122 | # deal with verb 123 | if transform.startswith("$TRANSFORM_VERB"): 124 | return convert_using_verb(source_token, transform) 125 | # deal with split 126 | if transform.startswith("$TRANSFORM_SPLIT"): 127 | return convert_using_split(source_token, transform) 128 | # deal with single/plural 129 | if transform.startswith("$TRANSFORM_AGREEMENT"): 130 | return convert_using_plural(source_token, transform) 131 | # raise exception if not find correct type 132 | raise Exception(f"Unknown action type {transform}") 133 | else: 134 | return source_token 135 | 136 | 137 | def read_parallel_lines(fn1, fn2): 138 | lines1 = read_lines(fn1, skip_strip=True) 139 | lines2 = read_lines(fn2, skip_strip=True) 140 | assert len(lines1) == len(lines2) 141 | out_lines1, out_lines2 = [], [] 142 | for line1, line2 in zip(lines1, lines2): 143 | if not line1.strip() or not line2.strip(): 144 | continue 145 | else: 146 | out_lines1.append(line1) 147 | out_lines2.append(line2) 148 | return out_lines1, out_lines2 149 | 150 | 151 | def read_lines(fn, skip_strip=False): 152 | if not os.path.exists(fn): 153 | return [] 154 | with open(fn, 'r', encoding='utf-8') as f: 155 | lines = f.readlines() 156 | return [s.strip() for s in lines if s.strip() or skip_strip] 157 | 158 | 159 | def write_lines(fn, lines, mode='w'): 160 | if mode == 'w' and os.path.exists(fn): 161 | os.remove(fn) 162 | with open(fn, encoding='utf-8', mode=mode) as f: 163 | f.writelines(['%s\n' % s for s in lines]) 164 | 165 | 166 | def decode_verb_form(original): 167 | return DECODE_VERB_DICT.get(original) 168 | 169 | 170 | def encode_verb_form(original_word, corrected_word): 171 | decoding_request = original_word + "_" + corrected_word 172 | decoding_response = ENCODE_VERB_DICT.get(decoding_request, "").strip() 173 | if original_word and decoding_response: 174 | answer = decoding_response 175 | else: 176 | answer = None 177 | return answer 178 | 179 | 180 | def get_weights_name(transformer_name, lowercase): 181 | if transformer_name == 'bert' and lowercase: 182 | return 'bert-base-uncased' 183 | if transformer_name == 'bert' and not lowercase: 184 | return 'bert-base-cased' 185 | if transformer_name == 'distilbert': 186 | if not lowercase: 187 | print('Warning! This model was trained only on uncased sentences.') 188 | return 'distilbert-base-uncased' 189 | if transformer_name == 'albert': 190 | if not lowercase: 191 | print('Warning! This model was trained only on uncased sentences.') 192 | return 'albert-base-v1' 193 | if lowercase: 194 | print('Warning! This model was trained only on cased sentences.') 195 | if transformer_name == 'roberta': 196 | return 'roberta-base' 197 | if transformer_name == 'gpt2': 198 | return 'gpt2' 199 | if transformer_name == 'transformerxl': 200 | return 'transfo-xl-wt103' 201 | if transformer_name == 'xlnet': 202 | return 'xlnet-base-cased' 203 | -------------------------------------------------------------------------------- /gector/seq2labels_model.py: -------------------------------------------------------------------------------- 1 | """Basic model. Predicts tags for every token""" 2 | from typing import Dict, Optional, List, Any 3 | 4 | import numpy 5 | import torch 6 | import torch.nn.functional as F 7 | from allennlp.data import Vocabulary 8 | from allennlp.models.model import Model 9 | from allennlp.modules import TimeDistributed, TextFieldEmbedder 10 | from allennlp.nn import InitializerApplicator, RegularizerApplicator 11 | from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits 12 | from allennlp.training.metrics import CategoricalAccuracy 13 | from overrides import overrides 14 | from torch.nn.modules.linear import Linear 15 | 16 | 17 | @Model.register("seq2labels") 18 | class Seq2Labels(Model): 19 | """ 20 | This ``Seq2Labels`` simply encodes a sequence of text with a stacked ``Seq2SeqEncoder``, then 21 | predicts a tag (or couple tags) for each token in the sequence. 22 | 23 | Parameters 24 | ---------- 25 | vocab : ``Vocabulary``, required 26 | A Vocabulary, required in order to compute sizes for input/output projections. 27 | text_field_embedder : ``TextFieldEmbedder``, required 28 | Used to embed the ``tokens`` ``TextField`` we get as input to the model. 29 | encoder : ``Seq2SeqEncoder`` 30 | The encoder (with its own internal stacking) that we will use in between embedding tokens 31 | and predicting output tags. 32 | calculate_span_f1 : ``bool``, optional (default=``None``) 33 | Calculate span-level F1 metrics during training. If this is ``True``, then 34 | ``label_encoding`` is required. If ``None`` and 35 | label_encoding is specified, this is set to ``True``. 36 | If ``None`` and label_encoding is not specified, it defaults 37 | to ``False``. 38 | label_encoding : ``str``, optional (default=``None``) 39 | Label encoding to use when calculating span f1. 40 | Valid options are "BIO", "BIOUL", "IOB1", "BMES". 41 | Required if ``calculate_span_f1`` is true. 42 | label_namespace : ``str``, optional (default=``labels``) 43 | This is needed to compute the SpanBasedF1Measure metric, if desired. 44 | Unless you did something unusual, the default value should be what you want. 45 | verbose_metrics : ``bool``, optional (default = False) 46 | If true, metrics will be returned per label class in addition 47 | to the overall statistics. 48 | initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``) 49 | Used to initialize the model parameters. 50 | regularizer : ``RegularizerApplicator``, optional (default=``None``) 51 | If provided, will be used to calculate the regularization penalty during training. 52 | """ 53 | 54 | def __init__(self, vocab: Vocabulary, 55 | text_field_embedder: TextFieldEmbedder, 56 | predictor_dropout=0.0, 57 | labels_namespace: str = "labels", 58 | detect_namespace: str = "d_tags", 59 | verbose_metrics: bool = False, 60 | label_smoothing: float = 0.0, 61 | confidence: float = 0.0, 62 | initializer: InitializerApplicator = InitializerApplicator(), 63 | regularizer: Optional[RegularizerApplicator] = None) -> None: 64 | super(Seq2Labels, self).__init__(vocab, regularizer) 65 | 66 | self.label_namespaces = [labels_namespace, 67 | detect_namespace] 68 | self.text_field_embedder = text_field_embedder 69 | self.num_labels_classes = self.vocab.get_vocab_size(labels_namespace) 70 | self.num_detect_classes = self.vocab.get_vocab_size(detect_namespace) 71 | self.label_smoothing = label_smoothing 72 | self.confidence = confidence 73 | self.incorr_index = self.vocab.get_token_index("INCORRECT", 74 | namespace=detect_namespace) 75 | 76 | self._verbose_metrics = verbose_metrics 77 | self.predictor_dropout = TimeDistributed(torch.nn.Dropout(predictor_dropout)) 78 | 79 | self.tag_labels_projection_layer = TimeDistributed( 80 | Linear(text_field_embedder._token_embedders['bert'].get_output_dim(), self.num_labels_classes)) 81 | 82 | self.tag_detect_projection_layer = TimeDistributed( 83 | Linear(text_field_embedder._token_embedders['bert'].get_output_dim(), self.num_detect_classes)) 84 | 85 | self.metrics = {"accuracy": CategoricalAccuracy()} 86 | 87 | initializer(self) 88 | 89 | @overrides 90 | def forward(self, # type: ignore 91 | tokens: Dict[str, torch.LongTensor], 92 | labels: torch.LongTensor = None, 93 | d_tags: torch.LongTensor = None, 94 | metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: 95 | # pylint: disable=arguments-differ 96 | """ 97 | Parameters 98 | ---------- 99 | tokens : Dict[str, torch.LongTensor], required 100 | The output of ``TextField.as_array()``, which should typically be passed directly to a 101 | ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer`` 102 | tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens": 103 | Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used 104 | for the ``TokenIndexers`` when you created the ``TextField`` representing your 105 | sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``, 106 | which knows how to combine different word representations into a single vector per 107 | token in your input. 108 | lables : torch.LongTensor, optional (default = None) 109 | A torch tensor representing the sequence of integer gold class labels of shape 110 | ``(batch_size, num_tokens)``. 111 | d_tags : torch.LongTensor, optional (default = None) 112 | A torch tensor representing the sequence of integer gold class labels of shape 113 | ``(batch_size, num_tokens)``. 114 | metadata : ``List[Dict[str, Any]]``, optional, (default = None) 115 | metadata containing the original words in the sentence to be tagged under a 'words' key. 116 | 117 | Returns 118 | ------- 119 | An output dictionary consisting of: 120 | logits : torch.FloatTensor 121 | A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing 122 | unnormalised log probabilities of the tag classes. 123 | class_probabilities : torch.FloatTensor 124 | A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing 125 | a distribution of the tag classes per word. 126 | loss : torch.FloatTensor, optional 127 | A scalar loss to be optimised. 128 | 129 | """ 130 | encoded_text = self.text_field_embedder(tokens) 131 | batch_size, sequence_length, _ = encoded_text.size() 132 | mask = get_text_field_mask(tokens) 133 | logits_labels = self.tag_labels_projection_layer(self.predictor_dropout(encoded_text)) 134 | logits_d = self.tag_detect_projection_layer(encoded_text) 135 | 136 | class_probabilities_labels = F.softmax(logits_labels, dim=-1).view( 137 | [batch_size, sequence_length, self.num_labels_classes]) 138 | class_probabilities_d = F.softmax(logits_d, dim=-1).view( 139 | [batch_size, sequence_length, self.num_detect_classes]) 140 | error_probs = class_probabilities_d[:, :, self.incorr_index] * mask 141 | incorr_prob = torch.max(error_probs, dim=-1)[0] 142 | 143 | if self.confidence > 0: 144 | probability_change = [self.confidence] + [0] * (self.num_labels_classes - 1) 145 | class_probabilities_labels += torch.cuda.FloatTensor(probability_change).repeat( 146 | (batch_size, sequence_length, 1)) 147 | 148 | output_dict = {"logits_labels": logits_labels, 149 | "logits_d_tags": logits_d, 150 | "class_probabilities_labels": class_probabilities_labels, 151 | "class_probabilities_d_tags": class_probabilities_d, 152 | "max_error_probability": incorr_prob} 153 | if labels is not None and d_tags is not None: 154 | loss_labels = sequence_cross_entropy_with_logits(logits_labels, labels, mask, 155 | label_smoothing=self.label_smoothing) 156 | loss_d = sequence_cross_entropy_with_logits(logits_d, d_tags, mask) 157 | for metric in self.metrics.values(): 158 | metric(logits_labels, labels, mask.float()) 159 | metric(logits_d, d_tags, mask.float()) 160 | output_dict["loss"] = loss_labels + loss_d 161 | 162 | if metadata is not None: 163 | output_dict["words"] = [x["words"] for x in metadata] 164 | return output_dict 165 | 166 | @overrides 167 | def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 168 | """ 169 | Does a simple position-wise argmax over each token, converts indices to string labels, and 170 | adds a ``"tags"`` key to the dictionary with the result. 171 | """ 172 | for label_namespace in self.label_namespaces: 173 | all_predictions = output_dict[f'class_probabilities_{label_namespace}'] 174 | all_predictions = all_predictions.cpu().data.numpy() 175 | if all_predictions.ndim == 3: 176 | predictions_list = [all_predictions[i] for i in range(all_predictions.shape[0])] 177 | else: 178 | predictions_list = [all_predictions] 179 | all_tags = [] 180 | 181 | for predictions in predictions_list: 182 | argmax_indices = numpy.argmax(predictions, axis=-1) 183 | tags = [self.vocab.get_token_from_index(x, namespace=label_namespace) 184 | for x in argmax_indices] 185 | all_tags.append(tags) 186 | output_dict[f'{label_namespace}'] = all_tags 187 | return output_dict 188 | 189 | @overrides 190 | def get_metrics(self, reset: bool = False) -> Dict[str, float]: 191 | metrics_to_return = {metric_name: metric.get_metric(reset) for 192 | metric_name, metric in self.metrics.items()} 193 | return metrics_to_return 194 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. -------------------------------------------------------------------------------- /gector/bert_token_embedder.py: -------------------------------------------------------------------------------- 1 | """Tweaked version of corresponding AllenNLP file""" 2 | import logging 3 | from copy import deepcopy 4 | from typing import Dict 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from allennlp.modules.token_embedders.token_embedder import TokenEmbedder 9 | from allennlp.nn import util 10 | from transformers import AutoModel, PreTrainedModel 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class PretrainedBertModel: 16 | """ 17 | In some instances you may want to load the same BERT model twice 18 | (e.g. to use as a token embedder and also as a pooling layer). 19 | This factory provides a cache so that you don't actually have to load the model twice. 20 | """ 21 | 22 | _cache: Dict[str, PreTrainedModel] = {} 23 | 24 | @classmethod 25 | def load(cls, model_name: str, cache_model: bool = True) -> PreTrainedModel: 26 | if model_name in cls._cache: 27 | return PretrainedBertModel._cache[model_name] 28 | 29 | model = AutoModel.from_pretrained(model_name) 30 | if cache_model: 31 | cls._cache[model_name] = model 32 | 33 | return model 34 | 35 | 36 | class BertEmbedder(TokenEmbedder): 37 | """ 38 | A ``TokenEmbedder`` that produces BERT embeddings for your tokens. 39 | Should be paired with a ``BertIndexer``, which produces wordpiece ids. 40 | Most likely you probably want to use ``PretrainedBertEmbedder`` 41 | for one of the named pretrained models, not this base class. 42 | Parameters 43 | ---------- 44 | bert_model: ``BertModel`` 45 | The BERT model being wrapped. 46 | top_layer_only: ``bool``, optional (default = ``False``) 47 | If ``True``, then only return the top layer instead of apply the scalar mix. 48 | max_pieces : int, optional (default: 512) 49 | The BERT embedder uses positional embeddings and so has a corresponding 50 | maximum length for its input ids. Assuming the inputs are windowed 51 | and padded appropriately by this length, the embedder will split them into a 52 | large batch, feed them into BERT, and recombine the output as if it was a 53 | longer sequence. 54 | num_start_tokens : int, optional (default: 1) 55 | The number of starting special tokens input to BERT (usually 1, i.e., [CLS]) 56 | num_end_tokens : int, optional (default: 1) 57 | The number of ending tokens input to BERT (usually 1, i.e., [SEP]) 58 | scalar_mix_parameters: ``List[float]``, optional, (default = None) 59 | If not ``None``, use these scalar mix parameters to weight the representations 60 | produced by different layers. These mixing weights are not updated during 61 | training. 62 | """ 63 | 64 | def __init__( 65 | self, 66 | bert_model: PreTrainedModel, 67 | top_layer_only: bool = False, 68 | max_pieces: int = 512, 69 | num_start_tokens: int = 1, 70 | num_end_tokens: int = 1 71 | ) -> None: 72 | super().__init__() 73 | # self.bert_model = bert_model 74 | self.bert_model = deepcopy(bert_model) 75 | self.output_dim = bert_model.config.hidden_size 76 | self.max_pieces = max_pieces 77 | self.num_start_tokens = num_start_tokens 78 | self.num_end_tokens = num_end_tokens 79 | self._scalar_mix = None 80 | 81 | def set_weights(self, freeze): 82 | for param in self.bert_model.parameters(): 83 | param.requires_grad = not freeze 84 | return 85 | 86 | def get_output_dim(self) -> int: 87 | return self.output_dim 88 | 89 | def forward( 90 | self, 91 | input_ids: torch.LongTensor, 92 | offsets: torch.LongTensor = None 93 | ) -> torch.Tensor: 94 | """ 95 | Parameters 96 | ---------- 97 | input_ids : ``torch.LongTensor`` 98 | The (batch_size, ..., max_sequence_length) tensor of wordpiece ids. 99 | offsets : ``torch.LongTensor``, optional 100 | The BERT embeddings are one per wordpiece. However it's possible/likely 101 | you might want one per original token. In that case, ``offsets`` 102 | represents the indices of the desired wordpiece for each original token. 103 | Depending on how your token indexer is configured, this could be the 104 | position of the last wordpiece for each token, or it could be the position 105 | of the first wordpiece for each token. 106 | For example, if you had the sentence "Definitely not", and if the corresponding 107 | wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids 108 | would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4]. 109 | If offsets are provided, the returned tensor will contain only the wordpiece 110 | embeddings at those positions, and (in particular) will contain one embedding 111 | per token. If offsets are not provided, the entire tensor of wordpiece embeddings 112 | will be returned. 113 | """ 114 | 115 | batch_size, full_seq_len = input_ids.size(0), input_ids.size(-1) 116 | initial_dims = list(input_ids.shape[:-1]) 117 | 118 | # The embedder may receive an input tensor that has a sequence length longer than can 119 | # be fit. In that case, we should expect the wordpiece indexer to create padded windows 120 | # of length `self.max_pieces` for us, and have them concatenated into one long sequence. 121 | # E.g., "[CLS] I went to the [SEP] [CLS] to the store to [SEP] ..." 122 | # We can then split the sequence into sub-sequences of that length, and concatenate them 123 | # along the batch dimension so we effectively have one huge batch of partial sentences. 124 | # This can then be fed into BERT without any sentence length issues. Keep in mind 125 | # that the memory consumption can dramatically increase for large batches with extremely 126 | # long sentences. 127 | needs_split = full_seq_len > self.max_pieces 128 | last_window_size = 0 129 | if needs_split: 130 | # Split the flattened list by the window size, `max_pieces` 131 | split_input_ids = list(input_ids.split(self.max_pieces, dim=-1)) 132 | 133 | # We want all sequences to be the same length, so pad the last sequence 134 | last_window_size = split_input_ids[-1].size(-1) 135 | padding_amount = self.max_pieces - last_window_size 136 | split_input_ids[-1] = F.pad(split_input_ids[-1], pad=[0, padding_amount], value=0) 137 | 138 | # Now combine the sequences along the batch dimension 139 | input_ids = torch.cat(split_input_ids, dim=0) 140 | 141 | input_mask = (input_ids != 0).long() 142 | # input_ids may have extra dimensions, so we reshape down to 2-d 143 | # before calling the BERT model and then reshape back at the end. 144 | all_encoder_layers = self.bert_model( 145 | input_ids=util.combine_initial_dims(input_ids), 146 | attention_mask=util.combine_initial_dims(input_mask), 147 | )[0] 148 | if len(all_encoder_layers[0].shape) == 3: 149 | all_encoder_layers = torch.stack(all_encoder_layers) 150 | elif len(all_encoder_layers[0].shape) == 2: 151 | all_encoder_layers = torch.unsqueeze(all_encoder_layers, dim=0) 152 | 153 | if needs_split: 154 | # First, unpack the output embeddings into one long sequence again 155 | unpacked_embeddings = torch.split(all_encoder_layers, batch_size, dim=1) 156 | unpacked_embeddings = torch.cat(unpacked_embeddings, dim=2) 157 | 158 | # Next, select indices of the sequence such that it will result in embeddings representing the original 159 | # sentence. To capture maximal context, the indices will be the middle part of each embedded window 160 | # sub-sequence (plus any leftover start and final edge windows), e.g., 161 | # 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 162 | # "[CLS] I went to the very fine [SEP] [CLS] the very fine store to eat [SEP]" 163 | # with max_pieces = 8 should produce max context indices [2, 3, 4, 10, 11, 12] with additional start 164 | # and final windows with indices [0, 1] and [14, 15] respectively. 165 | 166 | # Find the stride as half the max pieces, ignoring the special start and end tokens 167 | # Calculate an offset to extract the centermost embeddings of each window 168 | stride = (self.max_pieces - self.num_start_tokens - self.num_end_tokens) // 2 169 | stride_offset = stride // 2 + self.num_start_tokens 170 | 171 | first_window = list(range(stride_offset)) 172 | 173 | max_context_windows = [ 174 | i 175 | for i in range(full_seq_len) 176 | if stride_offset - 1 < i % self.max_pieces < stride_offset + stride 177 | ] 178 | 179 | # Lookback what's left, unless it's the whole self.max_pieces window 180 | if full_seq_len % self.max_pieces == 0: 181 | lookback = self.max_pieces 182 | else: 183 | lookback = full_seq_len % self.max_pieces 184 | 185 | final_window_start = full_seq_len - lookback + stride_offset + stride 186 | final_window = list(range(final_window_start, full_seq_len)) 187 | 188 | select_indices = first_window + max_context_windows + final_window 189 | 190 | initial_dims.append(len(select_indices)) 191 | 192 | recombined_embeddings = unpacked_embeddings[:, :, select_indices] 193 | else: 194 | recombined_embeddings = all_encoder_layers 195 | 196 | # Recombine the outputs of all layers 197 | # (layers, batch_size * d1 * ... * dn, sequence_length, embedding_dim) 198 | # recombined = torch.cat(combined, dim=2) 199 | input_mask = (recombined_embeddings != 0).long() 200 | 201 | if self._scalar_mix is not None: 202 | mix = self._scalar_mix(recombined_embeddings, input_mask) 203 | else: 204 | mix = recombined_embeddings[-1] 205 | 206 | # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim) 207 | 208 | if offsets is None: 209 | # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim) 210 | dims = initial_dims if needs_split else input_ids.size() 211 | return util.uncombine_initial_dims(mix, dims) 212 | else: 213 | # offsets is (batch_size, d1, ..., dn, orig_sequence_length) 214 | offsets2d = util.combine_initial_dims(offsets) 215 | # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length) 216 | range_vector = util.get_range_vector( 217 | offsets2d.size(0), device=util.get_device_of(mix) 218 | ).unsqueeze(1) 219 | # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length) 220 | selected_embeddings = mix[range_vector, offsets2d] 221 | 222 | return util.uncombine_initial_dims(selected_embeddings, offsets.size()) 223 | 224 | 225 | # @TokenEmbedder.register("bert-pretrained") 226 | class PretrainedBertEmbedder(BertEmbedder): 227 | 228 | """ 229 | Parameters 230 | ---------- 231 | pretrained_model: ``str`` 232 | Either the name of the pretrained model to use (e.g. 'bert-base-uncased'), 233 | or the path to the .tar.gz file with the model weights. 234 | If the name is a key in the list of pretrained models at 235 | https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/modeling.py#L41 236 | the corresponding path will be used; otherwise it will be interpreted as a path or URL. 237 | requires_grad : ``bool``, optional (default = False) 238 | If True, compute gradient of BERT parameters for fine tuning. 239 | top_layer_only: ``bool``, optional (default = ``False``) 240 | If ``True``, then only return the top layer instead of apply the scalar mix. 241 | scalar_mix_parameters: ``List[float]``, optional, (default = None) 242 | If not ``None``, use these scalar mix parameters to weight the representations 243 | produced by different layers. These mixing weights are not updated during 244 | training. 245 | """ 246 | 247 | def __init__( 248 | self, 249 | pretrained_model: str, 250 | requires_grad: bool = False, 251 | top_layer_only: bool = False, 252 | special_tokens_fix: int = 0, 253 | ) -> None: 254 | model = PretrainedBertModel.load(pretrained_model) 255 | 256 | for param in model.parameters(): 257 | param.requires_grad = requires_grad 258 | 259 | super().__init__( 260 | bert_model=model, 261 | top_layer_only=top_layer_only 262 | ) 263 | 264 | if special_tokens_fix: 265 | try: 266 | vocab_size = self.bert_model.embeddings.word_embeddings.num_embeddings 267 | except AttributeError: 268 | # reserve more space 269 | vocab_size = self.bert_model.word_embedding.num_embeddings + 5 270 | self.bert_model.resize_token_embeddings(vocab_size + 1) 271 | -------------------------------------------------------------------------------- /gector/gec_model.py: -------------------------------------------------------------------------------- 1 | """Wrapper of AllenNLP model. Fixes errors based on model predictions""" 2 | import logging 3 | import os 4 | import sys 5 | from time import time 6 | 7 | import torch 8 | from allennlp.data.dataset import Batch 9 | from allennlp.data.fields import TextField 10 | from allennlp.data.instance import Instance 11 | from allennlp.data.tokenizers import Token 12 | from allennlp.data.vocabulary import Vocabulary 13 | from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder 14 | from allennlp.nn import util 15 | 16 | from gector.bert_token_embedder import PretrainedBertEmbedder 17 | from gector.seq2labels_model import Seq2Labels 18 | from gector.wordpiece_indexer import PretrainedBertIndexer 19 | from utils.helpers import PAD, UNK, get_target_sent_by_edits 20 | 21 | logging.getLogger("werkzeug").setLevel(logging.ERROR) 22 | logger = logging.getLogger(__file__) 23 | 24 | 25 | def get_weights_name(transformer_name, lowercase): 26 | if transformer_name == 'bert' and lowercase: 27 | return 'bert-base-uncased' 28 | if transformer_name == 'bert' and not lowercase: 29 | return 'bert-base-cased' 30 | if transformer_name == 'distilbert': 31 | if not lowercase: 32 | print('Warning! This model was trained only on uncased sentences.') 33 | return 'distilbert-base-uncased' 34 | if transformer_name == 'albert': 35 | if not lowercase: 36 | print('Warning! This model was trained only on uncased sentences.') 37 | return 'albert-base-v1' 38 | if lowercase: 39 | print('Warning! This model was trained only on cased sentences.') 40 | if transformer_name == 'roberta': 41 | return 'roberta-base' 42 | if transformer_name == 'gpt2': 43 | return 'gpt2' 44 | if transformer_name == 'transformerxl': 45 | return 'transfo-xl-wt103' 46 | if transformer_name == 'xlnet': 47 | return 'xlnet-base-cased' 48 | 49 | 50 | class GecBERTModel(object): 51 | def __init__(self, vocab_path=None, model_paths=None, 52 | weigths=None, 53 | max_len=50, 54 | min_len=3, 55 | lowercase_tokens=False, 56 | log=False, 57 | iterations=3, 58 | min_probability=0.0, 59 | model_name='roberta', 60 | special_tokens_fix=1, 61 | is_ensemble=True, 62 | min_error_probability=0.0, 63 | confidence=0, 64 | resolve_cycles=False, 65 | ): 66 | self.model_weights = list(map(float, weigths)) if weigths else [1] * len(model_paths) 67 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 68 | self.max_len = max_len 69 | self.min_len = min_len 70 | self.lowercase_tokens = lowercase_tokens 71 | self.min_probability = min_probability 72 | self.min_error_probability = min_error_probability 73 | self.vocab = Vocabulary.from_files(vocab_path) 74 | self.log = log 75 | self.iterations = iterations 76 | self.confidence = confidence 77 | self.resolve_cycles = resolve_cycles 78 | # set training parameters and operations 79 | 80 | self.indexers = [] 81 | self.models = [] 82 | for model_path in model_paths: 83 | if is_ensemble: 84 | model_name, special_tokens_fix = self._get_model_data(model_path) 85 | weights_name = get_weights_name(model_name, lowercase_tokens) 86 | self.indexers.append(self._get_indexer(weights_name, special_tokens_fix)) 87 | model = Seq2Labels(vocab=self.vocab, 88 | text_field_embedder=self._get_embbeder(weights_name, special_tokens_fix), 89 | confidence=self.confidence 90 | ).to(self.device) 91 | if torch.cuda.is_available(): 92 | model.load_state_dict(torch.load(model_path)) 93 | else: 94 | model.load_state_dict(torch.load(model_path, 95 | map_location=torch.device('cpu'))) 96 | model.eval() 97 | self.models.append(model) 98 | 99 | @staticmethod 100 | def _get_model_data(model_path): 101 | model_name = model_path.split('/')[-1] 102 | tr_model, stf = model_name.split('_')[:2] 103 | return tr_model, int(stf) 104 | 105 | def _restore_model(self, input_path): 106 | if os.path.isdir(input_path): 107 | print("Model could not be restored from directory", file=sys.stderr) 108 | filenames = [] 109 | else: 110 | filenames = [input_path] 111 | for model_path in filenames: 112 | try: 113 | if torch.cuda.is_available(): 114 | loaded_model = torch.load(model_path) 115 | else: 116 | loaded_model = torch.load(model_path, 117 | map_location=lambda storage, 118 | loc: storage) 119 | except: 120 | print(f"{model_path} is not valid model", file=sys.stderr) 121 | own_state = self.model.state_dict() 122 | for name, weights in loaded_model.items(): 123 | if name not in own_state: 124 | continue 125 | try: 126 | if len(filenames) == 1: 127 | own_state[name].copy_(weights) 128 | else: 129 | own_state[name] += weights 130 | except RuntimeError: 131 | continue 132 | print("Model is restored", file=sys.stderr) 133 | 134 | def predict(self, batches): 135 | t11 = time() 136 | predictions = [] 137 | for batch, model in zip(batches, self.models): 138 | batch = util.move_to_device(batch.as_tensor_dict(), 0 if torch.cuda.is_available() else -1) 139 | with torch.no_grad(): 140 | prediction = model.forward(**batch) 141 | predictions.append(prediction) 142 | 143 | preds, idx, error_probs = self._convert(predictions) 144 | t55 = time() 145 | if self.log: 146 | print(f"Inference time {t55 - t11}") 147 | return preds, idx, error_probs 148 | 149 | def get_token_action(self, token, index, prob, sugg_token): 150 | """Get lost of suggested actions for token.""" 151 | # cases when we don't need to do anything 152 | if prob < self.min_probability or sugg_token in [UNK, PAD, '$KEEP']: 153 | return None 154 | 155 | if sugg_token.startswith('$REPLACE_') or sugg_token.startswith('$TRANSFORM_') or sugg_token == '$DELETE': 156 | start_pos = index 157 | end_pos = index + 1 158 | elif sugg_token.startswith("$APPEND_") or sugg_token.startswith("$MERGE_"): 159 | start_pos = index + 1 160 | end_pos = index + 1 161 | 162 | if sugg_token == "$DELETE": 163 | sugg_token_clear = "" 164 | elif sugg_token.startswith('$TRANSFORM_') or sugg_token.startswith("$MERGE_"): 165 | sugg_token_clear = sugg_token[:] 166 | else: 167 | sugg_token_clear = sugg_token[sugg_token.index('_') + 1:] 168 | 169 | return start_pos - 1, end_pos - 1, sugg_token_clear, prob 170 | 171 | def _get_embbeder(self, weigths_name, special_tokens_fix): 172 | embedders = {'bert': PretrainedBertEmbedder( 173 | pretrained_model=weigths_name, 174 | requires_grad=False, 175 | top_layer_only=True, 176 | special_tokens_fix=special_tokens_fix) 177 | } 178 | text_field_embedder = BasicTextFieldEmbedder( 179 | token_embedders=embedders, 180 | embedder_to_indexer_map={"bert": ["bert", "bert-offsets"]}, 181 | allow_unmatched_keys=True) 182 | return text_field_embedder 183 | 184 | def _get_indexer(self, weights_name, special_tokens_fix): 185 | bert_token_indexer = PretrainedBertIndexer( 186 | pretrained_model=weights_name, 187 | do_lowercase=self.lowercase_tokens, 188 | max_pieces_per_token=5, 189 | use_starting_offsets=True, 190 | truncate_long_sequences=True, 191 | special_tokens_fix=special_tokens_fix, 192 | is_test=True 193 | ) 194 | return {'bert': bert_token_indexer} 195 | 196 | def preprocess(self, token_batch): 197 | seq_lens = [len(sequence) for sequence in token_batch if sequence] 198 | if not seq_lens: 199 | return [] 200 | max_len = min(max(seq_lens), self.max_len) 201 | batches = [] 202 | for indexer in self.indexers: 203 | batch = [] 204 | for sequence in token_batch: 205 | tokens = sequence[:max_len] 206 | tokens = [Token(token) for token in ['$START'] + tokens] 207 | batch.append(Instance({'tokens': TextField(tokens, indexer)})) 208 | batch = Batch(batch) 209 | batch.index_instances(self.vocab) 210 | batches.append(batch) 211 | 212 | return batches 213 | 214 | def _convert(self, data): 215 | all_class_probs = torch.zeros_like(data[0]['class_probabilities_labels']) 216 | error_probs = torch.zeros_like(data[0]['max_error_probability']) 217 | for output, weight in zip(data, self.model_weights): 218 | all_class_probs += weight * output['class_probabilities_labels'] / sum(self.model_weights) 219 | error_probs += weight * output['max_error_probability'] / sum(self.model_weights) 220 | 221 | max_vals = torch.max(all_class_probs, dim=-1) 222 | probs = max_vals[0].tolist() 223 | idx = max_vals[1].tolist() 224 | return probs, idx, error_probs.tolist() 225 | 226 | def update_final_batch(self, final_batch, pred_ids, pred_batch, 227 | prev_preds_dict): 228 | new_pred_ids = [] 229 | total_updated = 0 230 | for i, orig_id in enumerate(pred_ids): 231 | orig = final_batch[orig_id] 232 | pred = pred_batch[i] 233 | prev_preds = prev_preds_dict[orig_id] 234 | if orig != pred and pred not in prev_preds: 235 | final_batch[orig_id] = pred 236 | new_pred_ids.append(orig_id) 237 | prev_preds_dict[orig_id].append(pred) 238 | total_updated += 1 239 | elif orig != pred and pred in prev_preds: 240 | # update final batch, but stop iterations 241 | final_batch[orig_id] = pred 242 | total_updated += 1 243 | else: 244 | continue 245 | return final_batch, new_pred_ids, total_updated 246 | 247 | def postprocess_batch(self, batch, all_probabilities, all_idxs, 248 | error_probs, 249 | max_len=50): 250 | all_results = [] 251 | noop_index = self.vocab.get_token_index("$KEEP", "labels") 252 | for tokens, probabilities, idxs, error_prob in zip(batch, 253 | all_probabilities, 254 | all_idxs, 255 | error_probs): 256 | length = min(len(tokens), max_len) 257 | edits = [] 258 | 259 | # skip whole sentences if there no errors 260 | if max(idxs) == 0: 261 | all_results.append(tokens) 262 | continue 263 | 264 | # skip whole sentence if probability of correctness is not high 265 | if error_prob < self.min_error_probability: 266 | all_results.append(tokens) 267 | continue 268 | 269 | for i in range(length): 270 | token = tokens[i - 1] # because of START token 271 | # skip if there is no error 272 | if idxs[i] == noop_index: 273 | continue 274 | 275 | sugg_token = self.vocab.get_token_from_index(idxs[i], 276 | namespace='labels') 277 | action = self.get_token_action(token, i, probabilities[i], 278 | sugg_token) 279 | if not action: 280 | continue 281 | 282 | edits.append(action) 283 | all_results.append(get_target_sent_by_edits(tokens, edits)) 284 | return all_results 285 | 286 | def handle_batch(self, full_batch): 287 | """ 288 | Handle batch of requests. 289 | """ 290 | final_batch = full_batch[:] 291 | batch_size = len(full_batch) 292 | prev_preds_dict = {i: [final_batch[i]] for i in range(len(final_batch))} 293 | short_ids = [i for i in range(len(full_batch)) 294 | if len(full_batch[i]) < self.min_len] 295 | pred_ids = [i for i in range(len(full_batch)) if i not in short_ids] 296 | total_updates = 0 297 | 298 | for n_iter in range(self.iterations): 299 | orig_batch = [final_batch[i] for i in pred_ids] 300 | 301 | sequences = self.preprocess(orig_batch) 302 | 303 | if not sequences: 304 | break 305 | probabilities, idxs, error_probs = self.predict(sequences) 306 | 307 | pred_batch = self.postprocess_batch(orig_batch, probabilities, 308 | idxs, error_probs) 309 | if self.log: 310 | print(f"Iteration {n_iter + 1}. Predicted {round(100*len(pred_ids)/batch_size, 1)}% of sentences.") 311 | 312 | final_batch, pred_ids, cnt = \ 313 | self.update_final_batch(final_batch, pred_ids, pred_batch, 314 | prev_preds_dict) 315 | total_updates += cnt 316 | 317 | if not pred_ids: 318 | break 319 | 320 | return final_batch, total_updates 321 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from random import seed 4 | 5 | import torch 6 | from allennlp.data.iterators import BucketIterator 7 | from allennlp.data.vocabulary import DEFAULT_OOV_TOKEN, DEFAULT_PADDING_TOKEN 8 | from allennlp.data.vocabulary import Vocabulary 9 | from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder 10 | 11 | from gector.bert_token_embedder import PretrainedBertEmbedder 12 | from gector.datareader import Seq2LabelsDatasetReader 13 | from gector.seq2labels_model import Seq2Labels 14 | from gector.trainer import Trainer 15 | from gector.wordpiece_indexer import PretrainedBertIndexer 16 | from utils.helpers import get_weights_name 17 | 18 | 19 | def fix_seed(): 20 | torch.manual_seed(1) 21 | torch.backends.cudnn.enabled = False 22 | torch.backends.cudnn.benchmark = False 23 | torch.backends.cudnn.deterministic = True 24 | seed(43) 25 | 26 | 27 | def get_token_indexers(model_name, max_pieces_per_token=5, lowercase_tokens=True, special_tokens_fix=0, is_test=False): 28 | bert_token_indexer = PretrainedBertIndexer( 29 | pretrained_model=model_name, 30 | max_pieces_per_token=max_pieces_per_token, 31 | do_lowercase=lowercase_tokens, 32 | use_starting_offsets=True, 33 | special_tokens_fix=special_tokens_fix, 34 | is_test=is_test 35 | ) 36 | return {'bert': bert_token_indexer} 37 | 38 | 39 | def get_token_embedders(model_name, tune_bert=False, special_tokens_fix=0): 40 | take_grads = True if tune_bert > 0 else False 41 | bert_token_emb = PretrainedBertEmbedder( 42 | pretrained_model=model_name, 43 | top_layer_only=True, requires_grad=take_grads, 44 | special_tokens_fix=special_tokens_fix) 45 | 46 | token_embedders = {'bert': bert_token_emb} 47 | embedder_to_indexer_map = {"bert": ["bert", "bert-offsets"]} 48 | 49 | text_filed_emd = BasicTextFieldEmbedder(token_embedders=token_embedders, 50 | embedder_to_indexer_map=embedder_to_indexer_map, 51 | allow_unmatched_keys=True) 52 | return text_filed_emd 53 | 54 | 55 | def get_data_reader(model_name, max_len, skip_correct=False, skip_complex=0, 56 | test_mode=False, tag_strategy="keep_one", 57 | broken_dot_strategy="keep", lowercase_tokens=True, 58 | max_pieces_per_token=3, tn_prob=0, tp_prob=1, special_tokens_fix=0,): 59 | token_indexers = get_token_indexers(model_name, 60 | max_pieces_per_token=max_pieces_per_token, 61 | lowercase_tokens=lowercase_tokens, 62 | special_tokens_fix=special_tokens_fix, 63 | is_test=test_mode) 64 | reader = Seq2LabelsDatasetReader(token_indexers=token_indexers, 65 | max_len=max_len, 66 | skip_correct=skip_correct, 67 | skip_complex=skip_complex, 68 | test_mode=test_mode, 69 | tag_strategy=tag_strategy, 70 | broken_dot_strategy=broken_dot_strategy, 71 | lazy=True, 72 | tn_prob=tn_prob, 73 | tp_prob=tp_prob) 74 | return reader 75 | 76 | 77 | def get_model(model_name, vocab, tune_bert=False, 78 | predictor_dropout=0, 79 | label_smoothing=0.0, 80 | confidence=0, 81 | special_tokens_fix=0): 82 | token_embs = get_token_embedders(model_name, tune_bert=tune_bert, special_tokens_fix=special_tokens_fix) 83 | model = Seq2Labels(vocab=vocab, 84 | text_field_embedder=token_embs, 85 | predictor_dropout=predictor_dropout, 86 | label_smoothing=label_smoothing, 87 | confidence=confidence) 88 | return model 89 | 90 | 91 | def main(args): 92 | fix_seed() 93 | if not os.path.exists(args.model_dir): 94 | os.mkdir(args.model_dir) 95 | 96 | weights_name = get_weights_name(args.transformer_model, args.lowercase_tokens) 97 | # read datasets 98 | reader = get_data_reader(weights_name, args.max_len, skip_correct=bool(args.skip_correct), 99 | skip_complex=args.skip_complex, 100 | test_mode=False, 101 | tag_strategy=args.tag_strategy, 102 | lowercase_tokens=args.lowercase_tokens, 103 | max_pieces_per_token=args.pieces_per_token, 104 | tn_prob=args.tn_prob, 105 | tp_prob=args.tp_prob, 106 | special_tokens_fix=args.special_tokens_fix) 107 | train_data = reader.read(args.train_set) 108 | dev_data = reader.read(args.dev_set) 109 | 110 | default_tokens = [DEFAULT_OOV_TOKEN, DEFAULT_PADDING_TOKEN] 111 | namespaces = ['labels', 'd_tags'] 112 | tokens_to_add = {x: default_tokens for x in namespaces} 113 | # build vocab 114 | if args.vocab_path: 115 | vocab = Vocabulary.from_files(args.vocab_path) 116 | else: 117 | vocab = Vocabulary.from_instances(train_data, 118 | max_vocab_size={'tokens': 30000, 119 | 'labels': args.target_vocab_size, 120 | 'd_tags': 2}, 121 | tokens_to_add=tokens_to_add) 122 | vocab.save_to_files(os.path.join(args.model_dir, 'vocabulary')) 123 | 124 | print("Data is loaded") 125 | model = get_model(weights_name, vocab, 126 | tune_bert=args.tune_bert, 127 | predictor_dropout=args.predictor_dropout, 128 | label_smoothing=args.label_smoothing, 129 | special_tokens_fix=args.special_tokens_fix) 130 | 131 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 132 | if torch.cuda.is_available(): 133 | if torch.cuda.device_count() > 1: 134 | cuda_device = list(range(torch.cuda.device_count())) 135 | else: 136 | cuda_device = 0 137 | else: 138 | cuda_device = -1 139 | 140 | if args.pretrain: 141 | model.load_state_dict(torch.load(os.path.join(args.pretrain_folder, args.pretrain + '.th'))) 142 | 143 | model = model.to(device) 144 | 145 | print("Model is set") 146 | 147 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 148 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 149 | optimizer, factor=0.1, patience=10) 150 | instances_per_epoch = None if not args.updates_per_epoch else \ 151 | int(args.updates_per_epoch * args.batch_size * args.accumulation_size) 152 | iterator = BucketIterator(batch_size=args.batch_size, 153 | sorting_keys=[("tokens", "num_tokens")], 154 | biggest_batch_first=True, 155 | max_instances_in_memory=args.batch_size * 20000, 156 | instances_per_epoch=instances_per_epoch, 157 | ) 158 | iterator.index_with(vocab) 159 | trainer = Trainer(model=model, 160 | optimizer=optimizer, 161 | scheduler=scheduler, 162 | iterator=iterator, 163 | train_dataset=train_data, 164 | validation_dataset=dev_data, 165 | serialization_dir=args.model_dir, 166 | patience=args.patience, 167 | num_epochs=args.n_epoch, 168 | cuda_device=cuda_device, 169 | shuffle=False, 170 | accumulated_batch_count=args.accumulation_size, 171 | cold_step_count=args.cold_steps_count, 172 | cold_lr=args.cold_lr, 173 | cuda_verbose_step=int(args.cuda_verbose_steps) 174 | if args.cuda_verbose_steps else None 175 | ) 176 | print("Start training") 177 | trainer.train() 178 | 179 | # Here's how to save the model. 180 | out_model = os.path.join(args.model_dir, 'model.th') 181 | with open(out_model, 'wb') as f: 182 | torch.save(model.state_dict(), f) 183 | print("Model is dumped") 184 | 185 | 186 | if __name__ == '__main__': 187 | # read parameters 188 | parser = argparse.ArgumentParser() 189 | parser.add_argument('--train_set', 190 | help='Path to the train data', required=True) 191 | parser.add_argument('--dev_set', 192 | help='Path to the dev data', required=True) 193 | parser.add_argument('--model_dir', 194 | help='Path to the model dir', required=True) 195 | parser.add_argument('--vocab_path', 196 | help='Path to the model vocabulary directory.' 197 | 'If not set then build vocab from data', 198 | default='') 199 | parser.add_argument('--batch_size', 200 | type=int, 201 | help='The size of the batch.', 202 | default=32) 203 | parser.add_argument('--max_len', 204 | type=int, 205 | help='The max sentence length' 206 | '(all longer will be truncated)', 207 | default=50) 208 | parser.add_argument('--target_vocab_size', 209 | type=int, 210 | help='The size of target vocabularies.', 211 | default=1000) 212 | parser.add_argument('--n_epoch', 213 | type=int, 214 | help='The number of epoch for training model.', 215 | default=20) 216 | parser.add_argument('--patience', 217 | type=int, 218 | help='The number of epoch with any improvements' 219 | ' on validation set.', 220 | default=3) 221 | parser.add_argument('--skip_correct', 222 | type=int, 223 | help='If set than correct sentences will be skipped ' 224 | 'by data reader.', 225 | default=1) 226 | parser.add_argument('--skip_complex', 227 | type=int, 228 | help='If set than complex corrections will be skipped ' 229 | 'by data reader.', 230 | choices=[0, 1, 2, 3, 4, 5], 231 | default=0) 232 | parser.add_argument('--tune_bert', 233 | type=int, 234 | help='If more then 0 then fine tune bert.', 235 | default=1) 236 | parser.add_argument('--tag_strategy', 237 | choices=['keep_one', 'merge_all'], 238 | help='The type of the data reader behaviour.', 239 | default='keep_one') 240 | parser.add_argument('--accumulation_size', 241 | type=int, 242 | help='How many batches do you want accumulate.', 243 | default=4) 244 | parser.add_argument('--lr', 245 | type=float, 246 | help='Set initial learning rate.', 247 | default=1e-5) 248 | parser.add_argument('--cold_steps_count', 249 | type=int, 250 | help='Whether to train only classifier layers first.', 251 | default=4) 252 | parser.add_argument('--cold_lr', 253 | type=float, 254 | help='Learning rate during cold_steps.', 255 | default=1e-3) 256 | parser.add_argument('--predictor_dropout', 257 | type=float, 258 | help='The value of dropout for predictor.', 259 | default=0.0) 260 | parser.add_argument('--lowercase_tokens', 261 | type=int, 262 | help='Whether to lowercase tokens.', 263 | default=0) 264 | parser.add_argument('--pieces_per_token', 265 | type=int, 266 | help='The max number for pieces per token.', 267 | default=5) 268 | parser.add_argument('--cuda_verbose_steps', 269 | help='Number of steps after which CUDA memory information is printed. ' 270 | 'Makes sense for local testing. Usually about 1000.', 271 | default=None) 272 | parser.add_argument('--label_smoothing', 273 | type=float, 274 | help='The value of parameter alpha for label smoothing.', 275 | default=0.0) 276 | parser.add_argument('--tn_prob', 277 | type=float, 278 | help='The probability to take TN from data.', 279 | default=0) 280 | parser.add_argument('--tp_prob', 281 | type=float, 282 | help='The probability to take TP from data.', 283 | default=1) 284 | parser.add_argument('--updates_per_epoch', 285 | type=int, 286 | help='If set then each epoch will contain the exact amount of updates.', 287 | default=0) 288 | parser.add_argument('--pretrain_folder', 289 | help='The name of the pretrain folder.') 290 | parser.add_argument('--pretrain', 291 | help='The name of the pretrain weights in pretrain_folder param.', 292 | default='') 293 | parser.add_argument('--transformer_model', 294 | choices=['bert', 'distilbert', 'gpt2', 'roberta', 'transformerxl', 'xlnet', 'albert'], 295 | help='Name of the transformer model.', 296 | default='roberta') 297 | parser.add_argument('--special_tokens_fix', 298 | type=int, 299 | help='Whether to fix problem with [CLS], [SEP] tokens tokenization.', 300 | default=1) 301 | 302 | args = parser.parse_args() 303 | main(args) 304 | -------------------------------------------------------------------------------- /utils/preprocess_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from difflib import SequenceMatcher 4 | 5 | import Levenshtein 6 | import numpy as np 7 | from tqdm import tqdm 8 | 9 | from helpers import write_lines, read_parallel_lines, encode_verb_form, \ 10 | apply_reverse_transformation, SEQ_DELIMETERS, START_TOKEN 11 | 12 | 13 | def perfect_align(t, T, insertions_allowed=0, 14 | cost_function=Levenshtein.distance): 15 | # dp[i, j, k] is a minimal cost of matching first `i` tokens of `t` with 16 | # first `j` tokens of `T`, after making `k` insertions after last match of 17 | # token from `t`. In other words t[:i] aligned with T[:j]. 18 | 19 | # Initialize with INFINITY (unknown) 20 | shape = (len(t) + 1, len(T) + 1, insertions_allowed + 1) 21 | dp = np.ones(shape, dtype=int) * int(1e9) 22 | come_from = np.ones(shape, dtype=int) * int(1e9) 23 | come_from_ins = np.ones(shape, dtype=int) * int(1e9) 24 | 25 | dp[0, 0, 0] = 0 # The only known starting point. Nothing matched to nothing. 26 | for i in range(len(t) + 1): # Go inclusive 27 | for j in range(len(T) + 1): # Go inclusive 28 | for q in range(insertions_allowed + 1): # Go inclusive 29 | if i < len(t): 30 | # Given matched sequence of t[:i] and T[:j], match token 31 | # t[i] with following tokens T[j:k]. 32 | for k in range(j, len(T) + 1): 33 | transform = \ 34 | apply_transformation(t[i], ' '.join(T[j:k])) 35 | if transform: 36 | cost = 0 37 | else: 38 | cost = cost_function(t[i], ' '.join(T[j:k])) 39 | current = dp[i, j, q] + cost 40 | if dp[i + 1, k, 0] > current: 41 | dp[i + 1, k, 0] = current 42 | come_from[i + 1, k, 0] = j 43 | come_from_ins[i + 1, k, 0] = q 44 | if q < insertions_allowed: 45 | # Given matched sequence of t[:i] and T[:j], create 46 | # insertion with following tokens T[j:k]. 47 | for k in range(j, len(T) + 1): 48 | cost = len(' '.join(T[j:k])) 49 | current = dp[i, j, q] + cost 50 | if dp[i, k, q + 1] > current: 51 | dp[i, k, q + 1] = current 52 | come_from[i, k, q + 1] = j 53 | come_from_ins[i, k, q + 1] = q 54 | 55 | # Solution is in the dp[len(t), len(T), *]. Backtracking from there. 56 | alignment = [] 57 | i = len(t) 58 | j = len(T) 59 | q = dp[i, j, :].argmin() 60 | while i > 0 or q > 0: 61 | is_insert = (come_from_ins[i, j, q] != q) and (q != 0) 62 | j, k, q = come_from[i, j, q], j, come_from_ins[i, j, q] 63 | if not is_insert: 64 | i -= 1 65 | 66 | if is_insert: 67 | alignment.append(['INSERT', T[j:k], (i, i)]) 68 | else: 69 | alignment.append([f'REPLACE_{t[i]}', T[j:k], (i, i + 1)]) 70 | 71 | assert j == 0 72 | 73 | return dp[len(t), len(T)].min(), list(reversed(alignment)) 74 | 75 | 76 | def _split(token): 77 | if not token: 78 | return [] 79 | parts = token.split() 80 | return parts or [token] 81 | 82 | 83 | def apply_merge_transformation(source_tokens, target_words, shift_idx): 84 | edits = [] 85 | if len(source_tokens) > 1 and len(target_words) == 1: 86 | # check merge 87 | transform = check_merge(source_tokens, target_words) 88 | if transform: 89 | for i in range(len(source_tokens) - 1): 90 | edits.append([(shift_idx + i, shift_idx + i + 1), transform]) 91 | return edits 92 | 93 | if len(source_tokens) == len(target_words) == 2: 94 | # check swap 95 | transform = check_swap(source_tokens, target_words) 96 | if transform: 97 | edits.append([(shift_idx, shift_idx + 1), transform]) 98 | return edits 99 | 100 | 101 | def is_sent_ok(sent, delimeters=SEQ_DELIMETERS): 102 | for del_val in delimeters.values(): 103 | if del_val in sent and del_val != " ": 104 | return False 105 | return True 106 | 107 | 108 | def check_casetype(source_token, target_token): 109 | if source_token.lower() != target_token.lower(): 110 | return None 111 | if source_token.lower() == target_token: 112 | return "$TRANSFORM_CASE_LOWER" 113 | elif source_token.capitalize() == target_token: 114 | return "$TRANSFORM_CASE_CAPITAL" 115 | elif source_token.upper() == target_token: 116 | return "$TRANSFORM_CASE_UPPER" 117 | elif source_token[1:].capitalize() == target_token[1:] and source_token[0] == target_token[0]: 118 | return "$TRANSFORM_CASE_CAPITAL_1" 119 | elif source_token[:-1].upper() == target_token[:-1] and source_token[-1] == target_token[-1]: 120 | return "$TRANSFORM_CASE_UPPER_-1" 121 | else: 122 | return None 123 | 124 | 125 | def check_equal(source_token, target_token): 126 | if source_token == target_token: 127 | return "$KEEP" 128 | else: 129 | return None 130 | 131 | 132 | def check_split(source_token, target_tokens): 133 | if source_token.split("-") == target_tokens: 134 | return "$TRANSFORM_SPLIT_HYPHEN" 135 | else: 136 | return None 137 | 138 | 139 | def check_merge(source_tokens, target_tokens): 140 | if "".join(source_tokens) == "".join(target_tokens): 141 | return "$MERGE_SPACE" 142 | elif "-".join(source_tokens) == "-".join(target_tokens): 143 | return "$MERGE_HYPHEN" 144 | else: 145 | return None 146 | 147 | 148 | def check_swap(source_tokens, target_tokens): 149 | if source_tokens == [x for x in reversed(target_tokens)]: 150 | return "$MERGE_SWAP" 151 | else: 152 | return None 153 | 154 | 155 | def check_plural(source_token, target_token): 156 | if source_token.endswith("s") and source_token[:-1] == target_token: 157 | return "$TRANSFORM_AGREEMENT_SINGULAR" 158 | elif target_token.endswith("s") and source_token == target_token[:-1]: 159 | return "$TRANSFORM_AGREEMENT_PLURAL" 160 | else: 161 | return None 162 | 163 | 164 | def check_verb(source_token, target_token): 165 | encoding = encode_verb_form(source_token, target_token) 166 | if encoding: 167 | return f"$TRANSFORM_VERB_{encoding}" 168 | else: 169 | return None 170 | 171 | 172 | def apply_transformation(source_token, target_token): 173 | target_tokens = target_token.split() 174 | if len(target_tokens) > 1: 175 | # check split 176 | transform = check_split(source_token, target_tokens) 177 | if transform: 178 | return transform 179 | checks = [check_equal, check_casetype, check_verb, check_plural] 180 | for check in checks: 181 | transform = check(source_token, target_token) 182 | if transform: 183 | return transform 184 | return None 185 | 186 | 187 | def align_sequences(source_sent, target_sent): 188 | # check if sent is OK 189 | if not is_sent_ok(source_sent) or not is_sent_ok(target_sent): 190 | return None 191 | source_tokens = source_sent.split() 192 | target_tokens = target_sent.split() 193 | matcher = SequenceMatcher(None, source_tokens, target_tokens) 194 | diffs = list(matcher.get_opcodes()) 195 | all_edits = [] 196 | for diff in diffs: 197 | tag, i1, i2, j1, j2 = diff 198 | source_part = _split(" ".join(source_tokens[i1:i2])) 199 | target_part = _split(" ".join(target_tokens[j1:j2])) 200 | if tag == 'equal': 201 | continue 202 | elif tag == 'delete': 203 | # delete all words separatly 204 | for j in range(i2 - i1): 205 | edit = [(i1 + j, i1 + j + 1), '$DELETE'] 206 | all_edits.append(edit) 207 | elif tag == 'insert': 208 | # append to the previous word 209 | for target_token in target_part: 210 | edit = ((i1 - 1, i1), f"$APPEND_{target_token}") 211 | all_edits.append(edit) 212 | else: 213 | # check merge first of all 214 | edits = apply_merge_transformation(source_part, target_part, 215 | shift_idx=i1) 216 | if edits: 217 | all_edits.extend(edits) 218 | continue 219 | 220 | # normalize alignments if need (make them singleton) 221 | _, alignments = perfect_align(source_part, target_part, 222 | insertions_allowed=0) 223 | for alignment in alignments: 224 | new_shift = alignment[2][0] 225 | edits = convert_alignments_into_edits(alignment, 226 | shift_idx=i1 + new_shift) 227 | all_edits.extend(edits) 228 | 229 | # get labels 230 | labels = convert_edits_into_labels(source_tokens, all_edits) 231 | # match tags to source tokens 232 | sent_with_tags = add_labels_to_the_tokens(source_tokens, labels) 233 | return sent_with_tags 234 | 235 | 236 | def convert_edits_into_labels(source_tokens, all_edits): 237 | # make sure that edits are flat 238 | flat_edits = [] 239 | for edit in all_edits: 240 | (start, end), edit_operations = edit 241 | if isinstance(edit_operations, list): 242 | for operation in edit_operations: 243 | new_edit = [(start, end), operation] 244 | flat_edits.append(new_edit) 245 | elif isinstance(edit_operations, str): 246 | flat_edits.append(edit) 247 | else: 248 | raise Exception("Unknown operation type") 249 | all_edits = flat_edits[:] 250 | labels = [] 251 | total_labels = len(source_tokens) + 1 252 | if not all_edits: 253 | labels = [["$KEEP"] for x in range(total_labels)] 254 | else: 255 | for i in range(total_labels): 256 | edit_operations = [x[1] for x in all_edits if x[0][0] == i - 1 257 | and x[0][1] == i] 258 | if not edit_operations: 259 | labels.append(["$KEEP"]) 260 | else: 261 | labels.append(edit_operations) 262 | return labels 263 | 264 | 265 | def convert_alignments_into_edits(alignment, shift_idx): 266 | edits = [] 267 | action, target_tokens, new_idx = alignment 268 | source_token = action.replace("REPLACE_", "") 269 | 270 | # check if delete 271 | if not target_tokens: 272 | edit = [(shift_idx, 1 + shift_idx), "$DELETE"] 273 | return [edit] 274 | 275 | # check splits 276 | for i in range(1, len(target_tokens)): 277 | target_token = " ".join(target_tokens[:i + 1]) 278 | transform = apply_transformation(source_token, target_token) 279 | if transform: 280 | edit = [(shift_idx, shift_idx + 1), transform] 281 | edits.append(edit) 282 | target_tokens = target_tokens[i + 1:] 283 | for target in target_tokens: 284 | edits.append([(shift_idx, shift_idx + 1), f"$APPEND_{target}"]) 285 | return edits 286 | 287 | transform_costs = [] 288 | transforms = [] 289 | for target_token in target_tokens: 290 | transform = apply_transformation(source_token, target_token) 291 | if transform: 292 | cost = 0 293 | transforms.append(transform) 294 | else: 295 | cost = Levenshtein.distance(source_token, target_token) 296 | transforms.append(None) 297 | transform_costs.append(cost) 298 | min_cost_idx = transform_costs.index(min(transform_costs)) 299 | # append to the previous word 300 | for i in range(0, min_cost_idx): 301 | target = target_tokens[i] 302 | edit = [(shift_idx - 1, shift_idx), f"$APPEND_{target}"] 303 | edits.append(edit) 304 | # replace/transform target word 305 | transform = transforms[min_cost_idx] 306 | target = transform if transform is not None \ 307 | else f"$REPLACE_{target_tokens[min_cost_idx]}" 308 | edit = [(shift_idx, 1 + shift_idx), target] 309 | edits.append(edit) 310 | # append to this word 311 | for i in range(min_cost_idx + 1, len(target_tokens)): 312 | target = target_tokens[i] 313 | edit = [(shift_idx, 1 + shift_idx), f"$APPEND_{target}"] 314 | edits.append(edit) 315 | return edits 316 | 317 | 318 | def add_labels_to_the_tokens(source_tokens, labels, delimeters=SEQ_DELIMETERS): 319 | tokens_with_all_tags = [] 320 | source_tokens_with_start = [START_TOKEN] + source_tokens 321 | for token, label_list in zip(source_tokens_with_start, labels): 322 | all_tags = delimeters['operations'].join(label_list) 323 | comb_record = token + delimeters['labels'] + all_tags 324 | tokens_with_all_tags.append(comb_record) 325 | return delimeters['tokens'].join(tokens_with_all_tags) 326 | 327 | 328 | def convert_data_from_raw_files(source_file, target_file, output_file, chunk_size): 329 | tagged = [] 330 | source_data, target_data = read_parallel_lines(source_file, target_file) 331 | print(f"The size of raw dataset is {len(source_data)}") 332 | cnt_total, cnt_all, cnt_tp = 0, 0, 0 333 | for source_sent, target_sent in tqdm(zip(source_data, target_data)): 334 | try: 335 | aligned_sent = align_sequences(source_sent, target_sent) 336 | except Exception: 337 | aligned_sent = align_sequences(source_sent, target_sent) 338 | if source_sent != target_sent: 339 | cnt_tp += 1 340 | alignments = [aligned_sent] 341 | cnt_all += len(alignments) 342 | try: 343 | check_sent = convert_tagged_line(aligned_sent) 344 | except Exception: 345 | # debug mode 346 | aligned_sent = align_sequences(source_sent, target_sent) 347 | check_sent = convert_tagged_line(aligned_sent) 348 | 349 | if "".join(check_sent.split()) != "".join( 350 | target_sent.split()): 351 | # do it again for debugging 352 | aligned_sent = align_sequences(source_sent, target_sent) 353 | check_sent = convert_tagged_line(aligned_sent) 354 | print(f"Incorrect pair: \n{target_sent}\n{check_sent}") 355 | continue 356 | if alignments: 357 | cnt_total += len(alignments) 358 | tagged.extend(alignments) 359 | if len(tagged) > chunk_size: 360 | write_lines(output_file, tagged, mode='a') 361 | tagged = [] 362 | 363 | print(f"Overall extracted {cnt_total}. " 364 | f"Original TP {cnt_tp}." 365 | f" Original TN {cnt_all - cnt_tp}") 366 | if tagged: 367 | write_lines(output_file, tagged, 'a') 368 | 369 | 370 | def convert_labels_into_edits(labels): 371 | all_edits = [] 372 | for i, label_list in enumerate(labels): 373 | if label_list == ["$KEEP"]: 374 | continue 375 | else: 376 | edit = [(i - 1, i), label_list] 377 | all_edits.append(edit) 378 | return all_edits 379 | 380 | 381 | def get_target_sent_by_levels(source_tokens, labels): 382 | relevant_edits = convert_labels_into_edits(labels) 383 | target_tokens = source_tokens[:] 384 | leveled_target_tokens = {} 385 | if not relevant_edits: 386 | target_sentence = " ".join(target_tokens) 387 | return leveled_target_tokens, target_sentence 388 | max_level = max([len(x[1]) for x in relevant_edits]) 389 | for level in range(max_level): 390 | rest_edits = [] 391 | shift_idx = 0 392 | for edits in relevant_edits: 393 | (start, end), label_list = edits 394 | label = label_list[0] 395 | target_pos = start + shift_idx 396 | source_token = target_tokens[target_pos] if target_pos >= 0 else START_TOKEN 397 | if label == "$DELETE": 398 | del target_tokens[target_pos] 399 | shift_idx -= 1 400 | elif label.startswith("$APPEND_"): 401 | word = label.replace("$APPEND_", "") 402 | target_tokens[target_pos + 1: target_pos + 1] = [word] 403 | shift_idx += 1 404 | elif label.startswith("$REPLACE_"): 405 | word = label.replace("$REPLACE_", "") 406 | target_tokens[target_pos] = word 407 | elif label.startswith("$TRANSFORM"): 408 | word = apply_reverse_transformation(source_token, label) 409 | if word is None: 410 | word = source_token 411 | target_tokens[target_pos] = word 412 | elif label.startswith("$MERGE_"): 413 | # apply merge only on last stage 414 | if level == (max_level - 1): 415 | target_tokens[target_pos + 1: target_pos + 1] = [label] 416 | shift_idx += 1 417 | else: 418 | rest_edit = [(start + shift_idx, end + shift_idx), [label]] 419 | rest_edits.append(rest_edit) 420 | rest_labels = label_list[1:] 421 | if rest_labels: 422 | rest_edit = [(start + shift_idx, end + shift_idx), rest_labels] 423 | rest_edits.append(rest_edit) 424 | 425 | leveled_tokens = target_tokens[:] 426 | # update next step 427 | relevant_edits = rest_edits[:] 428 | if level == (max_level - 1): 429 | leveled_tokens = replace_merge_transforms(leveled_tokens) 430 | leveled_labels = convert_edits_into_labels(leveled_tokens, 431 | relevant_edits) 432 | leveled_target_tokens[level + 1] = {"tokens": leveled_tokens, 433 | "labels": leveled_labels} 434 | 435 | target_sentence = " ".join(leveled_target_tokens[max_level]["tokens"]) 436 | return leveled_target_tokens, target_sentence 437 | 438 | 439 | def replace_merge_transforms(tokens): 440 | if all(not x.startswith("$MERGE_") for x in tokens): 441 | return tokens 442 | target_tokens = tokens[:] 443 | allowed_range = (1, len(tokens) - 1) 444 | for i in range(len(tokens)): 445 | target_token = tokens[i] 446 | if target_token.startswith("$MERGE"): 447 | if target_token.startswith("$MERGE_SWAP") and i in allowed_range: 448 | target_tokens[i - 1] = tokens[i + 1] 449 | target_tokens[i + 1] = tokens[i - 1] 450 | target_tokens[i: i + 1] = [] 451 | target_line = " ".join(target_tokens) 452 | target_line = target_line.replace(" $MERGE_HYPHEN ", "-") 453 | target_line = target_line.replace(" $MERGE_SPACE ", "") 454 | return target_line.split() 455 | 456 | 457 | def convert_tagged_line(line, delimeters=SEQ_DELIMETERS): 458 | label_del = delimeters['labels'] 459 | source_tokens = [x.split(label_del)[0] 460 | for x in line.split(delimeters['tokens'])][1:] 461 | labels = [x.split(label_del)[1].split(delimeters['operations']) 462 | for x in line.split(delimeters['tokens'])] 463 | assert len(source_tokens) + 1 == len(labels) 464 | levels_dict, target_line = get_target_sent_by_levels(source_tokens, labels) 465 | return target_line 466 | 467 | 468 | def main(args): 469 | convert_data_from_raw_files(args.source, args.target, args.output_file, args.chunk_size) 470 | 471 | 472 | if __name__ == '__main__': 473 | parser = argparse.ArgumentParser() 474 | parser.add_argument('-s', '--source', 475 | help='Path to the source file', 476 | required=True) 477 | parser.add_argument('-t', '--target', 478 | help='Path to the target file', 479 | required=True) 480 | parser.add_argument('-o', '--output_file', 481 | help='Path to the output file', 482 | required=True) 483 | parser.add_argument('--chunk_size', 484 | type=int, 485 | help='Dump each chunk size.', 486 | default=1000000) 487 | args = parser.parse_args() 488 | main(args) 489 | -------------------------------------------------------------------------------- /gector/wordpiece_indexer.py: -------------------------------------------------------------------------------- 1 | """Tweaked version of corresponding AllenNLP file""" 2 | import logging 3 | from collections import defaultdict 4 | from typing import Dict, List, Callable 5 | 6 | from allennlp.common.util import pad_sequence_to_length 7 | from allennlp.data.token_indexers.token_indexer import TokenIndexer 8 | from allennlp.data.tokenizers.token import Token 9 | from allennlp.data.vocabulary import Vocabulary 10 | from overrides import overrides 11 | from transformers import AutoTokenizer 12 | 13 | from utils.helpers import START_TOKEN 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | # TODO(joelgrus): Figure out how to generate token_type_ids out of this token indexer. 18 | 19 | # This is the default list of tokens that should not be lowercased. 20 | _NEVER_LOWERCASE = ['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]'] 21 | 22 | 23 | class WordpieceIndexer(TokenIndexer[int]): 24 | """ 25 | A token indexer that does the wordpiece-tokenization (e.g. for BERT embeddings). 26 | If you are using one of the pretrained BERT models, you'll want to use the ``PretrainedBertIndexer`` 27 | subclass rather than this base class. 28 | 29 | Parameters 30 | ---------- 31 | vocab : ``Dict[str, int]`` 32 | The mapping {wordpiece -> id}. Note this is not an AllenNLP ``Vocabulary``. 33 | wordpiece_tokenizer : ``Callable[[str], List[str]]`` 34 | A function that does the actual tokenization. 35 | namespace : str, optional (default: "wordpiece") 36 | The namespace in the AllenNLP ``Vocabulary`` into which the wordpieces 37 | will be loaded. 38 | use_starting_offsets : bool, optional (default: False) 39 | By default, the "offsets" created by the token indexer correspond to the 40 | last wordpiece in each word. If ``use_starting_offsets`` is specified, 41 | they will instead correspond to the first wordpiece in each word. 42 | max_pieces : int, optional (default: 512) 43 | The BERT embedder uses positional embeddings and so has a corresponding 44 | maximum length for its input ids. Any inputs longer than this will 45 | either be truncated (default), or be split apart and batched using a 46 | sliding window. 47 | do_lowercase : ``bool``, optional (default=``False``) 48 | Should we lowercase the provided tokens before getting the indices? 49 | You would need to do this if you are using an -uncased BERT model 50 | but your DatasetReader is not lowercasing tokens (which might be the 51 | case if you're also using other embeddings based on cased tokens). 52 | never_lowercase: ``List[str]``, optional 53 | Tokens that should never be lowercased. Default is 54 | ['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]']. 55 | start_tokens : ``List[str]``, optional (default=``None``) 56 | These are prepended to the tokens provided to ``tokens_to_indices``. 57 | end_tokens : ``List[str]``, optional (default=``None``) 58 | These are appended to the tokens provided to ``tokens_to_indices``. 59 | separator_token : ``str``, optional (default=``[SEP]``) 60 | This token indicates the segments in the sequence. 61 | truncate_long_sequences : ``bool``, optional (default=``True``) 62 | By default, long sequences will be truncated to the maximum sequence 63 | length. Otherwise, they will be split apart and batched using a 64 | sliding window. 65 | token_min_padding_length : ``int``, optional (default=``0``) 66 | See :class:`TokenIndexer`. 67 | """ 68 | 69 | def __init__(self, 70 | vocab: Dict[str, int], 71 | bpe_ranks: Dict, 72 | byte_encoder: Dict, 73 | wordpiece_tokenizer: Callable[[str], List[str]], 74 | namespace: str = "wordpiece", 75 | use_starting_offsets: bool = False, 76 | max_pieces: int = 512, 77 | max_pieces_per_token: int = 3, 78 | is_test=False, 79 | do_lowercase: bool = False, 80 | never_lowercase: List[str] = None, 81 | start_tokens: List[str] = None, 82 | end_tokens: List[str] = None, 83 | truncate_long_sequences: bool = True, 84 | token_min_padding_length: int = 0) -> None: 85 | super().__init__(token_min_padding_length) 86 | self.vocab = vocab 87 | 88 | # The BERT code itself does a two-step tokenization: 89 | # sentence -> [words], and then word -> [wordpieces] 90 | # In AllenNLP, the first step is implemented as the ``BertBasicWordSplitter``, 91 | # and this token indexer handles the second. 92 | self.wordpiece_tokenizer = wordpiece_tokenizer 93 | self.max_pieces_per_token = max_pieces_per_token 94 | self._namespace = namespace 95 | self._added_to_vocabulary = False 96 | self.max_pieces = max_pieces 97 | self.use_starting_offsets = use_starting_offsets 98 | self._do_lowercase = do_lowercase 99 | self._truncate_long_sequences = truncate_long_sequences 100 | self.max_pieces_per_sentence = 80 101 | self.is_test = is_test 102 | self.cache = {} 103 | self.bpe_ranks = bpe_ranks 104 | self.byte_encoder = byte_encoder 105 | 106 | if self.is_test: 107 | self.max_pieces_per_token = None 108 | 109 | if never_lowercase is None: 110 | # Use the defaults 111 | self._never_lowercase = set(_NEVER_LOWERCASE) 112 | else: 113 | self._never_lowercase = set(never_lowercase) 114 | 115 | # Convert the start_tokens and end_tokens to wordpiece_ids 116 | self._start_piece_ids = [vocab[wordpiece] 117 | for token in (start_tokens or []) 118 | for wordpiece in wordpiece_tokenizer(token)] 119 | self._end_piece_ids = [vocab[wordpiece] 120 | for token in (end_tokens or []) 121 | for wordpiece in wordpiece_tokenizer(token)] 122 | 123 | @overrides 124 | def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]): 125 | # If we only use pretrained models, we don't need to do anything here. 126 | pass 127 | 128 | def _add_encoding_to_vocabulary(self, vocabulary: Vocabulary) -> None: 129 | # pylint: disable=protected-access 130 | for word, idx in self.vocab.items(): 131 | vocabulary._token_to_index[self._namespace][word] = idx 132 | vocabulary._index_to_token[self._namespace][idx] = word 133 | 134 | def get_pairs(self, word): 135 | """Return set of symbol pairs in a word. 136 | 137 | Word is represented as tuple of symbols (symbols being variable-length strings). 138 | """ 139 | pairs = set() 140 | prev_char = word[0] 141 | for char in word[1:]: 142 | pairs.add((prev_char, char)) 143 | prev_char = char 144 | return pairs 145 | 146 | def bpe(self, token): 147 | if token in self.cache: 148 | return self.cache[token] 149 | word = tuple(token) 150 | pairs = self.get_pairs(word) 151 | 152 | if not pairs: 153 | return token 154 | 155 | while True: 156 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, 157 | float( 158 | 'inf'))) 159 | if bigram not in self.bpe_ranks: 160 | break 161 | first, second = bigram 162 | new_word = [] 163 | i = 0 164 | while i < len(word): 165 | try: 166 | j = word.index(first, i) 167 | new_word.extend(word[i:j]) 168 | i = j 169 | except: 170 | new_word.extend(word[i:]) 171 | break 172 | 173 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 174 | new_word.append(first + second) 175 | i += 2 176 | else: 177 | new_word.append(word[i]) 178 | i += 1 179 | new_word = tuple(new_word) 180 | word = new_word 181 | if len(word) == 1: 182 | break 183 | else: 184 | pairs = self.get_pairs(word) 185 | word = ' '.join(word) 186 | self.cache[token] = word 187 | return word 188 | 189 | def bpe_tokenize(self, text): 190 | """ Tokenize a string.""" 191 | bpe_tokens = [] 192 | for token in text.split(): 193 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 194 | bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) 195 | return bpe_tokens 196 | 197 | @overrides 198 | def tokens_to_indices(self, 199 | tokens: List[Token], 200 | vocabulary: Vocabulary, 201 | index_name: str) -> Dict[str, List[int]]: 202 | if not self._added_to_vocabulary: 203 | self._add_encoding_to_vocabulary(vocabulary) 204 | self._added_to_vocabulary = True 205 | 206 | # This lowercases tokens if necessary 207 | text = (token.text.lower() 208 | if self._do_lowercase and token.text not in self._never_lowercase 209 | else token.text 210 | for token in tokens) 211 | 212 | # Obtain a nested sequence of wordpieces, each represented by a list of wordpiece ids 213 | token_wordpiece_ids = [] 214 | for token in text: 215 | if self.bpe_ranks != {}: 216 | wps = self.bpe_tokenize(token) 217 | else: 218 | wps = self.wordpiece_tokenizer(token) 219 | limited_wps = [self.vocab[wordpiece] for wordpiece in wps][:self.max_pieces_per_token] 220 | token_wordpiece_ids.append(limited_wps) 221 | 222 | # Flattened list of wordpieces. In the end, the output of the model (e.g., BERT) should 223 | # have a sequence length equal to the length of this list. However, it will first be split into 224 | # chunks of length `self.max_pieces` so that they can be fit through the model. After packing 225 | # and passing through the model, it should be unpacked to represent the wordpieces in this list. 226 | flat_wordpiece_ids = [wordpiece for token in token_wordpiece_ids for wordpiece in token] 227 | 228 | # reduce max_pieces_per_token if piece length of sentence is bigger than max_pieces_per_sentence 229 | # helps to fix CUDA out of memory errors meanwhile increasing batch size 230 | while not self.is_test and len(flat_wordpiece_ids) > \ 231 | self.max_pieces_per_sentence - len(self._start_piece_ids) - len(self._end_piece_ids): 232 | max_pieces = max([len(row) for row in token_wordpiece_ids]) 233 | token_wordpiece_ids = [row[:max_pieces - 1] for row in token_wordpiece_ids] 234 | flat_wordpiece_ids = [wordpiece for token in token_wordpiece_ids for wordpiece in token] 235 | 236 | # The code below will (possibly) pack the wordpiece sequence into multiple sub-sequences by using a sliding 237 | # window `window_length` that overlaps with previous windows according to the `stride`. Suppose we have 238 | # the following sentence: "I went to the store to buy some milk". Then a sliding window of length 4 and 239 | # stride of length 2 will split them up into: 240 | 241 | # "[I went to the] [to the store to] [store to buy some] [buy some milk [PAD]]". 242 | 243 | # This is to ensure that the model has context of as much of the sentence as possible to get accurate 244 | # embeddings. Finally, the sequences will be padded with any start/end piece ids, e.g., 245 | 246 | # "[CLS] I went to the [SEP] [CLS] to the store to [SEP] ...". 247 | 248 | # The embedder should then be able to split this token sequence by the window length, 249 | # pass them through the model, and recombine them. 250 | 251 | # Specify the stride to be half of `self.max_pieces`, minus any additional start/end wordpieces 252 | window_length = self.max_pieces - len(self._start_piece_ids) - len(self._end_piece_ids) 253 | stride = window_length // 2 254 | 255 | # offsets[i] will give us the index into wordpiece_ids 256 | # for the wordpiece "corresponding to" the i-th input token. 257 | offsets = [] 258 | 259 | # If we're using initial offsets, we want to start at offset = len(text_tokens) 260 | # so that the first offset is the index of the first wordpiece of tokens[0]. 261 | # Otherwise, we want to start at len(text_tokens) - 1, so that the "previous" 262 | # offset is the last wordpiece of "tokens[-1]". 263 | offset = len(self._start_piece_ids) if self.use_starting_offsets else len(self._start_piece_ids) - 1 264 | 265 | for token in token_wordpiece_ids: 266 | # Truncate the sequence if specified, which depends on where the offsets are 267 | next_offset = 1 if self.use_starting_offsets else 0 268 | if self._truncate_long_sequences and offset >= window_length + next_offset: 269 | break 270 | 271 | # For initial offsets, the current value of ``offset`` is the start of 272 | # the current wordpiece, so add it to ``offsets`` and then increment it. 273 | if self.use_starting_offsets: 274 | offsets.append(offset) 275 | offset += len(token) 276 | # For final offsets, the current value of ``offset`` is the end of 277 | # the previous wordpiece, so increment it and then add it to ``offsets``. 278 | else: 279 | offset += len(token) 280 | offsets.append(offset) 281 | 282 | if len(flat_wordpiece_ids) <= window_length: 283 | # If all the wordpieces fit, then we don't need to do anything special 284 | wordpiece_windows = [self._add_start_and_end(flat_wordpiece_ids)] 285 | elif self._truncate_long_sequences: 286 | logger.warning("Too many wordpieces, truncating sequence. If you would like a sliding window, set" 287 | "`truncate_long_sequences` to False %s", str([token.text for token in tokens])) 288 | wordpiece_windows = [self._add_start_and_end(flat_wordpiece_ids[:window_length])] 289 | else: 290 | # Create a sliding window of wordpieces of length `max_pieces` that advances by `stride` steps and 291 | # add start/end wordpieces to each window 292 | # TODO: this currently does not respect word boundaries, so words may be cut in half between windows 293 | # However, this would increase complexity, as sequences would need to be padded/unpadded in the middle 294 | wordpiece_windows = [self._add_start_and_end(flat_wordpiece_ids[i:i + window_length]) 295 | for i in range(0, len(flat_wordpiece_ids), stride)] 296 | 297 | # Check for overlap in the last window. Throw it away if it is redundant. 298 | last_window = wordpiece_windows[-1][1:] 299 | penultimate_window = wordpiece_windows[-2] 300 | if last_window == penultimate_window[-len(last_window):]: 301 | wordpiece_windows = wordpiece_windows[:-1] 302 | 303 | # Flatten the wordpiece windows 304 | wordpiece_ids = [wordpiece for sequence in wordpiece_windows for wordpiece in sequence] 305 | 306 | # Our mask should correspond to the original tokens, 307 | # because calling util.get_text_field_mask on the 308 | # "wordpiece_id" tokens will produce the wrong shape. 309 | # However, because of the max_pieces constraint, we may 310 | # have truncated the wordpieces; accordingly, we want the mask 311 | # to correspond to the remaining tokens after truncation, which 312 | # is captured by the offsets. 313 | mask = [1 for _ in offsets] 314 | 315 | return {index_name: wordpiece_ids, 316 | f"{index_name}-offsets": offsets, 317 | "mask": mask} 318 | 319 | def _add_start_and_end(self, wordpiece_ids: List[int]) -> List[int]: 320 | return self._start_piece_ids + wordpiece_ids + self._end_piece_ids 321 | 322 | def _extend(self, token_type_ids: List[int]) -> List[int]: 323 | """ 324 | Extend the token type ids by len(start_piece_ids) on the left 325 | and len(end_piece_ids) on the right. 326 | """ 327 | first = token_type_ids[0] 328 | last = token_type_ids[-1] 329 | return ([first for _ in self._start_piece_ids] + 330 | token_type_ids + 331 | [last for _ in self._end_piece_ids]) 332 | 333 | @overrides 334 | def get_padding_token(self) -> int: 335 | return 0 336 | 337 | @overrides 338 | def get_padding_lengths(self, token: int) -> Dict[str, int]: # pylint: disable=unused-argument 339 | return {} 340 | 341 | @overrides 342 | def pad_token_sequence(self, 343 | tokens: Dict[str, List[int]], 344 | desired_num_tokens: Dict[str, int], 345 | padding_lengths: Dict[str, int]) -> Dict[str, List[int]]: # pylint: disable=unused-argument 346 | return {key: pad_sequence_to_length(val, desired_num_tokens[key]) 347 | for key, val in tokens.items()} 348 | 349 | @overrides 350 | def get_keys(self, index_name: str) -> List[str]: 351 | """ 352 | We need to override this because the indexer generates multiple keys. 353 | """ 354 | # pylint: disable=no-self-use 355 | return [index_name, f"{index_name}-offsets", f"{index_name}-type-ids", "mask"] 356 | 357 | 358 | class PretrainedBertIndexer(WordpieceIndexer): 359 | # pylint: disable=line-too-long 360 | """ 361 | A ``TokenIndexer`` corresponding to a pretrained BERT model. 362 | 363 | Parameters 364 | ---------- 365 | pretrained_model: ``str`` 366 | Either the name of the pretrained model to use (e.g. 'bert-base-uncased'), 367 | or the path to the .txt file with its vocabulary. 368 | 369 | If the name is a key in the list of pretrained models at 370 | https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/tokenization.py#L33 371 | the corresponding path will be used; otherwise it will be interpreted as a path or URL. 372 | use_starting_offsets: bool, optional (default: False) 373 | By default, the "offsets" created by the token indexer correspond to the 374 | last wordpiece in each word. If ``use_starting_offsets`` is specified, 375 | they will instead correspond to the first wordpiece in each word. 376 | do_lowercase: ``bool``, optional (default = True) 377 | Whether to lowercase the tokens before converting to wordpiece ids. 378 | never_lowercase: ``List[str]``, optional 379 | Tokens that should never be lowercased. Default is 380 | ['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]']. 381 | max_pieces: int, optional (default: 512) 382 | The BERT embedder uses positional embeddings and so has a corresponding 383 | maximum length for its input ids. Any inputs longer than this will 384 | either be truncated (default), or be split apart and batched using a 385 | sliding window. 386 | truncate_long_sequences : ``bool``, optional (default=``True``) 387 | By default, long sequences will be truncated to the maximum sequence 388 | length. Otherwise, they will be split apart and batched using a 389 | sliding window. 390 | """ 391 | 392 | def __init__(self, 393 | pretrained_model: str, 394 | use_starting_offsets: bool = False, 395 | do_lowercase: bool = True, 396 | never_lowercase: List[str] = None, 397 | max_pieces: int = 512, 398 | max_pieces_per_token=5, 399 | is_test=False, 400 | truncate_long_sequences: bool = True, 401 | special_tokens_fix: int = 0) -> None: 402 | if pretrained_model.endswith("-cased") and do_lowercase: 403 | logger.warning("Your BERT model appears to be cased, " 404 | "but your indexer is lowercasing tokens.") 405 | elif pretrained_model.endswith("-uncased") and not do_lowercase: 406 | logger.warning("Your BERT model appears to be uncased, " 407 | "but your indexer is not lowercasing tokens.") 408 | 409 | bert_tokenizer = AutoTokenizer.from_pretrained( 410 | pretrained_model, do_lower_case=do_lowercase, do_basic_tokenize=False) 411 | 412 | # to adjust all tokenizers 413 | if hasattr(bert_tokenizer, 'encoder'): 414 | bert_tokenizer.vocab = bert_tokenizer.encoder 415 | if hasattr(bert_tokenizer, 'sp_model'): 416 | bert_tokenizer.vocab = defaultdict(lambda: 1) 417 | for i in range(bert_tokenizer.sp_model.get_piece_size()): 418 | bert_tokenizer.vocab[bert_tokenizer.sp_model.id_to_piece(i)] = i 419 | 420 | if special_tokens_fix: 421 | bert_tokenizer.add_tokens([START_TOKEN]) 422 | bert_tokenizer.vocab[START_TOKEN] = len(bert_tokenizer) - 1 423 | 424 | if "roberta" in pretrained_model: 425 | bpe_ranks = bert_tokenizer.bpe_ranks 426 | byte_encoder = bert_tokenizer.byte_encoder 427 | else: 428 | bpe_ranks = {} 429 | byte_encoder = None 430 | 431 | super().__init__(vocab=bert_tokenizer.vocab, 432 | bpe_ranks=bpe_ranks, 433 | byte_encoder=byte_encoder, 434 | wordpiece_tokenizer=bert_tokenizer.tokenize, 435 | namespace="bert", 436 | use_starting_offsets=use_starting_offsets, 437 | max_pieces=max_pieces, 438 | max_pieces_per_token=max_pieces_per_token, 439 | is_test=is_test, 440 | do_lowercase=do_lowercase, 441 | never_lowercase=never_lowercase, 442 | start_tokens=["[CLS]"] if not special_tokens_fix else [], 443 | end_tokens=["[SEP]"] if not special_tokens_fix else [], 444 | truncate_long_sequences=truncate_long_sequences) 445 | -------------------------------------------------------------------------------- /gector/trainer.py: -------------------------------------------------------------------------------- 1 | """Tweaked version of corresponding AllenNLP file""" 2 | import datetime 3 | import logging 4 | import math 5 | import os 6 | import time 7 | import traceback 8 | from typing import Dict, Optional, List, Tuple, Union, Iterable, Any 9 | 10 | import torch 11 | import torch.optim.lr_scheduler 12 | from allennlp.common import Params 13 | from allennlp.common.checks import ConfigurationError, parse_cuda_device 14 | from allennlp.common.tqdm import Tqdm 15 | from allennlp.common.util import dump_metrics, gpu_memory_mb, peak_memory_mb, lazy_groups_of 16 | from allennlp.data.instance import Instance 17 | from allennlp.data.iterators.data_iterator import DataIterator, TensorDict 18 | from allennlp.models.model import Model 19 | from allennlp.nn import util as nn_util 20 | from allennlp.training import util as training_util 21 | from allennlp.training.checkpointer import Checkpointer 22 | from allennlp.training.learning_rate_schedulers import LearningRateScheduler 23 | from allennlp.training.metric_tracker import MetricTracker 24 | from allennlp.training.momentum_schedulers import MomentumScheduler 25 | from allennlp.training.moving_average import MovingAverage 26 | from allennlp.training.optimizers import Optimizer 27 | from allennlp.training.tensorboard_writer import TensorboardWriter 28 | from allennlp.training.trainer_base import TrainerBase 29 | 30 | logger = logging.getLogger(__name__) 31 | 32 | 33 | class Trainer(TrainerBase): 34 | def __init__( 35 | self, 36 | model: Model, 37 | optimizer: torch.optim.Optimizer, 38 | scheduler: torch.optim.lr_scheduler, 39 | iterator: DataIterator, 40 | train_dataset: Iterable[Instance], 41 | validation_dataset: Optional[Iterable[Instance]] = None, 42 | patience: Optional[int] = None, 43 | validation_metric: str = "-loss", 44 | validation_iterator: DataIterator = None, 45 | shuffle: bool = True, 46 | num_epochs: int = 20, 47 | accumulated_batch_count: int = 1, 48 | serialization_dir: Optional[str] = None, 49 | num_serialized_models_to_keep: int = 20, 50 | keep_serialized_model_every_num_seconds: int = None, 51 | checkpointer: Checkpointer = None, 52 | model_save_interval: float = None, 53 | cuda_device: Union[int, List] = -1, 54 | grad_norm: Optional[float] = None, 55 | grad_clipping: Optional[float] = None, 56 | learning_rate_scheduler: Optional[LearningRateScheduler] = None, 57 | momentum_scheduler: Optional[MomentumScheduler] = None, 58 | summary_interval: int = 100, 59 | histogram_interval: int = None, 60 | should_log_parameter_statistics: bool = True, 61 | should_log_learning_rate: bool = False, 62 | log_batch_size_period: Optional[int] = None, 63 | moving_average: Optional[MovingAverage] = None, 64 | cold_step_count: int = 0, 65 | cold_lr: float = 1e-3, 66 | cuda_verbose_step=None, 67 | ) -> None: 68 | """ 69 | A trainer for doing supervised learning. It just takes a labeled dataset 70 | and a ``DataIterator``, and uses the supplied ``Optimizer`` to learn the weights 71 | for your model over some fixed number of epochs. You can also pass in a validation 72 | dataset and enable early stopping. There are many other bells and whistles as well. 73 | 74 | Parameters 75 | ---------- 76 | model : ``Model``, required. 77 | An AllenNLP model to be optimized. Pytorch Modules can also be optimized if 78 | their ``forward`` method returns a dictionary with a "loss" key, containing a 79 | scalar tensor representing the loss function to be optimized. 80 | 81 | If you are training your model using GPUs, your model should already be 82 | on the correct device. (If you use `Trainer.from_params` this will be 83 | handled for you.) 84 | optimizer : ``torch.nn.Optimizer``, required. 85 | An instance of a Pytorch Optimizer, instantiated with the parameters of the 86 | model to be optimized. 87 | iterator : ``DataIterator``, required. 88 | A method for iterating over a ``Dataset``, yielding padded indexed batches. 89 | train_dataset : ``Dataset``, required. 90 | A ``Dataset`` to train on. The dataset should have already been indexed. 91 | validation_dataset : ``Dataset``, optional, (default = None). 92 | A ``Dataset`` to evaluate on. The dataset should have already been indexed. 93 | patience : Optional[int] > 0, optional (default=None) 94 | Number of epochs to be patient before early stopping: the training is stopped 95 | after ``patience`` epochs with no improvement. If given, it must be ``> 0``. 96 | If None, early stopping is disabled. 97 | validation_metric : str, optional (default="loss") 98 | Validation metric to measure for whether to stop training using patience 99 | and whether to serialize an ``is_best`` model each epoch. The metric name 100 | must be prepended with either "+" or "-", which specifies whether the metric 101 | is an increasing or decreasing function. 102 | validation_iterator : ``DataIterator``, optional (default=None) 103 | An iterator to use for the validation set. If ``None``, then 104 | use the training `iterator`. 105 | shuffle: ``bool``, optional (default=True) 106 | Whether to shuffle the instances in the iterator or not. 107 | num_epochs : int, optional (default = 20) 108 | Number of training epochs. 109 | serialization_dir : str, optional (default=None) 110 | Path to directory for saving and loading model files. Models will not be saved if 111 | this parameter is not passed. 112 | num_serialized_models_to_keep : ``int``, optional (default=20) 113 | Number of previous model checkpoints to retain. Default is to keep 20 checkpoints. 114 | A value of None or -1 means all checkpoints will be kept. 115 | keep_serialized_model_every_num_seconds : ``int``, optional (default=None) 116 | If num_serialized_models_to_keep is not None, then occasionally it's useful to 117 | save models at a given interval in addition to the last num_serialized_models_to_keep. 118 | To do so, specify keep_serialized_model_every_num_seconds as the number of seconds 119 | between permanently saved checkpoints. Note that this option is only used if 120 | num_serialized_models_to_keep is not None, otherwise all checkpoints are kept. 121 | checkpointer : ``Checkpointer``, optional (default=None) 122 | An instance of class Checkpointer to use instead of the default. If a checkpointer is specified, 123 | the arguments num_serialized_models_to_keep and keep_serialized_model_every_num_seconds should 124 | not be specified. The caller is responsible for initializing the checkpointer so that it is 125 | consistent with serialization_dir. 126 | model_save_interval : ``float``, optional (default=None) 127 | If provided, then serialize models every ``model_save_interval`` 128 | seconds within single epochs. In all cases, models are also saved 129 | at the end of every epoch if ``serialization_dir`` is provided. 130 | cuda_device : ``Union[int, List[int]]``, optional (default = -1) 131 | An integer or list of integers specifying the CUDA device(s) to use. If -1, the CPU is used. 132 | grad_norm : ``float``, optional, (default = None). 133 | If provided, gradient norms will be rescaled to have a maximum of this value. 134 | grad_clipping : ``float``, optional (default = ``None``). 135 | If provided, gradients will be clipped `during the backward pass` to have an (absolute) 136 | maximum of this value. If you are getting ``NaNs`` in your gradients during training 137 | that are not solved by using ``grad_norm``, you may need this. 138 | learning_rate_scheduler : ``LearningRateScheduler``, optional (default = None) 139 | If specified, the learning rate will be decayed with respect to 140 | this schedule at the end of each epoch (or batch, if the scheduler implements 141 | the ``step_batch`` method). If you use :class:`torch.optim.lr_scheduler.ReduceLROnPlateau`, 142 | this will use the ``validation_metric`` provided to determine if learning has plateaued. 143 | To support updating the learning rate on every batch, this can optionally implement 144 | ``step_batch(batch_num_total)`` which updates the learning rate given the batch number. 145 | momentum_scheduler : ``MomentumScheduler``, optional (default = None) 146 | If specified, the momentum will be updated at the end of each batch or epoch 147 | according to the schedule. 148 | summary_interval: ``int``, optional, (default = 100) 149 | Number of batches between logging scalars to tensorboard 150 | histogram_interval : ``int``, optional, (default = ``None``) 151 | If not None, then log histograms to tensorboard every ``histogram_interval`` batches. 152 | When this parameter is specified, the following additional logging is enabled: 153 | * Histograms of model parameters 154 | * The ratio of parameter update norm to parameter norm 155 | * Histogram of layer activations 156 | We log histograms of the parameters returned by 157 | ``model.get_parameters_for_histogram_tensorboard_logging``. 158 | The layer activations are logged for any modules in the ``Model`` that have 159 | the attribute ``should_log_activations`` set to ``True``. Logging 160 | histograms requires a number of GPU-CPU copies during training and is typically 161 | slow, so we recommend logging histograms relatively infrequently. 162 | Note: only Modules that return tensors, tuples of tensors or dicts 163 | with tensors as values currently support activation logging. 164 | should_log_parameter_statistics : ``bool``, optional, (default = True) 165 | Whether to send parameter statistics (mean and standard deviation 166 | of parameters and gradients) to tensorboard. 167 | should_log_learning_rate : ``bool``, optional, (default = False) 168 | Whether to send parameter specific learning rate to tensorboard. 169 | log_batch_size_period : ``int``, optional, (default = ``None``) 170 | If defined, how often to log the average batch size. 171 | moving_average: ``MovingAverage``, optional, (default = None) 172 | If provided, we will maintain moving averages for all parameters. During training, we 173 | employ a shadow variable for each parameter, which maintains the moving average. During 174 | evaluation, we backup the original parameters and assign the moving averages to corresponding 175 | parameters. Be careful that when saving the checkpoint, we will save the moving averages of 176 | parameters. This is necessary because we want the saved model to perform as well as the validated 177 | model if we load it later. But this may cause problems if you restart the training from checkpoint. 178 | """ 179 | super().__init__(serialization_dir, cuda_device) 180 | 181 | # I am not calling move_to_gpu here, because if the model is 182 | # not already on the GPU then the optimizer is going to be wrong. 183 | self.model = model 184 | 185 | self.iterator = iterator 186 | self._validation_iterator = validation_iterator 187 | self.shuffle = shuffle 188 | self.optimizer = optimizer 189 | self.scheduler = scheduler 190 | self.train_data = train_dataset 191 | self._validation_data = validation_dataset 192 | self.accumulated_batch_count = accumulated_batch_count 193 | self.cold_step_count = cold_step_count 194 | self.cold_lr = cold_lr 195 | self.cuda_verbose_step = cuda_verbose_step 196 | 197 | if patience is None: # no early stopping 198 | if validation_dataset: 199 | logger.warning( 200 | "You provided a validation dataset but patience was set to None, " 201 | "meaning that early stopping is disabled" 202 | ) 203 | elif (not isinstance(patience, int)) or patience <= 0: 204 | raise ConfigurationError( 205 | '{} is an invalid value for "patience": it must be a positive integer ' 206 | "or None (if you want to disable early stopping)".format(patience) 207 | ) 208 | 209 | # For tracking is_best_so_far and should_stop_early 210 | self._metric_tracker = MetricTracker(patience, validation_metric) 211 | # Get rid of + or - 212 | self._validation_metric = validation_metric[1:] 213 | 214 | self._num_epochs = num_epochs 215 | 216 | if checkpointer is not None: 217 | # We can't easily check if these parameters were passed in, so check against their default values. 218 | # We don't check against serialization_dir since it is also used by the parent class. 219 | if num_serialized_models_to_keep != 20 \ 220 | or keep_serialized_model_every_num_seconds is not None: 221 | raise ConfigurationError( 222 | "When passing a custom Checkpointer, you may not also pass in separate checkpointer " 223 | "args 'num_serialized_models_to_keep' or 'keep_serialized_model_every_num_seconds'." 224 | ) 225 | self._checkpointer = checkpointer 226 | else: 227 | self._checkpointer = Checkpointer( 228 | serialization_dir, 229 | keep_serialized_model_every_num_seconds, 230 | num_serialized_models_to_keep, 231 | ) 232 | 233 | self._model_save_interval = model_save_interval 234 | 235 | self._grad_norm = grad_norm 236 | self._grad_clipping = grad_clipping 237 | 238 | self._learning_rate_scheduler = learning_rate_scheduler 239 | self._momentum_scheduler = momentum_scheduler 240 | self._moving_average = moving_average 241 | 242 | # We keep the total batch number as an instance variable because it 243 | # is used inside a closure for the hook which logs activations in 244 | # ``_enable_activation_logging``. 245 | self._batch_num_total = 0 246 | 247 | self._tensorboard = TensorboardWriter( 248 | get_batch_num_total=lambda: self._batch_num_total, 249 | serialization_dir=serialization_dir, 250 | summary_interval=summary_interval, 251 | histogram_interval=histogram_interval, 252 | should_log_parameter_statistics=should_log_parameter_statistics, 253 | should_log_learning_rate=should_log_learning_rate, 254 | ) 255 | 256 | self._log_batch_size_period = log_batch_size_period 257 | 258 | self._last_log = 0.0 # time of last logging 259 | 260 | # Enable activation logging. 261 | if histogram_interval is not None: 262 | self._tensorboard.enable_activation_logging(self.model) 263 | 264 | def rescale_gradients(self) -> Optional[float]: 265 | return training_util.rescale_gradients(self.model, self._grad_norm) 266 | 267 | def batch_loss(self, batch_group: List[TensorDict], for_training: bool) -> torch.Tensor: 268 | """ 269 | Does a forward pass on the given batches and returns the ``loss`` value in the result. 270 | If ``for_training`` is `True` also applies regularization penalty. 271 | """ 272 | if self._multiple_gpu: 273 | output_dict = training_util.data_parallel(batch_group, self.model, self._cuda_devices) 274 | else: 275 | assert len(batch_group) == 1 276 | batch = batch_group[0] 277 | batch = nn_util.move_to_device(batch, self._cuda_devices[0]) 278 | output_dict = self.model(**batch) 279 | 280 | try: 281 | loss = output_dict["loss"] 282 | if for_training: 283 | loss += self.model.get_regularization_penalty() 284 | except KeyError: 285 | if for_training: 286 | raise RuntimeError( 287 | "The model you are trying to optimize does not contain a" 288 | " 'loss' key in the output of model.forward(inputs)." 289 | ) 290 | loss = None 291 | 292 | return loss 293 | 294 | def _train_epoch(self, epoch: int) -> Dict[str, float]: 295 | """ 296 | Trains one epoch and returns metrics. 297 | """ 298 | logger.info("Epoch %d/%d", epoch, self._num_epochs - 1) 299 | peak_cpu_usage = peak_memory_mb() 300 | logger.info(f"Peak CPU memory usage MB: {peak_cpu_usage}") 301 | gpu_usage = [] 302 | for gpu, memory in gpu_memory_mb().items(): 303 | gpu_usage.append((gpu, memory)) 304 | logger.info(f"GPU {gpu} memory usage MB: {memory}") 305 | 306 | train_loss = 0.0 307 | # Set the model to "train" mode. 308 | self.model.train() 309 | 310 | num_gpus = len(self._cuda_devices) 311 | 312 | # Get tqdm for the training batches 313 | raw_train_generator = self.iterator(self.train_data, num_epochs=1, shuffle=self.shuffle) 314 | train_generator = lazy_groups_of(raw_train_generator, num_gpus) 315 | num_training_batches = math.ceil(self.iterator.get_num_batches(self.train_data) / num_gpus) 316 | residue = num_training_batches % self.accumulated_batch_count 317 | self._last_log = time.time() 318 | last_save_time = time.time() 319 | 320 | batches_this_epoch = 0 321 | if self._batch_num_total is None: 322 | self._batch_num_total = 0 323 | 324 | histogram_parameters = set(self.model.get_parameters_for_histogram_tensorboard_logging()) 325 | 326 | logger.info("Training") 327 | train_generator_tqdm = Tqdm.tqdm(train_generator, total=num_training_batches) 328 | cumulative_batch_size = 0 329 | self.optimizer.zero_grad() 330 | for batch_group in train_generator_tqdm: 331 | batches_this_epoch += 1 332 | self._batch_num_total += 1 333 | batch_num_total = self._batch_num_total 334 | 335 | iter_len = self.accumulated_batch_count \ 336 | if batches_this_epoch <= (num_training_batches - residue) else residue 337 | 338 | if self.cuda_verbose_step is not None and batch_num_total % self.cuda_verbose_step == 0: 339 | print(f'Before forward pass - Cuda memory allocated: {torch.cuda.memory_allocated() / 1e9}') 340 | print(f'Before forward pass - Cuda memory cached: {torch.cuda.memory_cached() / 1e9}') 341 | try: 342 | loss = self.batch_loss(batch_group, for_training=True) / iter_len 343 | except RuntimeError as e: 344 | print(e) 345 | for x in batch_group: 346 | all_words = [len(y['words']) for y in x['metadata']] 347 | print(f"Total sents: {len(all_words)}. " 348 | f"Min {min(all_words)}. Max {max(all_words)}") 349 | for elem in ['labels', 'd_tags']: 350 | tt = x[elem] 351 | print( 352 | f"{elem} shape {list(tt.shape)} and min {tt.min().item()} and {tt.max().item()}") 353 | for elem in ["bert", "mask", "bert-offsets"]: 354 | tt = x['tokens'][elem] 355 | print( 356 | f"{elem} shape {list(tt.shape)} and min {tt.min().item()} and {tt.max().item()}") 357 | raise e 358 | 359 | if self.cuda_verbose_step is not None and batch_num_total % self.cuda_verbose_step == 0: 360 | print(f'After forward pass - Cuda memory allocated: {torch.cuda.memory_allocated() / 1e9}') 361 | print(f'After forward pass - Cuda memory cached: {torch.cuda.memory_cached() / 1e9}') 362 | 363 | if torch.isnan(loss): 364 | raise ValueError("nan loss encountered") 365 | 366 | loss.backward() 367 | 368 | if self.cuda_verbose_step is not None and batch_num_total % self.cuda_verbose_step == 0: 369 | print(f'After backprop - Cuda memory allocated: {torch.cuda.memory_allocated() / 1e9}') 370 | print(f'After backprop - Cuda memory cached: {torch.cuda.memory_cached() / 1e9}') 371 | 372 | train_loss += loss.item() * iter_len 373 | 374 | del batch_group, loss 375 | torch.cuda.empty_cache() 376 | 377 | if self.cuda_verbose_step is not None and batch_num_total % self.cuda_verbose_step == 0: 378 | print(f'After collecting garbage - Cuda memory allocated: {torch.cuda.memory_allocated() / 1e9}') 379 | print(f'After collecting garbage - Cuda memory cached: {torch.cuda.memory_cached() / 1e9}') 380 | 381 | batch_grad_norm = self.rescale_gradients() 382 | 383 | # This does nothing if batch_num_total is None or you are using a 384 | # scheduler which doesn't update per batch. 385 | if self._learning_rate_scheduler: 386 | self._learning_rate_scheduler.step_batch(batch_num_total) 387 | if self._momentum_scheduler: 388 | self._momentum_scheduler.step_batch(batch_num_total) 389 | 390 | if self._tensorboard.should_log_histograms_this_batch(): 391 | # get the magnitude of parameter updates for logging 392 | # We need a copy of current parameters to compute magnitude of updates, 393 | # and copy them to CPU so large models won't go OOM on the GPU. 394 | param_updates = { 395 | name: param.detach().cpu().clone() 396 | for name, param in self.model.named_parameters() 397 | } 398 | if batches_this_epoch % self.accumulated_batch_count == 0 or \ 399 | batches_this_epoch == num_training_batches: 400 | self.optimizer.step() 401 | self.optimizer.zero_grad() 402 | for name, param in self.model.named_parameters(): 403 | param_updates[name].sub_(param.detach().cpu()) 404 | update_norm = torch.norm(param_updates[name].view(-1)) 405 | param_norm = torch.norm(param.view(-1)).cpu() 406 | self._tensorboard.add_train_scalar( 407 | "gradient_update/" + name, update_norm / (param_norm + 1e-7) 408 | ) 409 | else: 410 | if batches_this_epoch % self.accumulated_batch_count == 0 or \ 411 | batches_this_epoch == num_training_batches: 412 | self.optimizer.step() 413 | self.optimizer.zero_grad() 414 | 415 | # Update moving averages 416 | if self._moving_average is not None: 417 | self._moving_average.apply(batch_num_total) 418 | 419 | # Update the description with the latest metrics 420 | metrics = training_util.get_metrics(self.model, train_loss, batches_this_epoch) 421 | description = training_util.description_from_metrics(metrics) 422 | 423 | train_generator_tqdm.set_description(description, refresh=False) 424 | 425 | # Log parameter values to Tensorboard 426 | if self._tensorboard.should_log_this_batch(): 427 | self._tensorboard.log_parameter_and_gradient_statistics(self.model, batch_grad_norm) 428 | self._tensorboard.log_learning_rates(self.model, self.optimizer) 429 | 430 | self._tensorboard.add_train_scalar("loss/loss_train", metrics["loss"]) 431 | self._tensorboard.log_metrics({"epoch_metrics/" + k: v for k, v in metrics.items()}) 432 | 433 | if self._tensorboard.should_log_histograms_this_batch(): 434 | self._tensorboard.log_histograms(self.model, histogram_parameters) 435 | 436 | if self._log_batch_size_period: 437 | cur_batch = sum([training_util.get_batch_size(batch) for batch in batch_group]) 438 | cumulative_batch_size += cur_batch 439 | if (batches_this_epoch - 1) % self._log_batch_size_period == 0: 440 | average = cumulative_batch_size / batches_this_epoch 441 | logger.info(f"current batch size: {cur_batch} mean batch size: {average}") 442 | self._tensorboard.add_train_scalar("current_batch_size", cur_batch) 443 | self._tensorboard.add_train_scalar("mean_batch_size", average) 444 | 445 | # Save model if needed. 446 | if self._model_save_interval is not None and ( 447 | time.time() - last_save_time > self._model_save_interval 448 | ): 449 | last_save_time = time.time() 450 | self._save_checkpoint( 451 | "{0}.{1}".format(epoch, training_util.time_to_str(int(last_save_time))) 452 | ) 453 | 454 | metrics = training_util.get_metrics(self.model, train_loss, batches_this_epoch, reset=True) 455 | metrics["cpu_memory_MB"] = peak_cpu_usage 456 | for (gpu_num, memory) in gpu_usage: 457 | metrics["gpu_" + str(gpu_num) + "_memory_MB"] = memory 458 | return metrics 459 | 460 | def _validation_loss(self) -> Tuple[float, int]: 461 | """ 462 | Computes the validation loss. Returns it and the number of batches. 463 | """ 464 | logger.info("Validating") 465 | 466 | self.model.eval() 467 | 468 | # Replace parameter values with the shadow values from the moving averages. 469 | if self._moving_average is not None: 470 | self._moving_average.assign_average_value() 471 | 472 | if self._validation_iterator is not None: 473 | val_iterator = self._validation_iterator 474 | else: 475 | val_iterator = self.iterator 476 | 477 | num_gpus = len(self._cuda_devices) 478 | 479 | raw_val_generator = val_iterator(self._validation_data, num_epochs=1, shuffle=False) 480 | val_generator = lazy_groups_of(raw_val_generator, num_gpus) 481 | num_validation_batches = math.ceil( 482 | val_iterator.get_num_batches(self._validation_data) / num_gpus 483 | ) 484 | val_generator_tqdm = Tqdm.tqdm(val_generator, total=num_validation_batches) 485 | batches_this_epoch = 0 486 | val_loss = 0 487 | for batch_group in val_generator_tqdm: 488 | 489 | loss = self.batch_loss(batch_group, for_training=False) 490 | if loss is not None: 491 | # You shouldn't necessarily have to compute a loss for validation, so we allow for 492 | # `loss` to be None. We need to be careful, though - `batches_this_epoch` is 493 | # currently only used as the divisor for the loss function, so we can safely only 494 | # count those batches for which we actually have a loss. If this variable ever 495 | # gets used for something else, we might need to change things around a bit. 496 | batches_this_epoch += 1 497 | val_loss += loss.detach().cpu().numpy() 498 | 499 | # Update the description with the latest metrics 500 | val_metrics = training_util.get_metrics(self.model, val_loss, batches_this_epoch) 501 | description = training_util.description_from_metrics(val_metrics) 502 | val_generator_tqdm.set_description(description, refresh=False) 503 | 504 | # Now restore the original parameter values. 505 | if self._moving_average is not None: 506 | self._moving_average.restore() 507 | 508 | return val_loss, batches_this_epoch 509 | 510 | def train(self) -> Dict[str, Any]: 511 | """ 512 | Trains the supplied model with the supplied parameters. 513 | """ 514 | try: 515 | epoch_counter = self._restore_checkpoint() 516 | except RuntimeError: 517 | traceback.print_exc() 518 | raise ConfigurationError( 519 | "Could not recover training from the checkpoint. Did you mean to output to " 520 | "a different serialization directory or delete the existing serialization " 521 | "directory?" 522 | ) 523 | 524 | training_util.enable_gradient_clipping(self.model, self._grad_clipping) 525 | 526 | logger.info("Beginning training.") 527 | 528 | train_metrics: Dict[str, float] = {} 529 | val_metrics: Dict[str, float] = {} 530 | this_epoch_val_metric: float = None 531 | metrics: Dict[str, Any] = {} 532 | epochs_trained = 0 533 | training_start_time = time.time() 534 | 535 | if self.cold_step_count > 0: 536 | base_lr = self.optimizer.param_groups[0]['lr'] 537 | for param_group in self.optimizer.param_groups: 538 | param_group['lr'] = self.cold_lr 539 | self.model.text_field_embedder._token_embedders['bert'].set_weights(freeze=True) 540 | 541 | metrics["best_epoch"] = self._metric_tracker.best_epoch 542 | for key, value in self._metric_tracker.best_epoch_metrics.items(): 543 | metrics["best_validation_" + key] = value 544 | 545 | for epoch in range(epoch_counter, self._num_epochs): 546 | if epoch == self.cold_step_count and epoch != 0: 547 | for param_group in self.optimizer.param_groups: 548 | param_group['lr'] = base_lr 549 | self.model.text_field_embedder._token_embedders['bert'].set_weights(freeze=False) 550 | 551 | epoch_start_time = time.time() 552 | train_metrics = self._train_epoch(epoch) 553 | 554 | # get peak of memory usage 555 | if "cpu_memory_MB" in train_metrics: 556 | metrics["peak_cpu_memory_MB"] = max( 557 | metrics.get("peak_cpu_memory_MB", 0), train_metrics["cpu_memory_MB"] 558 | ) 559 | for key, value in train_metrics.items(): 560 | if key.startswith("gpu_"): 561 | metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value) 562 | 563 | # clear cache before validation 564 | torch.cuda.empty_cache() 565 | if self._validation_data is not None: 566 | with torch.no_grad(): 567 | # We have a validation set, so compute all the metrics on it. 568 | val_loss, num_batches = self._validation_loss() 569 | val_metrics = training_util.get_metrics( 570 | self.model, val_loss, num_batches, reset=True 571 | ) 572 | 573 | # Check validation metric for early stopping 574 | this_epoch_val_metric = val_metrics[self._validation_metric] 575 | self._metric_tracker.add_metric(this_epoch_val_metric) 576 | 577 | if self._metric_tracker.should_stop_early(): 578 | logger.info("Ran out of patience. Stopping training.") 579 | break 580 | 581 | self._tensorboard.log_metrics( 582 | train_metrics, val_metrics=val_metrics, log_to_console=True, epoch=epoch + 1 583 | ) # +1 because tensorboard doesn't like 0 584 | 585 | # Create overall metrics dict 586 | training_elapsed_time = time.time() - training_start_time 587 | metrics["training_duration"] = str(datetime.timedelta(seconds=training_elapsed_time)) 588 | metrics["training_start_epoch"] = epoch_counter 589 | metrics["training_epochs"] = epochs_trained 590 | metrics["epoch"] = epoch 591 | 592 | for key, value in train_metrics.items(): 593 | metrics["training_" + key] = value 594 | for key, value in val_metrics.items(): 595 | metrics["validation_" + key] = value 596 | 597 | # if self.cold_step_count <= epoch: 598 | self.scheduler.step(metrics['validation_loss']) 599 | 600 | if self._metric_tracker.is_best_so_far(): 601 | # Update all the best_ metrics. 602 | # (Otherwise they just stay the same as they were.) 603 | metrics["best_epoch"] = epoch 604 | for key, value in val_metrics.items(): 605 | metrics["best_validation_" + key] = value 606 | 607 | self._metric_tracker.best_epoch_metrics = val_metrics 608 | 609 | if self._serialization_dir: 610 | dump_metrics( 611 | os.path.join(self._serialization_dir, f"metrics_epoch_{epoch}.json"), metrics 612 | ) 613 | 614 | # The Scheduler API is agnostic to whether your schedule requires a validation metric - 615 | # if it doesn't, the validation metric passed here is ignored. 616 | if self._learning_rate_scheduler: 617 | self._learning_rate_scheduler.step(this_epoch_val_metric, epoch) 618 | if self._momentum_scheduler: 619 | self._momentum_scheduler.step(this_epoch_val_metric, epoch) 620 | 621 | self._save_checkpoint(epoch) 622 | 623 | epoch_elapsed_time = time.time() - epoch_start_time 624 | logger.info("Epoch duration: %s", datetime.timedelta(seconds=epoch_elapsed_time)) 625 | 626 | if epoch < self._num_epochs - 1: 627 | training_elapsed_time = time.time() - training_start_time 628 | estimated_time_remaining = training_elapsed_time * ( 629 | (self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1 630 | ) 631 | formatted_time = str(datetime.timedelta(seconds=int(estimated_time_remaining))) 632 | logger.info("Estimated training time remaining: %s", formatted_time) 633 | 634 | epochs_trained += 1 635 | 636 | # make sure pending events are flushed to disk and files are closed properly 637 | # self._tensorboard.close() 638 | 639 | # Load the best model state before returning 640 | best_model_state = self._checkpointer.best_model_state() 641 | if best_model_state: 642 | self.model.load_state_dict(best_model_state) 643 | 644 | return metrics 645 | 646 | def _save_checkpoint(self, epoch: Union[int, str]) -> None: 647 | """ 648 | Saves a checkpoint of the model to self._serialization_dir. 649 | Is a no-op if self._serialization_dir is None. 650 | 651 | Parameters 652 | ---------- 653 | epoch : Union[int, str], required. 654 | The epoch of training. If the checkpoint is saved in the middle 655 | of an epoch, the parameter is a string with the epoch and timestamp. 656 | """ 657 | # If moving averages are used for parameters, we save 658 | # the moving average values into checkpoint, instead of the current values. 659 | if self._moving_average is not None: 660 | self._moving_average.assign_average_value() 661 | 662 | # These are the training states we need to persist. 663 | training_states = { 664 | "metric_tracker": self._metric_tracker.state_dict(), 665 | "optimizer": self.optimizer.state_dict(), 666 | "batch_num_total": self._batch_num_total, 667 | } 668 | 669 | # If we have a learning rate or momentum scheduler, we should persist them too. 670 | if self._learning_rate_scheduler is not None: 671 | training_states["learning_rate_scheduler"] = self._learning_rate_scheduler.state_dict() 672 | if self._momentum_scheduler is not None: 673 | training_states["momentum_scheduler"] = self._momentum_scheduler.state_dict() 674 | 675 | self._checkpointer.save_checkpoint( 676 | model_state=self.model.state_dict(), 677 | epoch=epoch, 678 | training_states=training_states, 679 | is_best_so_far=self._metric_tracker.is_best_so_far(), 680 | ) 681 | 682 | # Restore the original values for parameters so that training will not be affected. 683 | if self._moving_average is not None: 684 | self._moving_average.restore() 685 | 686 | def _restore_checkpoint(self) -> int: 687 | """ 688 | Restores the model and training state from the last saved checkpoint. 689 | This includes an epoch count and optimizer state, which is serialized separately 690 | from model parameters. This function should only be used to continue training - 691 | if you wish to load a model for inference/load parts of a model into a new 692 | computation graph, you should use the native Pytorch functions: 693 | `` model.load_state_dict(torch.load("/path/to/model/weights.th"))`` 694 | 695 | If ``self._serialization_dir`` does not exist or does not contain any checkpointed weights, 696 | this function will do nothing and return 0. 697 | 698 | Returns 699 | ------- 700 | epoch: int 701 | The epoch at which to resume training, which should be one after the epoch 702 | in the saved training state. 703 | """ 704 | model_state, training_state = self._checkpointer.restore_checkpoint() 705 | 706 | if not training_state: 707 | # No checkpoint to restore, start at 0 708 | return 0 709 | 710 | self.model.load_state_dict(model_state) 711 | self.optimizer.load_state_dict(training_state["optimizer"]) 712 | if self._learning_rate_scheduler is not None \ 713 | and "learning_rate_scheduler" in training_state: 714 | self._learning_rate_scheduler.load_state_dict(training_state["learning_rate_scheduler"]) 715 | if self._momentum_scheduler is not None and "momentum_scheduler" in training_state: 716 | self._momentum_scheduler.load_state_dict(training_state["momentum_scheduler"]) 717 | training_util.move_optimizer_to_cuda(self.optimizer) 718 | 719 | # Currently the ``training_state`` contains a serialized ``MetricTracker``. 720 | if "metric_tracker" in training_state: 721 | self._metric_tracker.load_state_dict(training_state["metric_tracker"]) 722 | # It used to be the case that we tracked ``val_metric_per_epoch``. 723 | elif "val_metric_per_epoch" in training_state: 724 | self._metric_tracker.clear() 725 | self._metric_tracker.add_metrics(training_state["val_metric_per_epoch"]) 726 | # And before that we didn't track anything. 727 | else: 728 | self._metric_tracker.clear() 729 | 730 | if isinstance(training_state["epoch"], int): 731 | epoch_to_return = training_state["epoch"] + 1 732 | else: 733 | epoch_to_return = int(training_state["epoch"].split(".")[0]) + 1 734 | 735 | # For older checkpoints with batch_num_total missing, default to old behavior where 736 | # it is unchanged. 737 | batch_num_total = training_state.get("batch_num_total") 738 | if batch_num_total is not None: 739 | self._batch_num_total = batch_num_total 740 | 741 | return epoch_to_return 742 | 743 | # Requires custom from_params. 744 | @classmethod 745 | def from_params( # type: ignore 746 | cls, 747 | model: Model, 748 | serialization_dir: str, 749 | iterator: DataIterator, 750 | train_data: Iterable[Instance], 751 | validation_data: Optional[Iterable[Instance]], 752 | params: Params, 753 | validation_iterator: DataIterator = None, 754 | ) -> "Trainer": 755 | 756 | patience = params.pop_int("patience", None) 757 | validation_metric = params.pop("validation_metric", "-loss") 758 | shuffle = params.pop_bool("shuffle", True) 759 | num_epochs = params.pop_int("num_epochs", 20) 760 | cuda_device = parse_cuda_device(params.pop("cuda_device", -1)) 761 | grad_norm = params.pop_float("grad_norm", None) 762 | grad_clipping = params.pop_float("grad_clipping", None) 763 | lr_scheduler_params = params.pop("learning_rate_scheduler", None) 764 | momentum_scheduler_params = params.pop("momentum_scheduler", None) 765 | 766 | if isinstance(cuda_device, list): 767 | model_device = cuda_device[0] 768 | else: 769 | model_device = cuda_device 770 | if model_device >= 0: 771 | # Moving model to GPU here so that the optimizer state gets constructed on 772 | # the right device. 773 | model = model.cuda(model_device) 774 | 775 | parameters = [[n, p] for n, p in model.named_parameters() if p.requires_grad] 776 | optimizer = Optimizer.from_params(parameters, params.pop("optimizer")) 777 | if "moving_average" in params: 778 | moving_average = MovingAverage.from_params( 779 | params.pop("moving_average"), parameters=parameters 780 | ) 781 | else: 782 | moving_average = None 783 | 784 | if lr_scheduler_params: 785 | lr_scheduler = LearningRateScheduler.from_params(optimizer, lr_scheduler_params) 786 | else: 787 | lr_scheduler = None 788 | if momentum_scheduler_params: 789 | momentum_scheduler = MomentumScheduler.from_params(optimizer, momentum_scheduler_params) 790 | else: 791 | momentum_scheduler = None 792 | 793 | if "checkpointer" in params: 794 | if "keep_serialized_model_every_num_seconds" in params \ 795 | or "num_serialized_models_to_keep" in params: 796 | raise ConfigurationError( 797 | "Checkpointer may be initialized either from the 'checkpointer' key or from the " 798 | "keys 'num_serialized_models_to_keep' and 'keep_serialized_model_every_num_seconds'" 799 | " but the passed config uses both methods." 800 | ) 801 | checkpointer = Checkpointer.from_params(params.pop("checkpointer")) 802 | else: 803 | num_serialized_models_to_keep = params.pop_int("num_serialized_models_to_keep", 20) 804 | keep_serialized_model_every_num_seconds = params.pop_int( 805 | "keep_serialized_model_every_num_seconds", None 806 | ) 807 | checkpointer = Checkpointer( 808 | serialization_dir=serialization_dir, 809 | num_serialized_models_to_keep=num_serialized_models_to_keep, 810 | keep_serialized_model_every_num_seconds=keep_serialized_model_every_num_seconds, 811 | ) 812 | model_save_interval = params.pop_float("model_save_interval", None) 813 | summary_interval = params.pop_int("summary_interval", 100) 814 | histogram_interval = params.pop_int("histogram_interval", None) 815 | should_log_parameter_statistics = params.pop_bool("should_log_parameter_statistics", True) 816 | should_log_learning_rate = params.pop_bool("should_log_learning_rate", False) 817 | log_batch_size_period = params.pop_int("log_batch_size_period", None) 818 | 819 | params.assert_empty(cls.__name__) 820 | return cls( 821 | model, 822 | optimizer, 823 | iterator, 824 | train_data, 825 | validation_data, 826 | patience=patience, 827 | validation_metric=validation_metric, 828 | validation_iterator=validation_iterator, 829 | shuffle=shuffle, 830 | num_epochs=num_epochs, 831 | serialization_dir=serialization_dir, 832 | cuda_device=cuda_device, 833 | grad_norm=grad_norm, 834 | grad_clipping=grad_clipping, 835 | learning_rate_scheduler=lr_scheduler, 836 | momentum_scheduler=momentum_scheduler, 837 | checkpointer=checkpointer, 838 | model_save_interval=model_save_interval, 839 | summary_interval=summary_interval, 840 | histogram_interval=histogram_interval, 841 | should_log_parameter_statistics=should_log_parameter_statistics, 842 | should_log_learning_rate=should_log_learning_rate, 843 | log_batch_size_period=log_batch_size_period, 844 | moving_average=moving_average, 845 | ) 846 | --------------------------------------------------------------------------------