├── data
└── output_vocabulary
│ ├── non_padded_namespaces.txt
│ └── d_tags.txt
├── requirements.txt
├── .gitignore
├── README.md
├── utils
├── prepare_clc_fce_data.py
├── helpers.py
└── preprocess_data.py
├── predict.py
├── gector
├── datareader.py
├── seq2labels_model.py
├── bert_token_embedder.py
├── gec_model.py
├── wordpiece_indexer.py
└── trainer.py
├── LICENSE
└── train.py
/data/output_vocabulary/non_padded_namespaces.txt:
--------------------------------------------------------------------------------
1 | *tags
2 | *labels
3 |
--------------------------------------------------------------------------------
/data/output_vocabulary/d_tags.txt:
--------------------------------------------------------------------------------
1 | CORRECT
2 | INCORRECT
3 | @@UNKNOWN@@
4 | @@PADDING@@
5 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.3.0
2 | allennlp==0.8.4
3 | python-Levenshtein==0.12.0
4 | transformers==2.2.2
5 | scikit-learn==0.20.0
6 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 | .DS_Store
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | .python-version
87 |
88 | # pipenv
89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
92 | # install all needed dependencies.
93 | #Pipfile.lock
94 |
95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
96 | __pypackages__/
97 |
98 | # Celery stuff
99 | celerybeat-schedule
100 | celerybeat.pid
101 |
102 | # SageMath parsed files
103 | *.sage.py
104 |
105 | # Environments
106 | .env
107 | .venv
108 | env/
109 | venv/
110 | ENV/
111 | env.bak/
112 | venv.bak/
113 |
114 | # Spyder project settings
115 | .spyderproject
116 | .spyproject
117 |
118 | # Rope project settings
119 | .ropeproject
120 |
121 | # mkdocs documentation
122 | /site
123 |
124 | # mypy
125 | .mypy_cache/
126 | .dmypy.json
127 | dmypy.json
128 |
129 | # Pyre type checker
130 | .pyre/
131 |
132 | # PyCharm
133 | .idea
134 |
135 | *.sh
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GECToR – Grammatical Error Correction: Tag, Not Rewrite
2 |
3 | This repository provides code for training and testing state-of-the-art models for grammatical error correction with the official PyTorch implementation of the following paper:
4 | > [GECToR – Grammatical Error Correction: Tag, Not Rewrite](https://arxiv.org/abs/2005.12592)
5 | > [Kostiantyn Omelianchuk](https://github.com/komelianchuk), [Vitaliy Atrasevych](https://github.com/atrasevych), [Artem Chernodub](https://github.com/achernodub), [Oleksandr Skurzhanskyi](https://github.com/skurzhanskyi)
6 | > Grammarly
7 | > [15th Workshop on Innovative Use of NLP for Building Educational Applications (co-located with ACL 2020)](https://sig-edu.org/bea/current)
8 |
9 | It is mainly based on `AllenNLP` and `transformers`.
10 | ## Installation
11 | The following command installs all necessary packages:
12 | ```.bash
13 | pip install -r requirements.txt
14 | ```
15 | The project was tested using Python 3.7.
16 |
17 | ## Datasets
18 | All the public GEC datasets used in the paper can be downloaded from [here](https://www.cl.cam.ac.uk/research/nl/bea2019st/#data).
19 | Synthetically created datasets can be generated/downloaded [here](https://github.com/awasthiabhijeet/PIE/tree/master/errorify).
20 | To train the model data has to be preprocessed and converted to special format with the command:
21 | ```.bash
22 | python utils/preprocess_data.py -s SOURCE -t TARGET -o OUTPUT_FILE
23 | ```
24 | ## Pretrained models
25 |
26 |
27 | | Pretrained encoder |
28 | Confidence bias |
29 | Min error prob |
30 | CoNNL-2014 (test) |
31 | BEA-2019 (test) |
32 |
33 |
34 | | BERT [link] |
35 | 0.10 |
36 | 0.41 |
37 | 63.0 |
38 | 67.6 |
39 |
40 |
41 | | RoBERTa [link] |
42 | 0.20 |
43 | 0.50 |
44 | 64.0 |
45 | 71.5 |
46 |
47 |
48 | | XLNet [link] |
49 | 0.35 |
50 | 0.66 |
51 | 65.3 |
52 | 72.4 |
53 |
54 |
55 | | RoBERTa + XLNet |
56 | 0.24 |
57 | 0.45 |
58 | 66.0 |
59 | 73.7 |
60 |
61 |
62 | | BERT + RoBERTa + XLNet |
63 | 0.16 |
64 | 0.40 |
65 | 66.5 |
66 | 73.6 |
67 |
68 |
69 |
70 | ## Train model
71 | To train the model, simply run:
72 | ```.bash
73 | python train.py --train_set TRAIN_SET --dev_set DEV_SET \
74 | --model_dir MODEL_DIR
75 | ```
76 | There are a lot of parameters to specify among them:
77 | - `cold_steps_count` the number of epochs where we train only last linear layer
78 | - `transformer_model {bert,distilbert,gpt2,roberta,transformerxl,xlnet,albert}` model encoder
79 | - `tn_prob` probability of getting sentences with no errors; helps to balance precision/recall
80 | - `pieces_per_token` maximum number of subwords per token; helps not to get CUDA out of memory
81 |
82 | In our experiments we had 98/2 train/dev split.
83 | ## Model inference
84 | To run your model on the input file use the following command:
85 | ```.bash
86 | python predict.py --model_path MODEL_PATH [MODEL_PATH ...] \
87 | --vocab_path VOCAB_PATH --input_file INPUT_FILE \
88 | --output_file OUTPUT_FILE
89 | ```
90 | Among parameters:
91 | - `min_error_probability` - minimum error probability (as in the paper)
92 | - `additional_confidence` - confidence bias (as in the paper)
93 | - `special_tokens_fix` to reproduce some reported results of pretrained models
94 |
95 | For evaluation use [M^2Scorer](https://github.com/nusnlp/m2scorer) and [ERRANT](https://github.com/chrisjbryant/errant).
96 | ## Citation
97 | If you find this work is useful for your research, please cite our paper:
98 | ```
99 | @misc{omelianchuk2020gector,
100 | title={GECToR -- Grammatical Error Correction: Tag, Not Rewrite},
101 | author={Kostiantyn Omelianchuk and Vitaliy Atrasevych and Artem Chernodub and Oleksandr Skurzhanskyi},
102 | year={2020},
103 | eprint={2005.12592},
104 | archivePrefix={arXiv},
105 | primaryClass={cs.CL}
106 | }
107 | ```
108 |
--------------------------------------------------------------------------------
/utils/prepare_clc_fce_data.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """
3 | Convert CLC-FCE dataset (The Cambridge Learner Corpus) to the parallel sentences format.
4 | """
5 |
6 | import argparse
7 | import glob
8 | import os
9 | import re
10 | from xml.etree import cElementTree
11 |
12 | from nltk.tokenize import sent_tokenize, word_tokenize
13 | from tqdm import tqdm
14 |
15 |
16 | def annotate_fce_doc(xml):
17 | """Takes a FCE xml document and yields sentences with annotated errors."""
18 | result = []
19 | doc = cElementTree.fromstring(xml)
20 | paragraphs = doc.findall('head/text/*/coded_answer/p')
21 | for p in paragraphs:
22 | text = _get_formatted_text(p)
23 | result.append(text)
24 |
25 | return '\n'.join(result)
26 |
27 |
28 | def _get_formatted_text(elem, ignore_tags=None):
29 | text = elem.text or ''
30 | ignore_tags = [tag.upper() for tag in (ignore_tags or [])]
31 | correct = None
32 | mistake = None
33 |
34 | for child in elem.getchildren():
35 | tag = child.tag.upper()
36 | if tag == 'NS':
37 | text += _get_formatted_text(child)
38 |
39 | elif tag == 'UNKNOWN':
40 | text += ' UNKNOWN '
41 |
42 | elif tag == 'C':
43 | assert correct is None
44 | correct = _get_formatted_text(child)
45 |
46 | elif tag == 'I':
47 | assert mistake is None
48 | mistake = _get_formatted_text(child)
49 |
50 | elif tag in ignore_tags:
51 | pass
52 |
53 | else:
54 | raise ValueError(f"Unknown tag `{child.tag}`", text)
55 |
56 | if correct or mistake:
57 | correct = correct or ''
58 | mistake = mistake or ''
59 | if '=>' not in mistake:
60 | text += f'{{{mistake}=>{correct}}}'
61 | else:
62 | text += mistake
63 |
64 | text += elem.tail or ''
65 | return text
66 |
67 |
68 | def convert_fce(fce_dir):
69 | """Processes the whole FCE directory. Yields annotated documents (strings)."""
70 |
71 | # Ensure we got the valid dataset path
72 | if not os.path.isdir(fce_dir):
73 | raise UserWarning(
74 | f"{fce_dir} is not a valid path")
75 |
76 | dataset_dir = os.path.join(fce_dir, 'dataset')
77 | if not os.path.exists(dataset_dir):
78 | raise UserWarning(
79 | f"{fce_dir} doesn't point to a dataset's root dir")
80 |
81 | # Convert XML docs to the corpora format
82 | filenames = sorted(glob.glob(os.path.join(dataset_dir, '*/*.xml')))
83 |
84 | docs = []
85 | for filename in filenames:
86 | with open(filename, encoding='utf-8') as f:
87 | doc = annotate_fce_doc(f.read())
88 | docs.append(doc)
89 | return docs
90 |
91 |
92 | def main():
93 | fce = convert_fce(args.fce_dataset_path)
94 | with open(args.output + "/fce-original.txt", 'w', encoding='utf-8') as out_original, \
95 | open(args.output + "/fce-applied.txt", 'w', encoding='utf-8') as out_applied:
96 | for doc in tqdm(fce, unit='doc'):
97 | sents = re.split(r"\n +\n", doc)
98 | for sent in sents:
99 | tokenized_sents = sent_tokenize(sent)
100 | for i in range(len(tokenized_sents)):
101 | if re.search(r"[{>][.?!]$", tokenized_sents[i]):
102 | tokenized_sents[i + 1] = tokenized_sents[i] + " " + tokenized_sents[i + 1]
103 | tokenized_sents[i] = ""
104 | regexp = r'{([^{}]*?)=>([^{}]*?)}'
105 | original = re.sub(regexp, r"\1", tokenized_sents[i])
106 | applied = re.sub(regexp, r"\2", tokenized_sents[i])
107 | # filter out nested alerts
108 | if original != "" and applied != "" and not re.search(r"[{}=]", original) \
109 | and not re.search(r"[{}=]", applied):
110 | out_original.write(" ".join(word_tokenize(original)) + "\n")
111 | out_applied.write(" ".join(word_tokenize(applied)) + "\n")
112 |
113 |
114 | if __name__ == '__main__':
115 | parser = argparse.ArgumentParser(description=(
116 | "Convert CLC-FCE dataset to the parallel sentences format."))
117 | parser.add_argument('fce_dataset_path',
118 | help='Path to the folder with the FCE dataset')
119 | parser.add_argument('--output',
120 | help='Path to the output folder')
121 | args = parser.parse_args()
122 |
123 | main()
124 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from utils.helpers import read_lines
4 | from gector.gec_model import GecBERTModel
5 |
6 |
7 | def predict_for_file(input_file, output_file, model, batch_size=32):
8 | test_data = read_lines(input_file)
9 | predictions = []
10 | cnt_corrections = 0
11 | batch = []
12 | for sent in test_data:
13 | batch.append(sent.split())
14 | if len(batch) == batch_size:
15 | preds, cnt = model.handle_batch(batch)
16 | predictions.extend(preds)
17 | cnt_corrections += cnt
18 | batch = []
19 | if batch:
20 | preds, cnt = model.handle_batch(batch)
21 | predictions.extend(preds)
22 | cnt_corrections += cnt
23 |
24 | with open(output_file, 'w') as f:
25 | f.write("\n".join([" ".join(x) for x in predictions]) + '\n')
26 | return cnt_corrections
27 |
28 |
29 | def main(args):
30 | # get all paths
31 | model = GecBERTModel(vocab_path=args.vocab_path,
32 | model_paths=args.model_path,
33 | max_len=args.max_len, min_len=args.min_len,
34 | iterations=args.iteration_count,
35 | min_error_probability=args.min_error_probability,
36 | min_probability=args.min_error_probability,
37 | lowercase_tokens=args.lowercase_tokens,
38 | model_name=args.transformer_model,
39 | special_tokens_fix=args.special_tokens_fix,
40 | log=False,
41 | confidence=args.additional_confidence,
42 | is_ensemble=args.is_ensemble,
43 | weigths=args.weights)
44 |
45 | cnt_corrections = predict_for_file(args.input_file, args.output_file, model,
46 | batch_size=args.batch_size)
47 | # evaluate with m2 or ERRANT
48 | print(f"Produced overall corrections: {cnt_corrections}")
49 |
50 |
51 | if __name__ == '__main__':
52 | # read parameters
53 | parser = argparse.ArgumentParser()
54 | parser.add_argument('--model_path',
55 | help='Path to the model file.', nargs='+',
56 | required=True)
57 | parser.add_argument('--vocab_path',
58 | help='Path to the model file.',
59 | default='data/output_vocabulary' # to use pretrained models
60 | )
61 | parser.add_argument('--input_file',
62 | help='Path to the evalset file',
63 | required=True)
64 | parser.add_argument('--output_file',
65 | help='Path to the output file',
66 | required=True)
67 | parser.add_argument('--max_len',
68 | type=int,
69 | help='The max sentence length'
70 | '(all longer will be truncated)',
71 | default=50)
72 | parser.add_argument('--min_len',
73 | type=int,
74 | help='The minimum sentence length'
75 | '(all longer will be returned w/o changes)',
76 | default=3)
77 | parser.add_argument('--batch_size',
78 | type=int,
79 | help='The size of hidden unit cell.',
80 | default=128)
81 | parser.add_argument('--lowercase_tokens',
82 | type=int,
83 | help='Whether to lowercase tokens.',
84 | default=0)
85 | parser.add_argument('--transformer_model',
86 | choices=['bert', 'gpt2', 'transformerxl', 'xlnet', 'distilbert', 'roberta', 'albert'],
87 | help='Name of the transformer model.',
88 | default='roberta')
89 | parser.add_argument('--iteration_count',
90 | type=int,
91 | help='The number of iterations of the model.',
92 | default=5)
93 | parser.add_argument('--additional_confidence',
94 | type=float,
95 | help='How many probability to add to $KEEP token.',
96 | default=0)
97 | parser.add_argument('--min_probability',
98 | type=float,
99 | default=0.0)
100 | parser.add_argument('--min_error_probability',
101 | type=float,
102 | default=0.0)
103 | parser.add_argument('--special_tokens_fix',
104 | type=int,
105 | help='Whether to fix problem with [CLS], [SEP] tokens tokenization. '
106 | 'For reproducing reported results it should be 0 for BERT/XLNet and 1 for RoBERTa.',
107 | default=1)
108 | parser.add_argument('--is_ensemble',
109 | type=int,
110 | help='Whether to do ensembling.',
111 | default=0)
112 | parser.add_argument('--weights',
113 | help='Used to calculate weighted average', nargs='+',
114 | default=None)
115 | args = parser.parse_args()
116 | main(args)
117 |
--------------------------------------------------------------------------------
/gector/datareader.py:
--------------------------------------------------------------------------------
1 | """Tweaked AllenNLP dataset reader."""
2 | import logging
3 | import re
4 | from random import random
5 | from typing import Dict, List
6 |
7 | from allennlp.common.file_utils import cached_path
8 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader
9 | from allennlp.data.fields import TextField, SequenceLabelField, MetadataField, Field
10 | from allennlp.data.instance import Instance
11 | from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
12 | from allennlp.data.tokenizers import Token
13 | from overrides import overrides
14 |
15 | from utils.helpers import SEQ_DELIMETERS, START_TOKEN
16 |
17 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name
18 |
19 |
20 | @DatasetReader.register("seq2labels_datareader")
21 | class Seq2LabelsDatasetReader(DatasetReader):
22 | """
23 | Reads instances from a pretokenised file where each line is in the following format:
24 |
25 | WORD###TAG [TAB] WORD###TAG [TAB] ..... \n
26 |
27 | and converts it into a ``Dataset`` suitable for sequence tagging. You can also specify
28 | alternative delimiters in the constructor.
29 |
30 | Parameters
31 | ----------
32 | delimiters: ``dict``
33 | The dcitionary with all delimeters.
34 | token_indexers : ``Dict[str, TokenIndexer]``, optional (default=``{"tokens": SingleIdTokenIndexer()}``)
35 | We use this to define the input representation for the text. See :class:`TokenIndexer`.
36 | Note that the `output` tags will always correspond to single token IDs based on how they
37 | are pre-tokenised in the data file.
38 | max_len: if set than will truncate long sentences
39 | """
40 | # fix broken sentences mostly in Lang8
41 | BROKEN_SENTENCES_REGEXP = re.compile(r'\.[a-zA-RT-Z]')
42 |
43 | def __init__(self,
44 | token_indexers: Dict[str, TokenIndexer] = None,
45 | delimeters: dict = SEQ_DELIMETERS,
46 | skip_correct: bool = False,
47 | skip_complex: int = 0,
48 | lazy: bool = False,
49 | max_len: int = None,
50 | test_mode: bool = False,
51 | tag_strategy: str = "keep_one",
52 | tn_prob: float = 0,
53 | tp_prob: float = 0,
54 | broken_dot_strategy: str = "keep") -> None:
55 | super().__init__(lazy)
56 | self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()}
57 | self._delimeters = delimeters
58 | self._max_len = max_len
59 | self._skip_correct = skip_correct
60 | self._skip_complex = skip_complex
61 | self._tag_strategy = tag_strategy
62 | self._broken_dot_strategy = broken_dot_strategy
63 | self._test_mode = test_mode
64 | self._tn_prob = tn_prob
65 | self._tp_prob = tp_prob
66 |
67 | @overrides
68 | def _read(self, file_path):
69 | # if `file_path` is a URL, redirect to the cache
70 | file_path = cached_path(file_path)
71 | with open(file_path, "r") as data_file:
72 | logger.info("Reading instances from lines in file at: %s", file_path)
73 | for line in data_file:
74 | line = line.strip("\n")
75 | # skip blank and broken lines
76 | if not line or (not self._test_mode and self._broken_dot_strategy == 'skip'
77 | and self.BROKEN_SENTENCES_REGEXP.search(line) is not None):
78 | continue
79 |
80 | tokens_and_tags = [pair.rsplit(self._delimeters['labels'], 1)
81 | for pair in line.split(self._delimeters['tokens'])]
82 | try:
83 | tokens = [Token(token) for token, tag in tokens_and_tags]
84 | tags = [tag for token, tag in tokens_and_tags]
85 | except ValueError:
86 | tokens = [Token(token[0]) for token in tokens_and_tags]
87 | tags = None
88 |
89 | if tokens and tokens[0] != Token(START_TOKEN):
90 | tokens = [Token(START_TOKEN)] + tokens
91 |
92 | words = [x.text for x in tokens]
93 | if self._max_len is not None:
94 | tokens = tokens[:self._max_len]
95 | tags = None if tags is None else tags[:self._max_len]
96 | instance = self.text_to_instance(tokens, tags, words)
97 | if instance:
98 | yield instance
99 |
100 | def extract_tags(self, tags: List[str]):
101 | op_del = self._delimeters['operations']
102 |
103 | labels = [x.split(op_del) for x in tags]
104 |
105 | comlex_flag_dict = {}
106 | # get flags
107 | for i in range(5):
108 | idx = i + 1
109 | comlex_flag_dict[idx] = sum([len(x) > idx for x in labels])
110 |
111 | if self._tag_strategy == "keep_one":
112 | # get only first candidates for r_tags in right and the last for left
113 | labels = [x[0] for x in labels]
114 | elif self._tag_strategy == "merge_all":
115 | # consider phrases as a words
116 | pass
117 | else:
118 | raise Exception("Incorrect tag strategy")
119 |
120 | detect_tags = ["CORRECT" if label == "$KEEP" else "INCORRECT" for label in labels]
121 | return labels, detect_tags, comlex_flag_dict
122 |
123 | def text_to_instance(self, tokens: List[Token], tags: List[str] = None,
124 | words: List[str] = None) -> Instance: # type: ignore
125 | """
126 | We take `pre-tokenized` input here, because we don't have a tokenizer in this class.
127 | """
128 | # pylint: disable=arguments-differ
129 | fields: Dict[str, Field] = {}
130 | sequence = TextField(tokens, self._token_indexers)
131 | fields["tokens"] = sequence
132 | fields["metadata"] = MetadataField({"words": words})
133 | if tags is not None:
134 | labels, detect_tags, complex_flag_dict = self.extract_tags(tags)
135 | if self._skip_complex and complex_flag_dict[self._skip_complex] > 0:
136 | return None
137 | rnd = random()
138 | # skip TN
139 | if self._skip_correct and all(x == "CORRECT" for x in detect_tags):
140 | if rnd > self._tn_prob:
141 | return None
142 | # skip TP
143 | else:
144 | if rnd > self._tp_prob:
145 | return None
146 |
147 | fields["labels"] = SequenceLabelField(labels, sequence,
148 | label_namespace="labels")
149 | fields["d_tags"] = SequenceLabelField(detect_tags, sequence,
150 | label_namespace="d_tags")
151 | return Instance(fields)
152 |
--------------------------------------------------------------------------------
/utils/helpers.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 |
5 | VOCAB_DIR = Path(__file__).resolve().parent.parent / "data"
6 | PAD = "@@PADDING@@"
7 | UNK = "@@UNKNOWN@@"
8 | START_TOKEN = "$START"
9 | SEQ_DELIMETERS = {"tokens": " ",
10 | "labels": "SEPL|||SEPR",
11 | "operations": "SEPL__SEPR"}
12 |
13 |
14 | def get_verb_form_dicts():
15 | path_to_dict = os.path.join(VOCAB_DIR, "verb-form-vocab.txt")
16 | encode, decode = {}, {}
17 | with open(path_to_dict, encoding="utf-8") as f:
18 | for line in f:
19 | words, tags = line.split(":")
20 | word1, word2 = words.split("_")
21 | tag1, tag2 = tags.split("_")
22 | decode_key = f"{word1}_{tag1}_{tag2.strip()}"
23 | if decode_key not in decode:
24 | encode[words] = tags
25 | decode[decode_key] = word2
26 | return encode, decode
27 |
28 |
29 | ENCODE_VERB_DICT, DECODE_VERB_DICT = get_verb_form_dicts()
30 |
31 |
32 | def get_target_sent_by_edits(source_tokens, edits):
33 | target_tokens = source_tokens[:]
34 | shift_idx = 0
35 | for edit in edits:
36 | start, end, label, _ = edit
37 | target_pos = start + shift_idx
38 | source_token = target_tokens[target_pos] if target_pos >= 0 else ''
39 | if label == "":
40 | del target_tokens[target_pos]
41 | shift_idx -= 1
42 | elif start == end:
43 | word = label.replace("$APPEND_", "")
44 | target_tokens[target_pos: target_pos] = [word]
45 | shift_idx += 1
46 | elif label.startswith("$TRANSFORM_"):
47 | word = apply_reverse_transformation(source_token, label)
48 | if word is None:
49 | word = source_token
50 | target_tokens[target_pos] = word
51 | elif start == end - 1:
52 | word = label.replace("$REPLACE_", "")
53 | target_tokens[target_pos] = word
54 | elif label.startswith("$MERGE_"):
55 | target_tokens[target_pos + 1: target_pos + 1] = [label]
56 | shift_idx += 1
57 |
58 | return replace_merge_transforms(target_tokens)
59 |
60 |
61 | def replace_merge_transforms(tokens):
62 | if all(not x.startswith("$MERGE_") for x in tokens):
63 | return tokens
64 |
65 | target_line = " ".join(tokens)
66 | target_line = target_line.replace(" $MERGE_HYPHEN ", "-")
67 | target_line = target_line.replace(" $MERGE_SPACE ", "")
68 | return target_line.split()
69 |
70 |
71 | def convert_using_case(token, smart_action):
72 | if not smart_action.startswith("$TRANSFORM_CASE_"):
73 | return token
74 | if smart_action.endswith("LOWER"):
75 | return token.lower()
76 | elif smart_action.endswith("UPPER"):
77 | return token.upper()
78 | elif smart_action.endswith("CAPITAL"):
79 | return token.capitalize()
80 | elif smart_action.endswith("CAPITAL_1"):
81 | return token[0] + token[1:].capitalize()
82 | elif smart_action.endswith("UPPER_-1"):
83 | return token[:-1].upper() + token[-1]
84 | else:
85 | return token
86 |
87 |
88 | def convert_using_verb(token, smart_action):
89 | key_word = "$TRANSFORM_VERB_"
90 | if not smart_action.startswith(key_word):
91 | raise Exception(f"Unknown action type {smart_action}")
92 | encoding_part = f"{token}_{smart_action[len(key_word):]}"
93 | decoded_target_word = decode_verb_form(encoding_part)
94 | return decoded_target_word
95 |
96 |
97 | def convert_using_split(token, smart_action):
98 | key_word = "$TRANSFORM_SPLIT"
99 | if not smart_action.startswith(key_word):
100 | raise Exception(f"Unknown action type {smart_action}")
101 | target_words = token.split("-")
102 | return " ".join(target_words)
103 |
104 |
105 | def convert_using_plural(token, smart_action):
106 | if smart_action.endswith("PLURAL"):
107 | return token + "s"
108 | elif smart_action.endswith("SINGULAR"):
109 | return token[:-1]
110 | else:
111 | raise Exception(f"Unknown action type {smart_action}")
112 |
113 |
114 | def apply_reverse_transformation(source_token, transform):
115 | if transform.startswith("$TRANSFORM"):
116 | # deal with equal
117 | if transform == "$KEEP":
118 | return source_token
119 | # deal with case
120 | if transform.startswith("$TRANSFORM_CASE"):
121 | return convert_using_case(source_token, transform)
122 | # deal with verb
123 | if transform.startswith("$TRANSFORM_VERB"):
124 | return convert_using_verb(source_token, transform)
125 | # deal with split
126 | if transform.startswith("$TRANSFORM_SPLIT"):
127 | return convert_using_split(source_token, transform)
128 | # deal with single/plural
129 | if transform.startswith("$TRANSFORM_AGREEMENT"):
130 | return convert_using_plural(source_token, transform)
131 | # raise exception if not find correct type
132 | raise Exception(f"Unknown action type {transform}")
133 | else:
134 | return source_token
135 |
136 |
137 | def read_parallel_lines(fn1, fn2):
138 | lines1 = read_lines(fn1, skip_strip=True)
139 | lines2 = read_lines(fn2, skip_strip=True)
140 | assert len(lines1) == len(lines2)
141 | out_lines1, out_lines2 = [], []
142 | for line1, line2 in zip(lines1, lines2):
143 | if not line1.strip() or not line2.strip():
144 | continue
145 | else:
146 | out_lines1.append(line1)
147 | out_lines2.append(line2)
148 | return out_lines1, out_lines2
149 |
150 |
151 | def read_lines(fn, skip_strip=False):
152 | if not os.path.exists(fn):
153 | return []
154 | with open(fn, 'r', encoding='utf-8') as f:
155 | lines = f.readlines()
156 | return [s.strip() for s in lines if s.strip() or skip_strip]
157 |
158 |
159 | def write_lines(fn, lines, mode='w'):
160 | if mode == 'w' and os.path.exists(fn):
161 | os.remove(fn)
162 | with open(fn, encoding='utf-8', mode=mode) as f:
163 | f.writelines(['%s\n' % s for s in lines])
164 |
165 |
166 | def decode_verb_form(original):
167 | return DECODE_VERB_DICT.get(original)
168 |
169 |
170 | def encode_verb_form(original_word, corrected_word):
171 | decoding_request = original_word + "_" + corrected_word
172 | decoding_response = ENCODE_VERB_DICT.get(decoding_request, "").strip()
173 | if original_word and decoding_response:
174 | answer = decoding_response
175 | else:
176 | answer = None
177 | return answer
178 |
179 |
180 | def get_weights_name(transformer_name, lowercase):
181 | if transformer_name == 'bert' and lowercase:
182 | return 'bert-base-uncased'
183 | if transformer_name == 'bert' and not lowercase:
184 | return 'bert-base-cased'
185 | if transformer_name == 'distilbert':
186 | if not lowercase:
187 | print('Warning! This model was trained only on uncased sentences.')
188 | return 'distilbert-base-uncased'
189 | if transformer_name == 'albert':
190 | if not lowercase:
191 | print('Warning! This model was trained only on uncased sentences.')
192 | return 'albert-base-v1'
193 | if lowercase:
194 | print('Warning! This model was trained only on cased sentences.')
195 | if transformer_name == 'roberta':
196 | return 'roberta-base'
197 | if transformer_name == 'gpt2':
198 | return 'gpt2'
199 | if transformer_name == 'transformerxl':
200 | return 'transfo-xl-wt103'
201 | if transformer_name == 'xlnet':
202 | return 'xlnet-base-cased'
203 |
--------------------------------------------------------------------------------
/gector/seq2labels_model.py:
--------------------------------------------------------------------------------
1 | """Basic model. Predicts tags for every token"""
2 | from typing import Dict, Optional, List, Any
3 |
4 | import numpy
5 | import torch
6 | import torch.nn.functional as F
7 | from allennlp.data import Vocabulary
8 | from allennlp.models.model import Model
9 | from allennlp.modules import TimeDistributed, TextFieldEmbedder
10 | from allennlp.nn import InitializerApplicator, RegularizerApplicator
11 | from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
12 | from allennlp.training.metrics import CategoricalAccuracy
13 | from overrides import overrides
14 | from torch.nn.modules.linear import Linear
15 |
16 |
17 | @Model.register("seq2labels")
18 | class Seq2Labels(Model):
19 | """
20 | This ``Seq2Labels`` simply encodes a sequence of text with a stacked ``Seq2SeqEncoder``, then
21 | predicts a tag (or couple tags) for each token in the sequence.
22 |
23 | Parameters
24 | ----------
25 | vocab : ``Vocabulary``, required
26 | A Vocabulary, required in order to compute sizes for input/output projections.
27 | text_field_embedder : ``TextFieldEmbedder``, required
28 | Used to embed the ``tokens`` ``TextField`` we get as input to the model.
29 | encoder : ``Seq2SeqEncoder``
30 | The encoder (with its own internal stacking) that we will use in between embedding tokens
31 | and predicting output tags.
32 | calculate_span_f1 : ``bool``, optional (default=``None``)
33 | Calculate span-level F1 metrics during training. If this is ``True``, then
34 | ``label_encoding`` is required. If ``None`` and
35 | label_encoding is specified, this is set to ``True``.
36 | If ``None`` and label_encoding is not specified, it defaults
37 | to ``False``.
38 | label_encoding : ``str``, optional (default=``None``)
39 | Label encoding to use when calculating span f1.
40 | Valid options are "BIO", "BIOUL", "IOB1", "BMES".
41 | Required if ``calculate_span_f1`` is true.
42 | label_namespace : ``str``, optional (default=``labels``)
43 | This is needed to compute the SpanBasedF1Measure metric, if desired.
44 | Unless you did something unusual, the default value should be what you want.
45 | verbose_metrics : ``bool``, optional (default = False)
46 | If true, metrics will be returned per label class in addition
47 | to the overall statistics.
48 | initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
49 | Used to initialize the model parameters.
50 | regularizer : ``RegularizerApplicator``, optional (default=``None``)
51 | If provided, will be used to calculate the regularization penalty during training.
52 | """
53 |
54 | def __init__(self, vocab: Vocabulary,
55 | text_field_embedder: TextFieldEmbedder,
56 | predictor_dropout=0.0,
57 | labels_namespace: str = "labels",
58 | detect_namespace: str = "d_tags",
59 | verbose_metrics: bool = False,
60 | label_smoothing: float = 0.0,
61 | confidence: float = 0.0,
62 | initializer: InitializerApplicator = InitializerApplicator(),
63 | regularizer: Optional[RegularizerApplicator] = None) -> None:
64 | super(Seq2Labels, self).__init__(vocab, regularizer)
65 |
66 | self.label_namespaces = [labels_namespace,
67 | detect_namespace]
68 | self.text_field_embedder = text_field_embedder
69 | self.num_labels_classes = self.vocab.get_vocab_size(labels_namespace)
70 | self.num_detect_classes = self.vocab.get_vocab_size(detect_namespace)
71 | self.label_smoothing = label_smoothing
72 | self.confidence = confidence
73 | self.incorr_index = self.vocab.get_token_index("INCORRECT",
74 | namespace=detect_namespace)
75 |
76 | self._verbose_metrics = verbose_metrics
77 | self.predictor_dropout = TimeDistributed(torch.nn.Dropout(predictor_dropout))
78 |
79 | self.tag_labels_projection_layer = TimeDistributed(
80 | Linear(text_field_embedder._token_embedders['bert'].get_output_dim(), self.num_labels_classes))
81 |
82 | self.tag_detect_projection_layer = TimeDistributed(
83 | Linear(text_field_embedder._token_embedders['bert'].get_output_dim(), self.num_detect_classes))
84 |
85 | self.metrics = {"accuracy": CategoricalAccuracy()}
86 |
87 | initializer(self)
88 |
89 | @overrides
90 | def forward(self, # type: ignore
91 | tokens: Dict[str, torch.LongTensor],
92 | labels: torch.LongTensor = None,
93 | d_tags: torch.LongTensor = None,
94 | metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
95 | # pylint: disable=arguments-differ
96 | """
97 | Parameters
98 | ----------
99 | tokens : Dict[str, torch.LongTensor], required
100 | The output of ``TextField.as_array()``, which should typically be passed directly to a
101 | ``TextFieldEmbedder``. This output is a dictionary mapping keys to ``TokenIndexer``
102 | tensors. At its most basic, using a ``SingleIdTokenIndexer`` this is: ``{"tokens":
103 | Tensor(batch_size, num_tokens)}``. This dictionary will have the same keys as were used
104 | for the ``TokenIndexers`` when you created the ``TextField`` representing your
105 | sequence. The dictionary is designed to be passed directly to a ``TextFieldEmbedder``,
106 | which knows how to combine different word representations into a single vector per
107 | token in your input.
108 | lables : torch.LongTensor, optional (default = None)
109 | A torch tensor representing the sequence of integer gold class labels of shape
110 | ``(batch_size, num_tokens)``.
111 | d_tags : torch.LongTensor, optional (default = None)
112 | A torch tensor representing the sequence of integer gold class labels of shape
113 | ``(batch_size, num_tokens)``.
114 | metadata : ``List[Dict[str, Any]]``, optional, (default = None)
115 | metadata containing the original words in the sentence to be tagged under a 'words' key.
116 |
117 | Returns
118 | -------
119 | An output dictionary consisting of:
120 | logits : torch.FloatTensor
121 | A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
122 | unnormalised log probabilities of the tag classes.
123 | class_probabilities : torch.FloatTensor
124 | A tensor of shape ``(batch_size, num_tokens, tag_vocab_size)`` representing
125 | a distribution of the tag classes per word.
126 | loss : torch.FloatTensor, optional
127 | A scalar loss to be optimised.
128 |
129 | """
130 | encoded_text = self.text_field_embedder(tokens)
131 | batch_size, sequence_length, _ = encoded_text.size()
132 | mask = get_text_field_mask(tokens)
133 | logits_labels = self.tag_labels_projection_layer(self.predictor_dropout(encoded_text))
134 | logits_d = self.tag_detect_projection_layer(encoded_text)
135 |
136 | class_probabilities_labels = F.softmax(logits_labels, dim=-1).view(
137 | [batch_size, sequence_length, self.num_labels_classes])
138 | class_probabilities_d = F.softmax(logits_d, dim=-1).view(
139 | [batch_size, sequence_length, self.num_detect_classes])
140 | error_probs = class_probabilities_d[:, :, self.incorr_index] * mask
141 | incorr_prob = torch.max(error_probs, dim=-1)[0]
142 |
143 | if self.confidence > 0:
144 | probability_change = [self.confidence] + [0] * (self.num_labels_classes - 1)
145 | class_probabilities_labels += torch.cuda.FloatTensor(probability_change).repeat(
146 | (batch_size, sequence_length, 1))
147 |
148 | output_dict = {"logits_labels": logits_labels,
149 | "logits_d_tags": logits_d,
150 | "class_probabilities_labels": class_probabilities_labels,
151 | "class_probabilities_d_tags": class_probabilities_d,
152 | "max_error_probability": incorr_prob}
153 | if labels is not None and d_tags is not None:
154 | loss_labels = sequence_cross_entropy_with_logits(logits_labels, labels, mask,
155 | label_smoothing=self.label_smoothing)
156 | loss_d = sequence_cross_entropy_with_logits(logits_d, d_tags, mask)
157 | for metric in self.metrics.values():
158 | metric(logits_labels, labels, mask.float())
159 | metric(logits_d, d_tags, mask.float())
160 | output_dict["loss"] = loss_labels + loss_d
161 |
162 | if metadata is not None:
163 | output_dict["words"] = [x["words"] for x in metadata]
164 | return output_dict
165 |
166 | @overrides
167 | def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
168 | """
169 | Does a simple position-wise argmax over each token, converts indices to string labels, and
170 | adds a ``"tags"`` key to the dictionary with the result.
171 | """
172 | for label_namespace in self.label_namespaces:
173 | all_predictions = output_dict[f'class_probabilities_{label_namespace}']
174 | all_predictions = all_predictions.cpu().data.numpy()
175 | if all_predictions.ndim == 3:
176 | predictions_list = [all_predictions[i] for i in range(all_predictions.shape[0])]
177 | else:
178 | predictions_list = [all_predictions]
179 | all_tags = []
180 |
181 | for predictions in predictions_list:
182 | argmax_indices = numpy.argmax(predictions, axis=-1)
183 | tags = [self.vocab.get_token_from_index(x, namespace=label_namespace)
184 | for x in argmax_indices]
185 | all_tags.append(tags)
186 | output_dict[f'{label_namespace}'] = all_tags
187 | return output_dict
188 |
189 | @overrides
190 | def get_metrics(self, reset: bool = False) -> Dict[str, float]:
191 | metrics_to_return = {metric_name: metric.get_metric(reset) for
192 | metric_name, metric in self.metrics.items()}
193 | return metrics_to_return
194 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
--------------------------------------------------------------------------------
/gector/bert_token_embedder.py:
--------------------------------------------------------------------------------
1 | """Tweaked version of corresponding AllenNLP file"""
2 | import logging
3 | from copy import deepcopy
4 | from typing import Dict
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from allennlp.modules.token_embedders.token_embedder import TokenEmbedder
9 | from allennlp.nn import util
10 | from transformers import AutoModel, PreTrainedModel
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | class PretrainedBertModel:
16 | """
17 | In some instances you may want to load the same BERT model twice
18 | (e.g. to use as a token embedder and also as a pooling layer).
19 | This factory provides a cache so that you don't actually have to load the model twice.
20 | """
21 |
22 | _cache: Dict[str, PreTrainedModel] = {}
23 |
24 | @classmethod
25 | def load(cls, model_name: str, cache_model: bool = True) -> PreTrainedModel:
26 | if model_name in cls._cache:
27 | return PretrainedBertModel._cache[model_name]
28 |
29 | model = AutoModel.from_pretrained(model_name)
30 | if cache_model:
31 | cls._cache[model_name] = model
32 |
33 | return model
34 |
35 |
36 | class BertEmbedder(TokenEmbedder):
37 | """
38 | A ``TokenEmbedder`` that produces BERT embeddings for your tokens.
39 | Should be paired with a ``BertIndexer``, which produces wordpiece ids.
40 | Most likely you probably want to use ``PretrainedBertEmbedder``
41 | for one of the named pretrained models, not this base class.
42 | Parameters
43 | ----------
44 | bert_model: ``BertModel``
45 | The BERT model being wrapped.
46 | top_layer_only: ``bool``, optional (default = ``False``)
47 | If ``True``, then only return the top layer instead of apply the scalar mix.
48 | max_pieces : int, optional (default: 512)
49 | The BERT embedder uses positional embeddings and so has a corresponding
50 | maximum length for its input ids. Assuming the inputs are windowed
51 | and padded appropriately by this length, the embedder will split them into a
52 | large batch, feed them into BERT, and recombine the output as if it was a
53 | longer sequence.
54 | num_start_tokens : int, optional (default: 1)
55 | The number of starting special tokens input to BERT (usually 1, i.e., [CLS])
56 | num_end_tokens : int, optional (default: 1)
57 | The number of ending tokens input to BERT (usually 1, i.e., [SEP])
58 | scalar_mix_parameters: ``List[float]``, optional, (default = None)
59 | If not ``None``, use these scalar mix parameters to weight the representations
60 | produced by different layers. These mixing weights are not updated during
61 | training.
62 | """
63 |
64 | def __init__(
65 | self,
66 | bert_model: PreTrainedModel,
67 | top_layer_only: bool = False,
68 | max_pieces: int = 512,
69 | num_start_tokens: int = 1,
70 | num_end_tokens: int = 1
71 | ) -> None:
72 | super().__init__()
73 | # self.bert_model = bert_model
74 | self.bert_model = deepcopy(bert_model)
75 | self.output_dim = bert_model.config.hidden_size
76 | self.max_pieces = max_pieces
77 | self.num_start_tokens = num_start_tokens
78 | self.num_end_tokens = num_end_tokens
79 | self._scalar_mix = None
80 |
81 | def set_weights(self, freeze):
82 | for param in self.bert_model.parameters():
83 | param.requires_grad = not freeze
84 | return
85 |
86 | def get_output_dim(self) -> int:
87 | return self.output_dim
88 |
89 | def forward(
90 | self,
91 | input_ids: torch.LongTensor,
92 | offsets: torch.LongTensor = None
93 | ) -> torch.Tensor:
94 | """
95 | Parameters
96 | ----------
97 | input_ids : ``torch.LongTensor``
98 | The (batch_size, ..., max_sequence_length) tensor of wordpiece ids.
99 | offsets : ``torch.LongTensor``, optional
100 | The BERT embeddings are one per wordpiece. However it's possible/likely
101 | you might want one per original token. In that case, ``offsets``
102 | represents the indices of the desired wordpiece for each original token.
103 | Depending on how your token indexer is configured, this could be the
104 | position of the last wordpiece for each token, or it could be the position
105 | of the first wordpiece for each token.
106 | For example, if you had the sentence "Definitely not", and if the corresponding
107 | wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids
108 | would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4].
109 | If offsets are provided, the returned tensor will contain only the wordpiece
110 | embeddings at those positions, and (in particular) will contain one embedding
111 | per token. If offsets are not provided, the entire tensor of wordpiece embeddings
112 | will be returned.
113 | """
114 |
115 | batch_size, full_seq_len = input_ids.size(0), input_ids.size(-1)
116 | initial_dims = list(input_ids.shape[:-1])
117 |
118 | # The embedder may receive an input tensor that has a sequence length longer than can
119 | # be fit. In that case, we should expect the wordpiece indexer to create padded windows
120 | # of length `self.max_pieces` for us, and have them concatenated into one long sequence.
121 | # E.g., "[CLS] I went to the [SEP] [CLS] to the store to [SEP] ..."
122 | # We can then split the sequence into sub-sequences of that length, and concatenate them
123 | # along the batch dimension so we effectively have one huge batch of partial sentences.
124 | # This can then be fed into BERT without any sentence length issues. Keep in mind
125 | # that the memory consumption can dramatically increase for large batches with extremely
126 | # long sentences.
127 | needs_split = full_seq_len > self.max_pieces
128 | last_window_size = 0
129 | if needs_split:
130 | # Split the flattened list by the window size, `max_pieces`
131 | split_input_ids = list(input_ids.split(self.max_pieces, dim=-1))
132 |
133 | # We want all sequences to be the same length, so pad the last sequence
134 | last_window_size = split_input_ids[-1].size(-1)
135 | padding_amount = self.max_pieces - last_window_size
136 | split_input_ids[-1] = F.pad(split_input_ids[-1], pad=[0, padding_amount], value=0)
137 |
138 | # Now combine the sequences along the batch dimension
139 | input_ids = torch.cat(split_input_ids, dim=0)
140 |
141 | input_mask = (input_ids != 0).long()
142 | # input_ids may have extra dimensions, so we reshape down to 2-d
143 | # before calling the BERT model and then reshape back at the end.
144 | all_encoder_layers = self.bert_model(
145 | input_ids=util.combine_initial_dims(input_ids),
146 | attention_mask=util.combine_initial_dims(input_mask),
147 | )[0]
148 | if len(all_encoder_layers[0].shape) == 3:
149 | all_encoder_layers = torch.stack(all_encoder_layers)
150 | elif len(all_encoder_layers[0].shape) == 2:
151 | all_encoder_layers = torch.unsqueeze(all_encoder_layers, dim=0)
152 |
153 | if needs_split:
154 | # First, unpack the output embeddings into one long sequence again
155 | unpacked_embeddings = torch.split(all_encoder_layers, batch_size, dim=1)
156 | unpacked_embeddings = torch.cat(unpacked_embeddings, dim=2)
157 |
158 | # Next, select indices of the sequence such that it will result in embeddings representing the original
159 | # sentence. To capture maximal context, the indices will be the middle part of each embedded window
160 | # sub-sequence (plus any leftover start and final edge windows), e.g.,
161 | # 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
162 | # "[CLS] I went to the very fine [SEP] [CLS] the very fine store to eat [SEP]"
163 | # with max_pieces = 8 should produce max context indices [2, 3, 4, 10, 11, 12] with additional start
164 | # and final windows with indices [0, 1] and [14, 15] respectively.
165 |
166 | # Find the stride as half the max pieces, ignoring the special start and end tokens
167 | # Calculate an offset to extract the centermost embeddings of each window
168 | stride = (self.max_pieces - self.num_start_tokens - self.num_end_tokens) // 2
169 | stride_offset = stride // 2 + self.num_start_tokens
170 |
171 | first_window = list(range(stride_offset))
172 |
173 | max_context_windows = [
174 | i
175 | for i in range(full_seq_len)
176 | if stride_offset - 1 < i % self.max_pieces < stride_offset + stride
177 | ]
178 |
179 | # Lookback what's left, unless it's the whole self.max_pieces window
180 | if full_seq_len % self.max_pieces == 0:
181 | lookback = self.max_pieces
182 | else:
183 | lookback = full_seq_len % self.max_pieces
184 |
185 | final_window_start = full_seq_len - lookback + stride_offset + stride
186 | final_window = list(range(final_window_start, full_seq_len))
187 |
188 | select_indices = first_window + max_context_windows + final_window
189 |
190 | initial_dims.append(len(select_indices))
191 |
192 | recombined_embeddings = unpacked_embeddings[:, :, select_indices]
193 | else:
194 | recombined_embeddings = all_encoder_layers
195 |
196 | # Recombine the outputs of all layers
197 | # (layers, batch_size * d1 * ... * dn, sequence_length, embedding_dim)
198 | # recombined = torch.cat(combined, dim=2)
199 | input_mask = (recombined_embeddings != 0).long()
200 |
201 | if self._scalar_mix is not None:
202 | mix = self._scalar_mix(recombined_embeddings, input_mask)
203 | else:
204 | mix = recombined_embeddings[-1]
205 |
206 | # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim)
207 |
208 | if offsets is None:
209 | # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim)
210 | dims = initial_dims if needs_split else input_ids.size()
211 | return util.uncombine_initial_dims(mix, dims)
212 | else:
213 | # offsets is (batch_size, d1, ..., dn, orig_sequence_length)
214 | offsets2d = util.combine_initial_dims(offsets)
215 | # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length)
216 | range_vector = util.get_range_vector(
217 | offsets2d.size(0), device=util.get_device_of(mix)
218 | ).unsqueeze(1)
219 | # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
220 | selected_embeddings = mix[range_vector, offsets2d]
221 |
222 | return util.uncombine_initial_dims(selected_embeddings, offsets.size())
223 |
224 |
225 | # @TokenEmbedder.register("bert-pretrained")
226 | class PretrainedBertEmbedder(BertEmbedder):
227 |
228 | """
229 | Parameters
230 | ----------
231 | pretrained_model: ``str``
232 | Either the name of the pretrained model to use (e.g. 'bert-base-uncased'),
233 | or the path to the .tar.gz file with the model weights.
234 | If the name is a key in the list of pretrained models at
235 | https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/modeling.py#L41
236 | the corresponding path will be used; otherwise it will be interpreted as a path or URL.
237 | requires_grad : ``bool``, optional (default = False)
238 | If True, compute gradient of BERT parameters for fine tuning.
239 | top_layer_only: ``bool``, optional (default = ``False``)
240 | If ``True``, then only return the top layer instead of apply the scalar mix.
241 | scalar_mix_parameters: ``List[float]``, optional, (default = None)
242 | If not ``None``, use these scalar mix parameters to weight the representations
243 | produced by different layers. These mixing weights are not updated during
244 | training.
245 | """
246 |
247 | def __init__(
248 | self,
249 | pretrained_model: str,
250 | requires_grad: bool = False,
251 | top_layer_only: bool = False,
252 | special_tokens_fix: int = 0,
253 | ) -> None:
254 | model = PretrainedBertModel.load(pretrained_model)
255 |
256 | for param in model.parameters():
257 | param.requires_grad = requires_grad
258 |
259 | super().__init__(
260 | bert_model=model,
261 | top_layer_only=top_layer_only
262 | )
263 |
264 | if special_tokens_fix:
265 | try:
266 | vocab_size = self.bert_model.embeddings.word_embeddings.num_embeddings
267 | except AttributeError:
268 | # reserve more space
269 | vocab_size = self.bert_model.word_embedding.num_embeddings + 5
270 | self.bert_model.resize_token_embeddings(vocab_size + 1)
271 |
--------------------------------------------------------------------------------
/gector/gec_model.py:
--------------------------------------------------------------------------------
1 | """Wrapper of AllenNLP model. Fixes errors based on model predictions"""
2 | import logging
3 | import os
4 | import sys
5 | from time import time
6 |
7 | import torch
8 | from allennlp.data.dataset import Batch
9 | from allennlp.data.fields import TextField
10 | from allennlp.data.instance import Instance
11 | from allennlp.data.tokenizers import Token
12 | from allennlp.data.vocabulary import Vocabulary
13 | from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
14 | from allennlp.nn import util
15 |
16 | from gector.bert_token_embedder import PretrainedBertEmbedder
17 | from gector.seq2labels_model import Seq2Labels
18 | from gector.wordpiece_indexer import PretrainedBertIndexer
19 | from utils.helpers import PAD, UNK, get_target_sent_by_edits
20 |
21 | logging.getLogger("werkzeug").setLevel(logging.ERROR)
22 | logger = logging.getLogger(__file__)
23 |
24 |
25 | def get_weights_name(transformer_name, lowercase):
26 | if transformer_name == 'bert' and lowercase:
27 | return 'bert-base-uncased'
28 | if transformer_name == 'bert' and not lowercase:
29 | return 'bert-base-cased'
30 | if transformer_name == 'distilbert':
31 | if not lowercase:
32 | print('Warning! This model was trained only on uncased sentences.')
33 | return 'distilbert-base-uncased'
34 | if transformer_name == 'albert':
35 | if not lowercase:
36 | print('Warning! This model was trained only on uncased sentences.')
37 | return 'albert-base-v1'
38 | if lowercase:
39 | print('Warning! This model was trained only on cased sentences.')
40 | if transformer_name == 'roberta':
41 | return 'roberta-base'
42 | if transformer_name == 'gpt2':
43 | return 'gpt2'
44 | if transformer_name == 'transformerxl':
45 | return 'transfo-xl-wt103'
46 | if transformer_name == 'xlnet':
47 | return 'xlnet-base-cased'
48 |
49 |
50 | class GecBERTModel(object):
51 | def __init__(self, vocab_path=None, model_paths=None,
52 | weigths=None,
53 | max_len=50,
54 | min_len=3,
55 | lowercase_tokens=False,
56 | log=False,
57 | iterations=3,
58 | min_probability=0.0,
59 | model_name='roberta',
60 | special_tokens_fix=1,
61 | is_ensemble=True,
62 | min_error_probability=0.0,
63 | confidence=0,
64 | resolve_cycles=False,
65 | ):
66 | self.model_weights = list(map(float, weigths)) if weigths else [1] * len(model_paths)
67 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
68 | self.max_len = max_len
69 | self.min_len = min_len
70 | self.lowercase_tokens = lowercase_tokens
71 | self.min_probability = min_probability
72 | self.min_error_probability = min_error_probability
73 | self.vocab = Vocabulary.from_files(vocab_path)
74 | self.log = log
75 | self.iterations = iterations
76 | self.confidence = confidence
77 | self.resolve_cycles = resolve_cycles
78 | # set training parameters and operations
79 |
80 | self.indexers = []
81 | self.models = []
82 | for model_path in model_paths:
83 | if is_ensemble:
84 | model_name, special_tokens_fix = self._get_model_data(model_path)
85 | weights_name = get_weights_name(model_name, lowercase_tokens)
86 | self.indexers.append(self._get_indexer(weights_name, special_tokens_fix))
87 | model = Seq2Labels(vocab=self.vocab,
88 | text_field_embedder=self._get_embbeder(weights_name, special_tokens_fix),
89 | confidence=self.confidence
90 | ).to(self.device)
91 | if torch.cuda.is_available():
92 | model.load_state_dict(torch.load(model_path))
93 | else:
94 | model.load_state_dict(torch.load(model_path,
95 | map_location=torch.device('cpu')))
96 | model.eval()
97 | self.models.append(model)
98 |
99 | @staticmethod
100 | def _get_model_data(model_path):
101 | model_name = model_path.split('/')[-1]
102 | tr_model, stf = model_name.split('_')[:2]
103 | return tr_model, int(stf)
104 |
105 | def _restore_model(self, input_path):
106 | if os.path.isdir(input_path):
107 | print("Model could not be restored from directory", file=sys.stderr)
108 | filenames = []
109 | else:
110 | filenames = [input_path]
111 | for model_path in filenames:
112 | try:
113 | if torch.cuda.is_available():
114 | loaded_model = torch.load(model_path)
115 | else:
116 | loaded_model = torch.load(model_path,
117 | map_location=lambda storage,
118 | loc: storage)
119 | except:
120 | print(f"{model_path} is not valid model", file=sys.stderr)
121 | own_state = self.model.state_dict()
122 | for name, weights in loaded_model.items():
123 | if name not in own_state:
124 | continue
125 | try:
126 | if len(filenames) == 1:
127 | own_state[name].copy_(weights)
128 | else:
129 | own_state[name] += weights
130 | except RuntimeError:
131 | continue
132 | print("Model is restored", file=sys.stderr)
133 |
134 | def predict(self, batches):
135 | t11 = time()
136 | predictions = []
137 | for batch, model in zip(batches, self.models):
138 | batch = util.move_to_device(batch.as_tensor_dict(), 0 if torch.cuda.is_available() else -1)
139 | with torch.no_grad():
140 | prediction = model.forward(**batch)
141 | predictions.append(prediction)
142 |
143 | preds, idx, error_probs = self._convert(predictions)
144 | t55 = time()
145 | if self.log:
146 | print(f"Inference time {t55 - t11}")
147 | return preds, idx, error_probs
148 |
149 | def get_token_action(self, token, index, prob, sugg_token):
150 | """Get lost of suggested actions for token."""
151 | # cases when we don't need to do anything
152 | if prob < self.min_probability or sugg_token in [UNK, PAD, '$KEEP']:
153 | return None
154 |
155 | if sugg_token.startswith('$REPLACE_') or sugg_token.startswith('$TRANSFORM_') or sugg_token == '$DELETE':
156 | start_pos = index
157 | end_pos = index + 1
158 | elif sugg_token.startswith("$APPEND_") or sugg_token.startswith("$MERGE_"):
159 | start_pos = index + 1
160 | end_pos = index + 1
161 |
162 | if sugg_token == "$DELETE":
163 | sugg_token_clear = ""
164 | elif sugg_token.startswith('$TRANSFORM_') or sugg_token.startswith("$MERGE_"):
165 | sugg_token_clear = sugg_token[:]
166 | else:
167 | sugg_token_clear = sugg_token[sugg_token.index('_') + 1:]
168 |
169 | return start_pos - 1, end_pos - 1, sugg_token_clear, prob
170 |
171 | def _get_embbeder(self, weigths_name, special_tokens_fix):
172 | embedders = {'bert': PretrainedBertEmbedder(
173 | pretrained_model=weigths_name,
174 | requires_grad=False,
175 | top_layer_only=True,
176 | special_tokens_fix=special_tokens_fix)
177 | }
178 | text_field_embedder = BasicTextFieldEmbedder(
179 | token_embedders=embedders,
180 | embedder_to_indexer_map={"bert": ["bert", "bert-offsets"]},
181 | allow_unmatched_keys=True)
182 | return text_field_embedder
183 |
184 | def _get_indexer(self, weights_name, special_tokens_fix):
185 | bert_token_indexer = PretrainedBertIndexer(
186 | pretrained_model=weights_name,
187 | do_lowercase=self.lowercase_tokens,
188 | max_pieces_per_token=5,
189 | use_starting_offsets=True,
190 | truncate_long_sequences=True,
191 | special_tokens_fix=special_tokens_fix,
192 | is_test=True
193 | )
194 | return {'bert': bert_token_indexer}
195 |
196 | def preprocess(self, token_batch):
197 | seq_lens = [len(sequence) for sequence in token_batch if sequence]
198 | if not seq_lens:
199 | return []
200 | max_len = min(max(seq_lens), self.max_len)
201 | batches = []
202 | for indexer in self.indexers:
203 | batch = []
204 | for sequence in token_batch:
205 | tokens = sequence[:max_len]
206 | tokens = [Token(token) for token in ['$START'] + tokens]
207 | batch.append(Instance({'tokens': TextField(tokens, indexer)}))
208 | batch = Batch(batch)
209 | batch.index_instances(self.vocab)
210 | batches.append(batch)
211 |
212 | return batches
213 |
214 | def _convert(self, data):
215 | all_class_probs = torch.zeros_like(data[0]['class_probabilities_labels'])
216 | error_probs = torch.zeros_like(data[0]['max_error_probability'])
217 | for output, weight in zip(data, self.model_weights):
218 | all_class_probs += weight * output['class_probabilities_labels'] / sum(self.model_weights)
219 | error_probs += weight * output['max_error_probability'] / sum(self.model_weights)
220 |
221 | max_vals = torch.max(all_class_probs, dim=-1)
222 | probs = max_vals[0].tolist()
223 | idx = max_vals[1].tolist()
224 | return probs, idx, error_probs.tolist()
225 |
226 | def update_final_batch(self, final_batch, pred_ids, pred_batch,
227 | prev_preds_dict):
228 | new_pred_ids = []
229 | total_updated = 0
230 | for i, orig_id in enumerate(pred_ids):
231 | orig = final_batch[orig_id]
232 | pred = pred_batch[i]
233 | prev_preds = prev_preds_dict[orig_id]
234 | if orig != pred and pred not in prev_preds:
235 | final_batch[orig_id] = pred
236 | new_pred_ids.append(orig_id)
237 | prev_preds_dict[orig_id].append(pred)
238 | total_updated += 1
239 | elif orig != pred and pred in prev_preds:
240 | # update final batch, but stop iterations
241 | final_batch[orig_id] = pred
242 | total_updated += 1
243 | else:
244 | continue
245 | return final_batch, new_pred_ids, total_updated
246 |
247 | def postprocess_batch(self, batch, all_probabilities, all_idxs,
248 | error_probs,
249 | max_len=50):
250 | all_results = []
251 | noop_index = self.vocab.get_token_index("$KEEP", "labels")
252 | for tokens, probabilities, idxs, error_prob in zip(batch,
253 | all_probabilities,
254 | all_idxs,
255 | error_probs):
256 | length = min(len(tokens), max_len)
257 | edits = []
258 |
259 | # skip whole sentences if there no errors
260 | if max(idxs) == 0:
261 | all_results.append(tokens)
262 | continue
263 |
264 | # skip whole sentence if probability of correctness is not high
265 | if error_prob < self.min_error_probability:
266 | all_results.append(tokens)
267 | continue
268 |
269 | for i in range(length):
270 | token = tokens[i - 1] # because of START token
271 | # skip if there is no error
272 | if idxs[i] == noop_index:
273 | continue
274 |
275 | sugg_token = self.vocab.get_token_from_index(idxs[i],
276 | namespace='labels')
277 | action = self.get_token_action(token, i, probabilities[i],
278 | sugg_token)
279 | if not action:
280 | continue
281 |
282 | edits.append(action)
283 | all_results.append(get_target_sent_by_edits(tokens, edits))
284 | return all_results
285 |
286 | def handle_batch(self, full_batch):
287 | """
288 | Handle batch of requests.
289 | """
290 | final_batch = full_batch[:]
291 | batch_size = len(full_batch)
292 | prev_preds_dict = {i: [final_batch[i]] for i in range(len(final_batch))}
293 | short_ids = [i for i in range(len(full_batch))
294 | if len(full_batch[i]) < self.min_len]
295 | pred_ids = [i for i in range(len(full_batch)) if i not in short_ids]
296 | total_updates = 0
297 |
298 | for n_iter in range(self.iterations):
299 | orig_batch = [final_batch[i] for i in pred_ids]
300 |
301 | sequences = self.preprocess(orig_batch)
302 |
303 | if not sequences:
304 | break
305 | probabilities, idxs, error_probs = self.predict(sequences)
306 |
307 | pred_batch = self.postprocess_batch(orig_batch, probabilities,
308 | idxs, error_probs)
309 | if self.log:
310 | print(f"Iteration {n_iter + 1}. Predicted {round(100*len(pred_ids)/batch_size, 1)}% of sentences.")
311 |
312 | final_batch, pred_ids, cnt = \
313 | self.update_final_batch(final_batch, pred_ids, pred_batch,
314 | prev_preds_dict)
315 | total_updates += cnt
316 |
317 | if not pred_ids:
318 | break
319 |
320 | return final_batch, total_updates
321 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from random import seed
4 |
5 | import torch
6 | from allennlp.data.iterators import BucketIterator
7 | from allennlp.data.vocabulary import DEFAULT_OOV_TOKEN, DEFAULT_PADDING_TOKEN
8 | from allennlp.data.vocabulary import Vocabulary
9 | from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
10 |
11 | from gector.bert_token_embedder import PretrainedBertEmbedder
12 | from gector.datareader import Seq2LabelsDatasetReader
13 | from gector.seq2labels_model import Seq2Labels
14 | from gector.trainer import Trainer
15 | from gector.wordpiece_indexer import PretrainedBertIndexer
16 | from utils.helpers import get_weights_name
17 |
18 |
19 | def fix_seed():
20 | torch.manual_seed(1)
21 | torch.backends.cudnn.enabled = False
22 | torch.backends.cudnn.benchmark = False
23 | torch.backends.cudnn.deterministic = True
24 | seed(43)
25 |
26 |
27 | def get_token_indexers(model_name, max_pieces_per_token=5, lowercase_tokens=True, special_tokens_fix=0, is_test=False):
28 | bert_token_indexer = PretrainedBertIndexer(
29 | pretrained_model=model_name,
30 | max_pieces_per_token=max_pieces_per_token,
31 | do_lowercase=lowercase_tokens,
32 | use_starting_offsets=True,
33 | special_tokens_fix=special_tokens_fix,
34 | is_test=is_test
35 | )
36 | return {'bert': bert_token_indexer}
37 |
38 |
39 | def get_token_embedders(model_name, tune_bert=False, special_tokens_fix=0):
40 | take_grads = True if tune_bert > 0 else False
41 | bert_token_emb = PretrainedBertEmbedder(
42 | pretrained_model=model_name,
43 | top_layer_only=True, requires_grad=take_grads,
44 | special_tokens_fix=special_tokens_fix)
45 |
46 | token_embedders = {'bert': bert_token_emb}
47 | embedder_to_indexer_map = {"bert": ["bert", "bert-offsets"]}
48 |
49 | text_filed_emd = BasicTextFieldEmbedder(token_embedders=token_embedders,
50 | embedder_to_indexer_map=embedder_to_indexer_map,
51 | allow_unmatched_keys=True)
52 | return text_filed_emd
53 |
54 |
55 | def get_data_reader(model_name, max_len, skip_correct=False, skip_complex=0,
56 | test_mode=False, tag_strategy="keep_one",
57 | broken_dot_strategy="keep", lowercase_tokens=True,
58 | max_pieces_per_token=3, tn_prob=0, tp_prob=1, special_tokens_fix=0,):
59 | token_indexers = get_token_indexers(model_name,
60 | max_pieces_per_token=max_pieces_per_token,
61 | lowercase_tokens=lowercase_tokens,
62 | special_tokens_fix=special_tokens_fix,
63 | is_test=test_mode)
64 | reader = Seq2LabelsDatasetReader(token_indexers=token_indexers,
65 | max_len=max_len,
66 | skip_correct=skip_correct,
67 | skip_complex=skip_complex,
68 | test_mode=test_mode,
69 | tag_strategy=tag_strategy,
70 | broken_dot_strategy=broken_dot_strategy,
71 | lazy=True,
72 | tn_prob=tn_prob,
73 | tp_prob=tp_prob)
74 | return reader
75 |
76 |
77 | def get_model(model_name, vocab, tune_bert=False,
78 | predictor_dropout=0,
79 | label_smoothing=0.0,
80 | confidence=0,
81 | special_tokens_fix=0):
82 | token_embs = get_token_embedders(model_name, tune_bert=tune_bert, special_tokens_fix=special_tokens_fix)
83 | model = Seq2Labels(vocab=vocab,
84 | text_field_embedder=token_embs,
85 | predictor_dropout=predictor_dropout,
86 | label_smoothing=label_smoothing,
87 | confidence=confidence)
88 | return model
89 |
90 |
91 | def main(args):
92 | fix_seed()
93 | if not os.path.exists(args.model_dir):
94 | os.mkdir(args.model_dir)
95 |
96 | weights_name = get_weights_name(args.transformer_model, args.lowercase_tokens)
97 | # read datasets
98 | reader = get_data_reader(weights_name, args.max_len, skip_correct=bool(args.skip_correct),
99 | skip_complex=args.skip_complex,
100 | test_mode=False,
101 | tag_strategy=args.tag_strategy,
102 | lowercase_tokens=args.lowercase_tokens,
103 | max_pieces_per_token=args.pieces_per_token,
104 | tn_prob=args.tn_prob,
105 | tp_prob=args.tp_prob,
106 | special_tokens_fix=args.special_tokens_fix)
107 | train_data = reader.read(args.train_set)
108 | dev_data = reader.read(args.dev_set)
109 |
110 | default_tokens = [DEFAULT_OOV_TOKEN, DEFAULT_PADDING_TOKEN]
111 | namespaces = ['labels', 'd_tags']
112 | tokens_to_add = {x: default_tokens for x in namespaces}
113 | # build vocab
114 | if args.vocab_path:
115 | vocab = Vocabulary.from_files(args.vocab_path)
116 | else:
117 | vocab = Vocabulary.from_instances(train_data,
118 | max_vocab_size={'tokens': 30000,
119 | 'labels': args.target_vocab_size,
120 | 'd_tags': 2},
121 | tokens_to_add=tokens_to_add)
122 | vocab.save_to_files(os.path.join(args.model_dir, 'vocabulary'))
123 |
124 | print("Data is loaded")
125 | model = get_model(weights_name, vocab,
126 | tune_bert=args.tune_bert,
127 | predictor_dropout=args.predictor_dropout,
128 | label_smoothing=args.label_smoothing,
129 | special_tokens_fix=args.special_tokens_fix)
130 |
131 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
132 | if torch.cuda.is_available():
133 | if torch.cuda.device_count() > 1:
134 | cuda_device = list(range(torch.cuda.device_count()))
135 | else:
136 | cuda_device = 0
137 | else:
138 | cuda_device = -1
139 |
140 | if args.pretrain:
141 | model.load_state_dict(torch.load(os.path.join(args.pretrain_folder, args.pretrain + '.th')))
142 |
143 | model = model.to(device)
144 |
145 | print("Model is set")
146 |
147 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
148 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
149 | optimizer, factor=0.1, patience=10)
150 | instances_per_epoch = None if not args.updates_per_epoch else \
151 | int(args.updates_per_epoch * args.batch_size * args.accumulation_size)
152 | iterator = BucketIterator(batch_size=args.batch_size,
153 | sorting_keys=[("tokens", "num_tokens")],
154 | biggest_batch_first=True,
155 | max_instances_in_memory=args.batch_size * 20000,
156 | instances_per_epoch=instances_per_epoch,
157 | )
158 | iterator.index_with(vocab)
159 | trainer = Trainer(model=model,
160 | optimizer=optimizer,
161 | scheduler=scheduler,
162 | iterator=iterator,
163 | train_dataset=train_data,
164 | validation_dataset=dev_data,
165 | serialization_dir=args.model_dir,
166 | patience=args.patience,
167 | num_epochs=args.n_epoch,
168 | cuda_device=cuda_device,
169 | shuffle=False,
170 | accumulated_batch_count=args.accumulation_size,
171 | cold_step_count=args.cold_steps_count,
172 | cold_lr=args.cold_lr,
173 | cuda_verbose_step=int(args.cuda_verbose_steps)
174 | if args.cuda_verbose_steps else None
175 | )
176 | print("Start training")
177 | trainer.train()
178 |
179 | # Here's how to save the model.
180 | out_model = os.path.join(args.model_dir, 'model.th')
181 | with open(out_model, 'wb') as f:
182 | torch.save(model.state_dict(), f)
183 | print("Model is dumped")
184 |
185 |
186 | if __name__ == '__main__':
187 | # read parameters
188 | parser = argparse.ArgumentParser()
189 | parser.add_argument('--train_set',
190 | help='Path to the train data', required=True)
191 | parser.add_argument('--dev_set',
192 | help='Path to the dev data', required=True)
193 | parser.add_argument('--model_dir',
194 | help='Path to the model dir', required=True)
195 | parser.add_argument('--vocab_path',
196 | help='Path to the model vocabulary directory.'
197 | 'If not set then build vocab from data',
198 | default='')
199 | parser.add_argument('--batch_size',
200 | type=int,
201 | help='The size of the batch.',
202 | default=32)
203 | parser.add_argument('--max_len',
204 | type=int,
205 | help='The max sentence length'
206 | '(all longer will be truncated)',
207 | default=50)
208 | parser.add_argument('--target_vocab_size',
209 | type=int,
210 | help='The size of target vocabularies.',
211 | default=1000)
212 | parser.add_argument('--n_epoch',
213 | type=int,
214 | help='The number of epoch for training model.',
215 | default=20)
216 | parser.add_argument('--patience',
217 | type=int,
218 | help='The number of epoch with any improvements'
219 | ' on validation set.',
220 | default=3)
221 | parser.add_argument('--skip_correct',
222 | type=int,
223 | help='If set than correct sentences will be skipped '
224 | 'by data reader.',
225 | default=1)
226 | parser.add_argument('--skip_complex',
227 | type=int,
228 | help='If set than complex corrections will be skipped '
229 | 'by data reader.',
230 | choices=[0, 1, 2, 3, 4, 5],
231 | default=0)
232 | parser.add_argument('--tune_bert',
233 | type=int,
234 | help='If more then 0 then fine tune bert.',
235 | default=1)
236 | parser.add_argument('--tag_strategy',
237 | choices=['keep_one', 'merge_all'],
238 | help='The type of the data reader behaviour.',
239 | default='keep_one')
240 | parser.add_argument('--accumulation_size',
241 | type=int,
242 | help='How many batches do you want accumulate.',
243 | default=4)
244 | parser.add_argument('--lr',
245 | type=float,
246 | help='Set initial learning rate.',
247 | default=1e-5)
248 | parser.add_argument('--cold_steps_count',
249 | type=int,
250 | help='Whether to train only classifier layers first.',
251 | default=4)
252 | parser.add_argument('--cold_lr',
253 | type=float,
254 | help='Learning rate during cold_steps.',
255 | default=1e-3)
256 | parser.add_argument('--predictor_dropout',
257 | type=float,
258 | help='The value of dropout for predictor.',
259 | default=0.0)
260 | parser.add_argument('--lowercase_tokens',
261 | type=int,
262 | help='Whether to lowercase tokens.',
263 | default=0)
264 | parser.add_argument('--pieces_per_token',
265 | type=int,
266 | help='The max number for pieces per token.',
267 | default=5)
268 | parser.add_argument('--cuda_verbose_steps',
269 | help='Number of steps after which CUDA memory information is printed. '
270 | 'Makes sense for local testing. Usually about 1000.',
271 | default=None)
272 | parser.add_argument('--label_smoothing',
273 | type=float,
274 | help='The value of parameter alpha for label smoothing.',
275 | default=0.0)
276 | parser.add_argument('--tn_prob',
277 | type=float,
278 | help='The probability to take TN from data.',
279 | default=0)
280 | parser.add_argument('--tp_prob',
281 | type=float,
282 | help='The probability to take TP from data.',
283 | default=1)
284 | parser.add_argument('--updates_per_epoch',
285 | type=int,
286 | help='If set then each epoch will contain the exact amount of updates.',
287 | default=0)
288 | parser.add_argument('--pretrain_folder',
289 | help='The name of the pretrain folder.')
290 | parser.add_argument('--pretrain',
291 | help='The name of the pretrain weights in pretrain_folder param.',
292 | default='')
293 | parser.add_argument('--transformer_model',
294 | choices=['bert', 'distilbert', 'gpt2', 'roberta', 'transformerxl', 'xlnet', 'albert'],
295 | help='Name of the transformer model.',
296 | default='roberta')
297 | parser.add_argument('--special_tokens_fix',
298 | type=int,
299 | help='Whether to fix problem with [CLS], [SEP] tokens tokenization.',
300 | default=1)
301 |
302 | args = parser.parse_args()
303 | main(args)
304 |
--------------------------------------------------------------------------------
/utils/preprocess_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from difflib import SequenceMatcher
4 |
5 | import Levenshtein
6 | import numpy as np
7 | from tqdm import tqdm
8 |
9 | from helpers import write_lines, read_parallel_lines, encode_verb_form, \
10 | apply_reverse_transformation, SEQ_DELIMETERS, START_TOKEN
11 |
12 |
13 | def perfect_align(t, T, insertions_allowed=0,
14 | cost_function=Levenshtein.distance):
15 | # dp[i, j, k] is a minimal cost of matching first `i` tokens of `t` with
16 | # first `j` tokens of `T`, after making `k` insertions after last match of
17 | # token from `t`. In other words t[:i] aligned with T[:j].
18 |
19 | # Initialize with INFINITY (unknown)
20 | shape = (len(t) + 1, len(T) + 1, insertions_allowed + 1)
21 | dp = np.ones(shape, dtype=int) * int(1e9)
22 | come_from = np.ones(shape, dtype=int) * int(1e9)
23 | come_from_ins = np.ones(shape, dtype=int) * int(1e9)
24 |
25 | dp[0, 0, 0] = 0 # The only known starting point. Nothing matched to nothing.
26 | for i in range(len(t) + 1): # Go inclusive
27 | for j in range(len(T) + 1): # Go inclusive
28 | for q in range(insertions_allowed + 1): # Go inclusive
29 | if i < len(t):
30 | # Given matched sequence of t[:i] and T[:j], match token
31 | # t[i] with following tokens T[j:k].
32 | for k in range(j, len(T) + 1):
33 | transform = \
34 | apply_transformation(t[i], ' '.join(T[j:k]))
35 | if transform:
36 | cost = 0
37 | else:
38 | cost = cost_function(t[i], ' '.join(T[j:k]))
39 | current = dp[i, j, q] + cost
40 | if dp[i + 1, k, 0] > current:
41 | dp[i + 1, k, 0] = current
42 | come_from[i + 1, k, 0] = j
43 | come_from_ins[i + 1, k, 0] = q
44 | if q < insertions_allowed:
45 | # Given matched sequence of t[:i] and T[:j], create
46 | # insertion with following tokens T[j:k].
47 | for k in range(j, len(T) + 1):
48 | cost = len(' '.join(T[j:k]))
49 | current = dp[i, j, q] + cost
50 | if dp[i, k, q + 1] > current:
51 | dp[i, k, q + 1] = current
52 | come_from[i, k, q + 1] = j
53 | come_from_ins[i, k, q + 1] = q
54 |
55 | # Solution is in the dp[len(t), len(T), *]. Backtracking from there.
56 | alignment = []
57 | i = len(t)
58 | j = len(T)
59 | q = dp[i, j, :].argmin()
60 | while i > 0 or q > 0:
61 | is_insert = (come_from_ins[i, j, q] != q) and (q != 0)
62 | j, k, q = come_from[i, j, q], j, come_from_ins[i, j, q]
63 | if not is_insert:
64 | i -= 1
65 |
66 | if is_insert:
67 | alignment.append(['INSERT', T[j:k], (i, i)])
68 | else:
69 | alignment.append([f'REPLACE_{t[i]}', T[j:k], (i, i + 1)])
70 |
71 | assert j == 0
72 |
73 | return dp[len(t), len(T)].min(), list(reversed(alignment))
74 |
75 |
76 | def _split(token):
77 | if not token:
78 | return []
79 | parts = token.split()
80 | return parts or [token]
81 |
82 |
83 | def apply_merge_transformation(source_tokens, target_words, shift_idx):
84 | edits = []
85 | if len(source_tokens) > 1 and len(target_words) == 1:
86 | # check merge
87 | transform = check_merge(source_tokens, target_words)
88 | if transform:
89 | for i in range(len(source_tokens) - 1):
90 | edits.append([(shift_idx + i, shift_idx + i + 1), transform])
91 | return edits
92 |
93 | if len(source_tokens) == len(target_words) == 2:
94 | # check swap
95 | transform = check_swap(source_tokens, target_words)
96 | if transform:
97 | edits.append([(shift_idx, shift_idx + 1), transform])
98 | return edits
99 |
100 |
101 | def is_sent_ok(sent, delimeters=SEQ_DELIMETERS):
102 | for del_val in delimeters.values():
103 | if del_val in sent and del_val != " ":
104 | return False
105 | return True
106 |
107 |
108 | def check_casetype(source_token, target_token):
109 | if source_token.lower() != target_token.lower():
110 | return None
111 | if source_token.lower() == target_token:
112 | return "$TRANSFORM_CASE_LOWER"
113 | elif source_token.capitalize() == target_token:
114 | return "$TRANSFORM_CASE_CAPITAL"
115 | elif source_token.upper() == target_token:
116 | return "$TRANSFORM_CASE_UPPER"
117 | elif source_token[1:].capitalize() == target_token[1:] and source_token[0] == target_token[0]:
118 | return "$TRANSFORM_CASE_CAPITAL_1"
119 | elif source_token[:-1].upper() == target_token[:-1] and source_token[-1] == target_token[-1]:
120 | return "$TRANSFORM_CASE_UPPER_-1"
121 | else:
122 | return None
123 |
124 |
125 | def check_equal(source_token, target_token):
126 | if source_token == target_token:
127 | return "$KEEP"
128 | else:
129 | return None
130 |
131 |
132 | def check_split(source_token, target_tokens):
133 | if source_token.split("-") == target_tokens:
134 | return "$TRANSFORM_SPLIT_HYPHEN"
135 | else:
136 | return None
137 |
138 |
139 | def check_merge(source_tokens, target_tokens):
140 | if "".join(source_tokens) == "".join(target_tokens):
141 | return "$MERGE_SPACE"
142 | elif "-".join(source_tokens) == "-".join(target_tokens):
143 | return "$MERGE_HYPHEN"
144 | else:
145 | return None
146 |
147 |
148 | def check_swap(source_tokens, target_tokens):
149 | if source_tokens == [x for x in reversed(target_tokens)]:
150 | return "$MERGE_SWAP"
151 | else:
152 | return None
153 |
154 |
155 | def check_plural(source_token, target_token):
156 | if source_token.endswith("s") and source_token[:-1] == target_token:
157 | return "$TRANSFORM_AGREEMENT_SINGULAR"
158 | elif target_token.endswith("s") and source_token == target_token[:-1]:
159 | return "$TRANSFORM_AGREEMENT_PLURAL"
160 | else:
161 | return None
162 |
163 |
164 | def check_verb(source_token, target_token):
165 | encoding = encode_verb_form(source_token, target_token)
166 | if encoding:
167 | return f"$TRANSFORM_VERB_{encoding}"
168 | else:
169 | return None
170 |
171 |
172 | def apply_transformation(source_token, target_token):
173 | target_tokens = target_token.split()
174 | if len(target_tokens) > 1:
175 | # check split
176 | transform = check_split(source_token, target_tokens)
177 | if transform:
178 | return transform
179 | checks = [check_equal, check_casetype, check_verb, check_plural]
180 | for check in checks:
181 | transform = check(source_token, target_token)
182 | if transform:
183 | return transform
184 | return None
185 |
186 |
187 | def align_sequences(source_sent, target_sent):
188 | # check if sent is OK
189 | if not is_sent_ok(source_sent) or not is_sent_ok(target_sent):
190 | return None
191 | source_tokens = source_sent.split()
192 | target_tokens = target_sent.split()
193 | matcher = SequenceMatcher(None, source_tokens, target_tokens)
194 | diffs = list(matcher.get_opcodes())
195 | all_edits = []
196 | for diff in diffs:
197 | tag, i1, i2, j1, j2 = diff
198 | source_part = _split(" ".join(source_tokens[i1:i2]))
199 | target_part = _split(" ".join(target_tokens[j1:j2]))
200 | if tag == 'equal':
201 | continue
202 | elif tag == 'delete':
203 | # delete all words separatly
204 | for j in range(i2 - i1):
205 | edit = [(i1 + j, i1 + j + 1), '$DELETE']
206 | all_edits.append(edit)
207 | elif tag == 'insert':
208 | # append to the previous word
209 | for target_token in target_part:
210 | edit = ((i1 - 1, i1), f"$APPEND_{target_token}")
211 | all_edits.append(edit)
212 | else:
213 | # check merge first of all
214 | edits = apply_merge_transformation(source_part, target_part,
215 | shift_idx=i1)
216 | if edits:
217 | all_edits.extend(edits)
218 | continue
219 |
220 | # normalize alignments if need (make them singleton)
221 | _, alignments = perfect_align(source_part, target_part,
222 | insertions_allowed=0)
223 | for alignment in alignments:
224 | new_shift = alignment[2][0]
225 | edits = convert_alignments_into_edits(alignment,
226 | shift_idx=i1 + new_shift)
227 | all_edits.extend(edits)
228 |
229 | # get labels
230 | labels = convert_edits_into_labels(source_tokens, all_edits)
231 | # match tags to source tokens
232 | sent_with_tags = add_labels_to_the_tokens(source_tokens, labels)
233 | return sent_with_tags
234 |
235 |
236 | def convert_edits_into_labels(source_tokens, all_edits):
237 | # make sure that edits are flat
238 | flat_edits = []
239 | for edit in all_edits:
240 | (start, end), edit_operations = edit
241 | if isinstance(edit_operations, list):
242 | for operation in edit_operations:
243 | new_edit = [(start, end), operation]
244 | flat_edits.append(new_edit)
245 | elif isinstance(edit_operations, str):
246 | flat_edits.append(edit)
247 | else:
248 | raise Exception("Unknown operation type")
249 | all_edits = flat_edits[:]
250 | labels = []
251 | total_labels = len(source_tokens) + 1
252 | if not all_edits:
253 | labels = [["$KEEP"] for x in range(total_labels)]
254 | else:
255 | for i in range(total_labels):
256 | edit_operations = [x[1] for x in all_edits if x[0][0] == i - 1
257 | and x[0][1] == i]
258 | if not edit_operations:
259 | labels.append(["$KEEP"])
260 | else:
261 | labels.append(edit_operations)
262 | return labels
263 |
264 |
265 | def convert_alignments_into_edits(alignment, shift_idx):
266 | edits = []
267 | action, target_tokens, new_idx = alignment
268 | source_token = action.replace("REPLACE_", "")
269 |
270 | # check if delete
271 | if not target_tokens:
272 | edit = [(shift_idx, 1 + shift_idx), "$DELETE"]
273 | return [edit]
274 |
275 | # check splits
276 | for i in range(1, len(target_tokens)):
277 | target_token = " ".join(target_tokens[:i + 1])
278 | transform = apply_transformation(source_token, target_token)
279 | if transform:
280 | edit = [(shift_idx, shift_idx + 1), transform]
281 | edits.append(edit)
282 | target_tokens = target_tokens[i + 1:]
283 | for target in target_tokens:
284 | edits.append([(shift_idx, shift_idx + 1), f"$APPEND_{target}"])
285 | return edits
286 |
287 | transform_costs = []
288 | transforms = []
289 | for target_token in target_tokens:
290 | transform = apply_transformation(source_token, target_token)
291 | if transform:
292 | cost = 0
293 | transforms.append(transform)
294 | else:
295 | cost = Levenshtein.distance(source_token, target_token)
296 | transforms.append(None)
297 | transform_costs.append(cost)
298 | min_cost_idx = transform_costs.index(min(transform_costs))
299 | # append to the previous word
300 | for i in range(0, min_cost_idx):
301 | target = target_tokens[i]
302 | edit = [(shift_idx - 1, shift_idx), f"$APPEND_{target}"]
303 | edits.append(edit)
304 | # replace/transform target word
305 | transform = transforms[min_cost_idx]
306 | target = transform if transform is not None \
307 | else f"$REPLACE_{target_tokens[min_cost_idx]}"
308 | edit = [(shift_idx, 1 + shift_idx), target]
309 | edits.append(edit)
310 | # append to this word
311 | for i in range(min_cost_idx + 1, len(target_tokens)):
312 | target = target_tokens[i]
313 | edit = [(shift_idx, 1 + shift_idx), f"$APPEND_{target}"]
314 | edits.append(edit)
315 | return edits
316 |
317 |
318 | def add_labels_to_the_tokens(source_tokens, labels, delimeters=SEQ_DELIMETERS):
319 | tokens_with_all_tags = []
320 | source_tokens_with_start = [START_TOKEN] + source_tokens
321 | for token, label_list in zip(source_tokens_with_start, labels):
322 | all_tags = delimeters['operations'].join(label_list)
323 | comb_record = token + delimeters['labels'] + all_tags
324 | tokens_with_all_tags.append(comb_record)
325 | return delimeters['tokens'].join(tokens_with_all_tags)
326 |
327 |
328 | def convert_data_from_raw_files(source_file, target_file, output_file, chunk_size):
329 | tagged = []
330 | source_data, target_data = read_parallel_lines(source_file, target_file)
331 | print(f"The size of raw dataset is {len(source_data)}")
332 | cnt_total, cnt_all, cnt_tp = 0, 0, 0
333 | for source_sent, target_sent in tqdm(zip(source_data, target_data)):
334 | try:
335 | aligned_sent = align_sequences(source_sent, target_sent)
336 | except Exception:
337 | aligned_sent = align_sequences(source_sent, target_sent)
338 | if source_sent != target_sent:
339 | cnt_tp += 1
340 | alignments = [aligned_sent]
341 | cnt_all += len(alignments)
342 | try:
343 | check_sent = convert_tagged_line(aligned_sent)
344 | except Exception:
345 | # debug mode
346 | aligned_sent = align_sequences(source_sent, target_sent)
347 | check_sent = convert_tagged_line(aligned_sent)
348 |
349 | if "".join(check_sent.split()) != "".join(
350 | target_sent.split()):
351 | # do it again for debugging
352 | aligned_sent = align_sequences(source_sent, target_sent)
353 | check_sent = convert_tagged_line(aligned_sent)
354 | print(f"Incorrect pair: \n{target_sent}\n{check_sent}")
355 | continue
356 | if alignments:
357 | cnt_total += len(alignments)
358 | tagged.extend(alignments)
359 | if len(tagged) > chunk_size:
360 | write_lines(output_file, tagged, mode='a')
361 | tagged = []
362 |
363 | print(f"Overall extracted {cnt_total}. "
364 | f"Original TP {cnt_tp}."
365 | f" Original TN {cnt_all - cnt_tp}")
366 | if tagged:
367 | write_lines(output_file, tagged, 'a')
368 |
369 |
370 | def convert_labels_into_edits(labels):
371 | all_edits = []
372 | for i, label_list in enumerate(labels):
373 | if label_list == ["$KEEP"]:
374 | continue
375 | else:
376 | edit = [(i - 1, i), label_list]
377 | all_edits.append(edit)
378 | return all_edits
379 |
380 |
381 | def get_target_sent_by_levels(source_tokens, labels):
382 | relevant_edits = convert_labels_into_edits(labels)
383 | target_tokens = source_tokens[:]
384 | leveled_target_tokens = {}
385 | if not relevant_edits:
386 | target_sentence = " ".join(target_tokens)
387 | return leveled_target_tokens, target_sentence
388 | max_level = max([len(x[1]) for x in relevant_edits])
389 | for level in range(max_level):
390 | rest_edits = []
391 | shift_idx = 0
392 | for edits in relevant_edits:
393 | (start, end), label_list = edits
394 | label = label_list[0]
395 | target_pos = start + shift_idx
396 | source_token = target_tokens[target_pos] if target_pos >= 0 else START_TOKEN
397 | if label == "$DELETE":
398 | del target_tokens[target_pos]
399 | shift_idx -= 1
400 | elif label.startswith("$APPEND_"):
401 | word = label.replace("$APPEND_", "")
402 | target_tokens[target_pos + 1: target_pos + 1] = [word]
403 | shift_idx += 1
404 | elif label.startswith("$REPLACE_"):
405 | word = label.replace("$REPLACE_", "")
406 | target_tokens[target_pos] = word
407 | elif label.startswith("$TRANSFORM"):
408 | word = apply_reverse_transformation(source_token, label)
409 | if word is None:
410 | word = source_token
411 | target_tokens[target_pos] = word
412 | elif label.startswith("$MERGE_"):
413 | # apply merge only on last stage
414 | if level == (max_level - 1):
415 | target_tokens[target_pos + 1: target_pos + 1] = [label]
416 | shift_idx += 1
417 | else:
418 | rest_edit = [(start + shift_idx, end + shift_idx), [label]]
419 | rest_edits.append(rest_edit)
420 | rest_labels = label_list[1:]
421 | if rest_labels:
422 | rest_edit = [(start + shift_idx, end + shift_idx), rest_labels]
423 | rest_edits.append(rest_edit)
424 |
425 | leveled_tokens = target_tokens[:]
426 | # update next step
427 | relevant_edits = rest_edits[:]
428 | if level == (max_level - 1):
429 | leveled_tokens = replace_merge_transforms(leveled_tokens)
430 | leveled_labels = convert_edits_into_labels(leveled_tokens,
431 | relevant_edits)
432 | leveled_target_tokens[level + 1] = {"tokens": leveled_tokens,
433 | "labels": leveled_labels}
434 |
435 | target_sentence = " ".join(leveled_target_tokens[max_level]["tokens"])
436 | return leveled_target_tokens, target_sentence
437 |
438 |
439 | def replace_merge_transforms(tokens):
440 | if all(not x.startswith("$MERGE_") for x in tokens):
441 | return tokens
442 | target_tokens = tokens[:]
443 | allowed_range = (1, len(tokens) - 1)
444 | for i in range(len(tokens)):
445 | target_token = tokens[i]
446 | if target_token.startswith("$MERGE"):
447 | if target_token.startswith("$MERGE_SWAP") and i in allowed_range:
448 | target_tokens[i - 1] = tokens[i + 1]
449 | target_tokens[i + 1] = tokens[i - 1]
450 | target_tokens[i: i + 1] = []
451 | target_line = " ".join(target_tokens)
452 | target_line = target_line.replace(" $MERGE_HYPHEN ", "-")
453 | target_line = target_line.replace(" $MERGE_SPACE ", "")
454 | return target_line.split()
455 |
456 |
457 | def convert_tagged_line(line, delimeters=SEQ_DELIMETERS):
458 | label_del = delimeters['labels']
459 | source_tokens = [x.split(label_del)[0]
460 | for x in line.split(delimeters['tokens'])][1:]
461 | labels = [x.split(label_del)[1].split(delimeters['operations'])
462 | for x in line.split(delimeters['tokens'])]
463 | assert len(source_tokens) + 1 == len(labels)
464 | levels_dict, target_line = get_target_sent_by_levels(source_tokens, labels)
465 | return target_line
466 |
467 |
468 | def main(args):
469 | convert_data_from_raw_files(args.source, args.target, args.output_file, args.chunk_size)
470 |
471 |
472 | if __name__ == '__main__':
473 | parser = argparse.ArgumentParser()
474 | parser.add_argument('-s', '--source',
475 | help='Path to the source file',
476 | required=True)
477 | parser.add_argument('-t', '--target',
478 | help='Path to the target file',
479 | required=True)
480 | parser.add_argument('-o', '--output_file',
481 | help='Path to the output file',
482 | required=True)
483 | parser.add_argument('--chunk_size',
484 | type=int,
485 | help='Dump each chunk size.',
486 | default=1000000)
487 | args = parser.parse_args()
488 | main(args)
489 |
--------------------------------------------------------------------------------
/gector/wordpiece_indexer.py:
--------------------------------------------------------------------------------
1 | """Tweaked version of corresponding AllenNLP file"""
2 | import logging
3 | from collections import defaultdict
4 | from typing import Dict, List, Callable
5 |
6 | from allennlp.common.util import pad_sequence_to_length
7 | from allennlp.data.token_indexers.token_indexer import TokenIndexer
8 | from allennlp.data.tokenizers.token import Token
9 | from allennlp.data.vocabulary import Vocabulary
10 | from overrides import overrides
11 | from transformers import AutoTokenizer
12 |
13 | from utils.helpers import START_TOKEN
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 | # TODO(joelgrus): Figure out how to generate token_type_ids out of this token indexer.
18 |
19 | # This is the default list of tokens that should not be lowercased.
20 | _NEVER_LOWERCASE = ['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]']
21 |
22 |
23 | class WordpieceIndexer(TokenIndexer[int]):
24 | """
25 | A token indexer that does the wordpiece-tokenization (e.g. for BERT embeddings).
26 | If you are using one of the pretrained BERT models, you'll want to use the ``PretrainedBertIndexer``
27 | subclass rather than this base class.
28 |
29 | Parameters
30 | ----------
31 | vocab : ``Dict[str, int]``
32 | The mapping {wordpiece -> id}. Note this is not an AllenNLP ``Vocabulary``.
33 | wordpiece_tokenizer : ``Callable[[str], List[str]]``
34 | A function that does the actual tokenization.
35 | namespace : str, optional (default: "wordpiece")
36 | The namespace in the AllenNLP ``Vocabulary`` into which the wordpieces
37 | will be loaded.
38 | use_starting_offsets : bool, optional (default: False)
39 | By default, the "offsets" created by the token indexer correspond to the
40 | last wordpiece in each word. If ``use_starting_offsets`` is specified,
41 | they will instead correspond to the first wordpiece in each word.
42 | max_pieces : int, optional (default: 512)
43 | The BERT embedder uses positional embeddings and so has a corresponding
44 | maximum length for its input ids. Any inputs longer than this will
45 | either be truncated (default), or be split apart and batched using a
46 | sliding window.
47 | do_lowercase : ``bool``, optional (default=``False``)
48 | Should we lowercase the provided tokens before getting the indices?
49 | You would need to do this if you are using an -uncased BERT model
50 | but your DatasetReader is not lowercasing tokens (which might be the
51 | case if you're also using other embeddings based on cased tokens).
52 | never_lowercase: ``List[str]``, optional
53 | Tokens that should never be lowercased. Default is
54 | ['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]'].
55 | start_tokens : ``List[str]``, optional (default=``None``)
56 | These are prepended to the tokens provided to ``tokens_to_indices``.
57 | end_tokens : ``List[str]``, optional (default=``None``)
58 | These are appended to the tokens provided to ``tokens_to_indices``.
59 | separator_token : ``str``, optional (default=``[SEP]``)
60 | This token indicates the segments in the sequence.
61 | truncate_long_sequences : ``bool``, optional (default=``True``)
62 | By default, long sequences will be truncated to the maximum sequence
63 | length. Otherwise, they will be split apart and batched using a
64 | sliding window.
65 | token_min_padding_length : ``int``, optional (default=``0``)
66 | See :class:`TokenIndexer`.
67 | """
68 |
69 | def __init__(self,
70 | vocab: Dict[str, int],
71 | bpe_ranks: Dict,
72 | byte_encoder: Dict,
73 | wordpiece_tokenizer: Callable[[str], List[str]],
74 | namespace: str = "wordpiece",
75 | use_starting_offsets: bool = False,
76 | max_pieces: int = 512,
77 | max_pieces_per_token: int = 3,
78 | is_test=False,
79 | do_lowercase: bool = False,
80 | never_lowercase: List[str] = None,
81 | start_tokens: List[str] = None,
82 | end_tokens: List[str] = None,
83 | truncate_long_sequences: bool = True,
84 | token_min_padding_length: int = 0) -> None:
85 | super().__init__(token_min_padding_length)
86 | self.vocab = vocab
87 |
88 | # The BERT code itself does a two-step tokenization:
89 | # sentence -> [words], and then word -> [wordpieces]
90 | # In AllenNLP, the first step is implemented as the ``BertBasicWordSplitter``,
91 | # and this token indexer handles the second.
92 | self.wordpiece_tokenizer = wordpiece_tokenizer
93 | self.max_pieces_per_token = max_pieces_per_token
94 | self._namespace = namespace
95 | self._added_to_vocabulary = False
96 | self.max_pieces = max_pieces
97 | self.use_starting_offsets = use_starting_offsets
98 | self._do_lowercase = do_lowercase
99 | self._truncate_long_sequences = truncate_long_sequences
100 | self.max_pieces_per_sentence = 80
101 | self.is_test = is_test
102 | self.cache = {}
103 | self.bpe_ranks = bpe_ranks
104 | self.byte_encoder = byte_encoder
105 |
106 | if self.is_test:
107 | self.max_pieces_per_token = None
108 |
109 | if never_lowercase is None:
110 | # Use the defaults
111 | self._never_lowercase = set(_NEVER_LOWERCASE)
112 | else:
113 | self._never_lowercase = set(never_lowercase)
114 |
115 | # Convert the start_tokens and end_tokens to wordpiece_ids
116 | self._start_piece_ids = [vocab[wordpiece]
117 | for token in (start_tokens or [])
118 | for wordpiece in wordpiece_tokenizer(token)]
119 | self._end_piece_ids = [vocab[wordpiece]
120 | for token in (end_tokens or [])
121 | for wordpiece in wordpiece_tokenizer(token)]
122 |
123 | @overrides
124 | def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]):
125 | # If we only use pretrained models, we don't need to do anything here.
126 | pass
127 |
128 | def _add_encoding_to_vocabulary(self, vocabulary: Vocabulary) -> None:
129 | # pylint: disable=protected-access
130 | for word, idx in self.vocab.items():
131 | vocabulary._token_to_index[self._namespace][word] = idx
132 | vocabulary._index_to_token[self._namespace][idx] = word
133 |
134 | def get_pairs(self, word):
135 | """Return set of symbol pairs in a word.
136 |
137 | Word is represented as tuple of symbols (symbols being variable-length strings).
138 | """
139 | pairs = set()
140 | prev_char = word[0]
141 | for char in word[1:]:
142 | pairs.add((prev_char, char))
143 | prev_char = char
144 | return pairs
145 |
146 | def bpe(self, token):
147 | if token in self.cache:
148 | return self.cache[token]
149 | word = tuple(token)
150 | pairs = self.get_pairs(word)
151 |
152 | if not pairs:
153 | return token
154 |
155 | while True:
156 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair,
157 | float(
158 | 'inf')))
159 | if bigram not in self.bpe_ranks:
160 | break
161 | first, second = bigram
162 | new_word = []
163 | i = 0
164 | while i < len(word):
165 | try:
166 | j = word.index(first, i)
167 | new_word.extend(word[i:j])
168 | i = j
169 | except:
170 | new_word.extend(word[i:])
171 | break
172 |
173 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
174 | new_word.append(first + second)
175 | i += 2
176 | else:
177 | new_word.append(word[i])
178 | i += 1
179 | new_word = tuple(new_word)
180 | word = new_word
181 | if len(word) == 1:
182 | break
183 | else:
184 | pairs = self.get_pairs(word)
185 | word = ' '.join(word)
186 | self.cache[token] = word
187 | return word
188 |
189 | def bpe_tokenize(self, text):
190 | """ Tokenize a string."""
191 | bpe_tokens = []
192 | for token in text.split():
193 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
194 | bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
195 | return bpe_tokens
196 |
197 | @overrides
198 | def tokens_to_indices(self,
199 | tokens: List[Token],
200 | vocabulary: Vocabulary,
201 | index_name: str) -> Dict[str, List[int]]:
202 | if not self._added_to_vocabulary:
203 | self._add_encoding_to_vocabulary(vocabulary)
204 | self._added_to_vocabulary = True
205 |
206 | # This lowercases tokens if necessary
207 | text = (token.text.lower()
208 | if self._do_lowercase and token.text not in self._never_lowercase
209 | else token.text
210 | for token in tokens)
211 |
212 | # Obtain a nested sequence of wordpieces, each represented by a list of wordpiece ids
213 | token_wordpiece_ids = []
214 | for token in text:
215 | if self.bpe_ranks != {}:
216 | wps = self.bpe_tokenize(token)
217 | else:
218 | wps = self.wordpiece_tokenizer(token)
219 | limited_wps = [self.vocab[wordpiece] for wordpiece in wps][:self.max_pieces_per_token]
220 | token_wordpiece_ids.append(limited_wps)
221 |
222 | # Flattened list of wordpieces. In the end, the output of the model (e.g., BERT) should
223 | # have a sequence length equal to the length of this list. However, it will first be split into
224 | # chunks of length `self.max_pieces` so that they can be fit through the model. After packing
225 | # and passing through the model, it should be unpacked to represent the wordpieces in this list.
226 | flat_wordpiece_ids = [wordpiece for token in token_wordpiece_ids for wordpiece in token]
227 |
228 | # reduce max_pieces_per_token if piece length of sentence is bigger than max_pieces_per_sentence
229 | # helps to fix CUDA out of memory errors meanwhile increasing batch size
230 | while not self.is_test and len(flat_wordpiece_ids) > \
231 | self.max_pieces_per_sentence - len(self._start_piece_ids) - len(self._end_piece_ids):
232 | max_pieces = max([len(row) for row in token_wordpiece_ids])
233 | token_wordpiece_ids = [row[:max_pieces - 1] for row in token_wordpiece_ids]
234 | flat_wordpiece_ids = [wordpiece for token in token_wordpiece_ids for wordpiece in token]
235 |
236 | # The code below will (possibly) pack the wordpiece sequence into multiple sub-sequences by using a sliding
237 | # window `window_length` that overlaps with previous windows according to the `stride`. Suppose we have
238 | # the following sentence: "I went to the store to buy some milk". Then a sliding window of length 4 and
239 | # stride of length 2 will split them up into:
240 |
241 | # "[I went to the] [to the store to] [store to buy some] [buy some milk [PAD]]".
242 |
243 | # This is to ensure that the model has context of as much of the sentence as possible to get accurate
244 | # embeddings. Finally, the sequences will be padded with any start/end piece ids, e.g.,
245 |
246 | # "[CLS] I went to the [SEP] [CLS] to the store to [SEP] ...".
247 |
248 | # The embedder should then be able to split this token sequence by the window length,
249 | # pass them through the model, and recombine them.
250 |
251 | # Specify the stride to be half of `self.max_pieces`, minus any additional start/end wordpieces
252 | window_length = self.max_pieces - len(self._start_piece_ids) - len(self._end_piece_ids)
253 | stride = window_length // 2
254 |
255 | # offsets[i] will give us the index into wordpiece_ids
256 | # for the wordpiece "corresponding to" the i-th input token.
257 | offsets = []
258 |
259 | # If we're using initial offsets, we want to start at offset = len(text_tokens)
260 | # so that the first offset is the index of the first wordpiece of tokens[0].
261 | # Otherwise, we want to start at len(text_tokens) - 1, so that the "previous"
262 | # offset is the last wordpiece of "tokens[-1]".
263 | offset = len(self._start_piece_ids) if self.use_starting_offsets else len(self._start_piece_ids) - 1
264 |
265 | for token in token_wordpiece_ids:
266 | # Truncate the sequence if specified, which depends on where the offsets are
267 | next_offset = 1 if self.use_starting_offsets else 0
268 | if self._truncate_long_sequences and offset >= window_length + next_offset:
269 | break
270 |
271 | # For initial offsets, the current value of ``offset`` is the start of
272 | # the current wordpiece, so add it to ``offsets`` and then increment it.
273 | if self.use_starting_offsets:
274 | offsets.append(offset)
275 | offset += len(token)
276 | # For final offsets, the current value of ``offset`` is the end of
277 | # the previous wordpiece, so increment it and then add it to ``offsets``.
278 | else:
279 | offset += len(token)
280 | offsets.append(offset)
281 |
282 | if len(flat_wordpiece_ids) <= window_length:
283 | # If all the wordpieces fit, then we don't need to do anything special
284 | wordpiece_windows = [self._add_start_and_end(flat_wordpiece_ids)]
285 | elif self._truncate_long_sequences:
286 | logger.warning("Too many wordpieces, truncating sequence. If you would like a sliding window, set"
287 | "`truncate_long_sequences` to False %s", str([token.text for token in tokens]))
288 | wordpiece_windows = [self._add_start_and_end(flat_wordpiece_ids[:window_length])]
289 | else:
290 | # Create a sliding window of wordpieces of length `max_pieces` that advances by `stride` steps and
291 | # add start/end wordpieces to each window
292 | # TODO: this currently does not respect word boundaries, so words may be cut in half between windows
293 | # However, this would increase complexity, as sequences would need to be padded/unpadded in the middle
294 | wordpiece_windows = [self._add_start_and_end(flat_wordpiece_ids[i:i + window_length])
295 | for i in range(0, len(flat_wordpiece_ids), stride)]
296 |
297 | # Check for overlap in the last window. Throw it away if it is redundant.
298 | last_window = wordpiece_windows[-1][1:]
299 | penultimate_window = wordpiece_windows[-2]
300 | if last_window == penultimate_window[-len(last_window):]:
301 | wordpiece_windows = wordpiece_windows[:-1]
302 |
303 | # Flatten the wordpiece windows
304 | wordpiece_ids = [wordpiece for sequence in wordpiece_windows for wordpiece in sequence]
305 |
306 | # Our mask should correspond to the original tokens,
307 | # because calling util.get_text_field_mask on the
308 | # "wordpiece_id" tokens will produce the wrong shape.
309 | # However, because of the max_pieces constraint, we may
310 | # have truncated the wordpieces; accordingly, we want the mask
311 | # to correspond to the remaining tokens after truncation, which
312 | # is captured by the offsets.
313 | mask = [1 for _ in offsets]
314 |
315 | return {index_name: wordpiece_ids,
316 | f"{index_name}-offsets": offsets,
317 | "mask": mask}
318 |
319 | def _add_start_and_end(self, wordpiece_ids: List[int]) -> List[int]:
320 | return self._start_piece_ids + wordpiece_ids + self._end_piece_ids
321 |
322 | def _extend(self, token_type_ids: List[int]) -> List[int]:
323 | """
324 | Extend the token type ids by len(start_piece_ids) on the left
325 | and len(end_piece_ids) on the right.
326 | """
327 | first = token_type_ids[0]
328 | last = token_type_ids[-1]
329 | return ([first for _ in self._start_piece_ids] +
330 | token_type_ids +
331 | [last for _ in self._end_piece_ids])
332 |
333 | @overrides
334 | def get_padding_token(self) -> int:
335 | return 0
336 |
337 | @overrides
338 | def get_padding_lengths(self, token: int) -> Dict[str, int]: # pylint: disable=unused-argument
339 | return {}
340 |
341 | @overrides
342 | def pad_token_sequence(self,
343 | tokens: Dict[str, List[int]],
344 | desired_num_tokens: Dict[str, int],
345 | padding_lengths: Dict[str, int]) -> Dict[str, List[int]]: # pylint: disable=unused-argument
346 | return {key: pad_sequence_to_length(val, desired_num_tokens[key])
347 | for key, val in tokens.items()}
348 |
349 | @overrides
350 | def get_keys(self, index_name: str) -> List[str]:
351 | """
352 | We need to override this because the indexer generates multiple keys.
353 | """
354 | # pylint: disable=no-self-use
355 | return [index_name, f"{index_name}-offsets", f"{index_name}-type-ids", "mask"]
356 |
357 |
358 | class PretrainedBertIndexer(WordpieceIndexer):
359 | # pylint: disable=line-too-long
360 | """
361 | A ``TokenIndexer`` corresponding to a pretrained BERT model.
362 |
363 | Parameters
364 | ----------
365 | pretrained_model: ``str``
366 | Either the name of the pretrained model to use (e.g. 'bert-base-uncased'),
367 | or the path to the .txt file with its vocabulary.
368 |
369 | If the name is a key in the list of pretrained models at
370 | https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/tokenization.py#L33
371 | the corresponding path will be used; otherwise it will be interpreted as a path or URL.
372 | use_starting_offsets: bool, optional (default: False)
373 | By default, the "offsets" created by the token indexer correspond to the
374 | last wordpiece in each word. If ``use_starting_offsets`` is specified,
375 | they will instead correspond to the first wordpiece in each word.
376 | do_lowercase: ``bool``, optional (default = True)
377 | Whether to lowercase the tokens before converting to wordpiece ids.
378 | never_lowercase: ``List[str]``, optional
379 | Tokens that should never be lowercased. Default is
380 | ['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]'].
381 | max_pieces: int, optional (default: 512)
382 | The BERT embedder uses positional embeddings and so has a corresponding
383 | maximum length for its input ids. Any inputs longer than this will
384 | either be truncated (default), or be split apart and batched using a
385 | sliding window.
386 | truncate_long_sequences : ``bool``, optional (default=``True``)
387 | By default, long sequences will be truncated to the maximum sequence
388 | length. Otherwise, they will be split apart and batched using a
389 | sliding window.
390 | """
391 |
392 | def __init__(self,
393 | pretrained_model: str,
394 | use_starting_offsets: bool = False,
395 | do_lowercase: bool = True,
396 | never_lowercase: List[str] = None,
397 | max_pieces: int = 512,
398 | max_pieces_per_token=5,
399 | is_test=False,
400 | truncate_long_sequences: bool = True,
401 | special_tokens_fix: int = 0) -> None:
402 | if pretrained_model.endswith("-cased") and do_lowercase:
403 | logger.warning("Your BERT model appears to be cased, "
404 | "but your indexer is lowercasing tokens.")
405 | elif pretrained_model.endswith("-uncased") and not do_lowercase:
406 | logger.warning("Your BERT model appears to be uncased, "
407 | "but your indexer is not lowercasing tokens.")
408 |
409 | bert_tokenizer = AutoTokenizer.from_pretrained(
410 | pretrained_model, do_lower_case=do_lowercase, do_basic_tokenize=False)
411 |
412 | # to adjust all tokenizers
413 | if hasattr(bert_tokenizer, 'encoder'):
414 | bert_tokenizer.vocab = bert_tokenizer.encoder
415 | if hasattr(bert_tokenizer, 'sp_model'):
416 | bert_tokenizer.vocab = defaultdict(lambda: 1)
417 | for i in range(bert_tokenizer.sp_model.get_piece_size()):
418 | bert_tokenizer.vocab[bert_tokenizer.sp_model.id_to_piece(i)] = i
419 |
420 | if special_tokens_fix:
421 | bert_tokenizer.add_tokens([START_TOKEN])
422 | bert_tokenizer.vocab[START_TOKEN] = len(bert_tokenizer) - 1
423 |
424 | if "roberta" in pretrained_model:
425 | bpe_ranks = bert_tokenizer.bpe_ranks
426 | byte_encoder = bert_tokenizer.byte_encoder
427 | else:
428 | bpe_ranks = {}
429 | byte_encoder = None
430 |
431 | super().__init__(vocab=bert_tokenizer.vocab,
432 | bpe_ranks=bpe_ranks,
433 | byte_encoder=byte_encoder,
434 | wordpiece_tokenizer=bert_tokenizer.tokenize,
435 | namespace="bert",
436 | use_starting_offsets=use_starting_offsets,
437 | max_pieces=max_pieces,
438 | max_pieces_per_token=max_pieces_per_token,
439 | is_test=is_test,
440 | do_lowercase=do_lowercase,
441 | never_lowercase=never_lowercase,
442 | start_tokens=["[CLS]"] if not special_tokens_fix else [],
443 | end_tokens=["[SEP]"] if not special_tokens_fix else [],
444 | truncate_long_sequences=truncate_long_sequences)
445 |
--------------------------------------------------------------------------------
/gector/trainer.py:
--------------------------------------------------------------------------------
1 | """Tweaked version of corresponding AllenNLP file"""
2 | import datetime
3 | import logging
4 | import math
5 | import os
6 | import time
7 | import traceback
8 | from typing import Dict, Optional, List, Tuple, Union, Iterable, Any
9 |
10 | import torch
11 | import torch.optim.lr_scheduler
12 | from allennlp.common import Params
13 | from allennlp.common.checks import ConfigurationError, parse_cuda_device
14 | from allennlp.common.tqdm import Tqdm
15 | from allennlp.common.util import dump_metrics, gpu_memory_mb, peak_memory_mb, lazy_groups_of
16 | from allennlp.data.instance import Instance
17 | from allennlp.data.iterators.data_iterator import DataIterator, TensorDict
18 | from allennlp.models.model import Model
19 | from allennlp.nn import util as nn_util
20 | from allennlp.training import util as training_util
21 | from allennlp.training.checkpointer import Checkpointer
22 | from allennlp.training.learning_rate_schedulers import LearningRateScheduler
23 | from allennlp.training.metric_tracker import MetricTracker
24 | from allennlp.training.momentum_schedulers import MomentumScheduler
25 | from allennlp.training.moving_average import MovingAverage
26 | from allennlp.training.optimizers import Optimizer
27 | from allennlp.training.tensorboard_writer import TensorboardWriter
28 | from allennlp.training.trainer_base import TrainerBase
29 |
30 | logger = logging.getLogger(__name__)
31 |
32 |
33 | class Trainer(TrainerBase):
34 | def __init__(
35 | self,
36 | model: Model,
37 | optimizer: torch.optim.Optimizer,
38 | scheduler: torch.optim.lr_scheduler,
39 | iterator: DataIterator,
40 | train_dataset: Iterable[Instance],
41 | validation_dataset: Optional[Iterable[Instance]] = None,
42 | patience: Optional[int] = None,
43 | validation_metric: str = "-loss",
44 | validation_iterator: DataIterator = None,
45 | shuffle: bool = True,
46 | num_epochs: int = 20,
47 | accumulated_batch_count: int = 1,
48 | serialization_dir: Optional[str] = None,
49 | num_serialized_models_to_keep: int = 20,
50 | keep_serialized_model_every_num_seconds: int = None,
51 | checkpointer: Checkpointer = None,
52 | model_save_interval: float = None,
53 | cuda_device: Union[int, List] = -1,
54 | grad_norm: Optional[float] = None,
55 | grad_clipping: Optional[float] = None,
56 | learning_rate_scheduler: Optional[LearningRateScheduler] = None,
57 | momentum_scheduler: Optional[MomentumScheduler] = None,
58 | summary_interval: int = 100,
59 | histogram_interval: int = None,
60 | should_log_parameter_statistics: bool = True,
61 | should_log_learning_rate: bool = False,
62 | log_batch_size_period: Optional[int] = None,
63 | moving_average: Optional[MovingAverage] = None,
64 | cold_step_count: int = 0,
65 | cold_lr: float = 1e-3,
66 | cuda_verbose_step=None,
67 | ) -> None:
68 | """
69 | A trainer for doing supervised learning. It just takes a labeled dataset
70 | and a ``DataIterator``, and uses the supplied ``Optimizer`` to learn the weights
71 | for your model over some fixed number of epochs. You can also pass in a validation
72 | dataset and enable early stopping. There are many other bells and whistles as well.
73 |
74 | Parameters
75 | ----------
76 | model : ``Model``, required.
77 | An AllenNLP model to be optimized. Pytorch Modules can also be optimized if
78 | their ``forward`` method returns a dictionary with a "loss" key, containing a
79 | scalar tensor representing the loss function to be optimized.
80 |
81 | If you are training your model using GPUs, your model should already be
82 | on the correct device. (If you use `Trainer.from_params` this will be
83 | handled for you.)
84 | optimizer : ``torch.nn.Optimizer``, required.
85 | An instance of a Pytorch Optimizer, instantiated with the parameters of the
86 | model to be optimized.
87 | iterator : ``DataIterator``, required.
88 | A method for iterating over a ``Dataset``, yielding padded indexed batches.
89 | train_dataset : ``Dataset``, required.
90 | A ``Dataset`` to train on. The dataset should have already been indexed.
91 | validation_dataset : ``Dataset``, optional, (default = None).
92 | A ``Dataset`` to evaluate on. The dataset should have already been indexed.
93 | patience : Optional[int] > 0, optional (default=None)
94 | Number of epochs to be patient before early stopping: the training is stopped
95 | after ``patience`` epochs with no improvement. If given, it must be ``> 0``.
96 | If None, early stopping is disabled.
97 | validation_metric : str, optional (default="loss")
98 | Validation metric to measure for whether to stop training using patience
99 | and whether to serialize an ``is_best`` model each epoch. The metric name
100 | must be prepended with either "+" or "-", which specifies whether the metric
101 | is an increasing or decreasing function.
102 | validation_iterator : ``DataIterator``, optional (default=None)
103 | An iterator to use for the validation set. If ``None``, then
104 | use the training `iterator`.
105 | shuffle: ``bool``, optional (default=True)
106 | Whether to shuffle the instances in the iterator or not.
107 | num_epochs : int, optional (default = 20)
108 | Number of training epochs.
109 | serialization_dir : str, optional (default=None)
110 | Path to directory for saving and loading model files. Models will not be saved if
111 | this parameter is not passed.
112 | num_serialized_models_to_keep : ``int``, optional (default=20)
113 | Number of previous model checkpoints to retain. Default is to keep 20 checkpoints.
114 | A value of None or -1 means all checkpoints will be kept.
115 | keep_serialized_model_every_num_seconds : ``int``, optional (default=None)
116 | If num_serialized_models_to_keep is not None, then occasionally it's useful to
117 | save models at a given interval in addition to the last num_serialized_models_to_keep.
118 | To do so, specify keep_serialized_model_every_num_seconds as the number of seconds
119 | between permanently saved checkpoints. Note that this option is only used if
120 | num_serialized_models_to_keep is not None, otherwise all checkpoints are kept.
121 | checkpointer : ``Checkpointer``, optional (default=None)
122 | An instance of class Checkpointer to use instead of the default. If a checkpointer is specified,
123 | the arguments num_serialized_models_to_keep and keep_serialized_model_every_num_seconds should
124 | not be specified. The caller is responsible for initializing the checkpointer so that it is
125 | consistent with serialization_dir.
126 | model_save_interval : ``float``, optional (default=None)
127 | If provided, then serialize models every ``model_save_interval``
128 | seconds within single epochs. In all cases, models are also saved
129 | at the end of every epoch if ``serialization_dir`` is provided.
130 | cuda_device : ``Union[int, List[int]]``, optional (default = -1)
131 | An integer or list of integers specifying the CUDA device(s) to use. If -1, the CPU is used.
132 | grad_norm : ``float``, optional, (default = None).
133 | If provided, gradient norms will be rescaled to have a maximum of this value.
134 | grad_clipping : ``float``, optional (default = ``None``).
135 | If provided, gradients will be clipped `during the backward pass` to have an (absolute)
136 | maximum of this value. If you are getting ``NaNs`` in your gradients during training
137 | that are not solved by using ``grad_norm``, you may need this.
138 | learning_rate_scheduler : ``LearningRateScheduler``, optional (default = None)
139 | If specified, the learning rate will be decayed with respect to
140 | this schedule at the end of each epoch (or batch, if the scheduler implements
141 | the ``step_batch`` method). If you use :class:`torch.optim.lr_scheduler.ReduceLROnPlateau`,
142 | this will use the ``validation_metric`` provided to determine if learning has plateaued.
143 | To support updating the learning rate on every batch, this can optionally implement
144 | ``step_batch(batch_num_total)`` which updates the learning rate given the batch number.
145 | momentum_scheduler : ``MomentumScheduler``, optional (default = None)
146 | If specified, the momentum will be updated at the end of each batch or epoch
147 | according to the schedule.
148 | summary_interval: ``int``, optional, (default = 100)
149 | Number of batches between logging scalars to tensorboard
150 | histogram_interval : ``int``, optional, (default = ``None``)
151 | If not None, then log histograms to tensorboard every ``histogram_interval`` batches.
152 | When this parameter is specified, the following additional logging is enabled:
153 | * Histograms of model parameters
154 | * The ratio of parameter update norm to parameter norm
155 | * Histogram of layer activations
156 | We log histograms of the parameters returned by
157 | ``model.get_parameters_for_histogram_tensorboard_logging``.
158 | The layer activations are logged for any modules in the ``Model`` that have
159 | the attribute ``should_log_activations`` set to ``True``. Logging
160 | histograms requires a number of GPU-CPU copies during training and is typically
161 | slow, so we recommend logging histograms relatively infrequently.
162 | Note: only Modules that return tensors, tuples of tensors or dicts
163 | with tensors as values currently support activation logging.
164 | should_log_parameter_statistics : ``bool``, optional, (default = True)
165 | Whether to send parameter statistics (mean and standard deviation
166 | of parameters and gradients) to tensorboard.
167 | should_log_learning_rate : ``bool``, optional, (default = False)
168 | Whether to send parameter specific learning rate to tensorboard.
169 | log_batch_size_period : ``int``, optional, (default = ``None``)
170 | If defined, how often to log the average batch size.
171 | moving_average: ``MovingAverage``, optional, (default = None)
172 | If provided, we will maintain moving averages for all parameters. During training, we
173 | employ a shadow variable for each parameter, which maintains the moving average. During
174 | evaluation, we backup the original parameters and assign the moving averages to corresponding
175 | parameters. Be careful that when saving the checkpoint, we will save the moving averages of
176 | parameters. This is necessary because we want the saved model to perform as well as the validated
177 | model if we load it later. But this may cause problems if you restart the training from checkpoint.
178 | """
179 | super().__init__(serialization_dir, cuda_device)
180 |
181 | # I am not calling move_to_gpu here, because if the model is
182 | # not already on the GPU then the optimizer is going to be wrong.
183 | self.model = model
184 |
185 | self.iterator = iterator
186 | self._validation_iterator = validation_iterator
187 | self.shuffle = shuffle
188 | self.optimizer = optimizer
189 | self.scheduler = scheduler
190 | self.train_data = train_dataset
191 | self._validation_data = validation_dataset
192 | self.accumulated_batch_count = accumulated_batch_count
193 | self.cold_step_count = cold_step_count
194 | self.cold_lr = cold_lr
195 | self.cuda_verbose_step = cuda_verbose_step
196 |
197 | if patience is None: # no early stopping
198 | if validation_dataset:
199 | logger.warning(
200 | "You provided a validation dataset but patience was set to None, "
201 | "meaning that early stopping is disabled"
202 | )
203 | elif (not isinstance(patience, int)) or patience <= 0:
204 | raise ConfigurationError(
205 | '{} is an invalid value for "patience": it must be a positive integer '
206 | "or None (if you want to disable early stopping)".format(patience)
207 | )
208 |
209 | # For tracking is_best_so_far and should_stop_early
210 | self._metric_tracker = MetricTracker(patience, validation_metric)
211 | # Get rid of + or -
212 | self._validation_metric = validation_metric[1:]
213 |
214 | self._num_epochs = num_epochs
215 |
216 | if checkpointer is not None:
217 | # We can't easily check if these parameters were passed in, so check against their default values.
218 | # We don't check against serialization_dir since it is also used by the parent class.
219 | if num_serialized_models_to_keep != 20 \
220 | or keep_serialized_model_every_num_seconds is not None:
221 | raise ConfigurationError(
222 | "When passing a custom Checkpointer, you may not also pass in separate checkpointer "
223 | "args 'num_serialized_models_to_keep' or 'keep_serialized_model_every_num_seconds'."
224 | )
225 | self._checkpointer = checkpointer
226 | else:
227 | self._checkpointer = Checkpointer(
228 | serialization_dir,
229 | keep_serialized_model_every_num_seconds,
230 | num_serialized_models_to_keep,
231 | )
232 |
233 | self._model_save_interval = model_save_interval
234 |
235 | self._grad_norm = grad_norm
236 | self._grad_clipping = grad_clipping
237 |
238 | self._learning_rate_scheduler = learning_rate_scheduler
239 | self._momentum_scheduler = momentum_scheduler
240 | self._moving_average = moving_average
241 |
242 | # We keep the total batch number as an instance variable because it
243 | # is used inside a closure for the hook which logs activations in
244 | # ``_enable_activation_logging``.
245 | self._batch_num_total = 0
246 |
247 | self._tensorboard = TensorboardWriter(
248 | get_batch_num_total=lambda: self._batch_num_total,
249 | serialization_dir=serialization_dir,
250 | summary_interval=summary_interval,
251 | histogram_interval=histogram_interval,
252 | should_log_parameter_statistics=should_log_parameter_statistics,
253 | should_log_learning_rate=should_log_learning_rate,
254 | )
255 |
256 | self._log_batch_size_period = log_batch_size_period
257 |
258 | self._last_log = 0.0 # time of last logging
259 |
260 | # Enable activation logging.
261 | if histogram_interval is not None:
262 | self._tensorboard.enable_activation_logging(self.model)
263 |
264 | def rescale_gradients(self) -> Optional[float]:
265 | return training_util.rescale_gradients(self.model, self._grad_norm)
266 |
267 | def batch_loss(self, batch_group: List[TensorDict], for_training: bool) -> torch.Tensor:
268 | """
269 | Does a forward pass on the given batches and returns the ``loss`` value in the result.
270 | If ``for_training`` is `True` also applies regularization penalty.
271 | """
272 | if self._multiple_gpu:
273 | output_dict = training_util.data_parallel(batch_group, self.model, self._cuda_devices)
274 | else:
275 | assert len(batch_group) == 1
276 | batch = batch_group[0]
277 | batch = nn_util.move_to_device(batch, self._cuda_devices[0])
278 | output_dict = self.model(**batch)
279 |
280 | try:
281 | loss = output_dict["loss"]
282 | if for_training:
283 | loss += self.model.get_regularization_penalty()
284 | except KeyError:
285 | if for_training:
286 | raise RuntimeError(
287 | "The model you are trying to optimize does not contain a"
288 | " 'loss' key in the output of model.forward(inputs)."
289 | )
290 | loss = None
291 |
292 | return loss
293 |
294 | def _train_epoch(self, epoch: int) -> Dict[str, float]:
295 | """
296 | Trains one epoch and returns metrics.
297 | """
298 | logger.info("Epoch %d/%d", epoch, self._num_epochs - 1)
299 | peak_cpu_usage = peak_memory_mb()
300 | logger.info(f"Peak CPU memory usage MB: {peak_cpu_usage}")
301 | gpu_usage = []
302 | for gpu, memory in gpu_memory_mb().items():
303 | gpu_usage.append((gpu, memory))
304 | logger.info(f"GPU {gpu} memory usage MB: {memory}")
305 |
306 | train_loss = 0.0
307 | # Set the model to "train" mode.
308 | self.model.train()
309 |
310 | num_gpus = len(self._cuda_devices)
311 |
312 | # Get tqdm for the training batches
313 | raw_train_generator = self.iterator(self.train_data, num_epochs=1, shuffle=self.shuffle)
314 | train_generator = lazy_groups_of(raw_train_generator, num_gpus)
315 | num_training_batches = math.ceil(self.iterator.get_num_batches(self.train_data) / num_gpus)
316 | residue = num_training_batches % self.accumulated_batch_count
317 | self._last_log = time.time()
318 | last_save_time = time.time()
319 |
320 | batches_this_epoch = 0
321 | if self._batch_num_total is None:
322 | self._batch_num_total = 0
323 |
324 | histogram_parameters = set(self.model.get_parameters_for_histogram_tensorboard_logging())
325 |
326 | logger.info("Training")
327 | train_generator_tqdm = Tqdm.tqdm(train_generator, total=num_training_batches)
328 | cumulative_batch_size = 0
329 | self.optimizer.zero_grad()
330 | for batch_group in train_generator_tqdm:
331 | batches_this_epoch += 1
332 | self._batch_num_total += 1
333 | batch_num_total = self._batch_num_total
334 |
335 | iter_len = self.accumulated_batch_count \
336 | if batches_this_epoch <= (num_training_batches - residue) else residue
337 |
338 | if self.cuda_verbose_step is not None and batch_num_total % self.cuda_verbose_step == 0:
339 | print(f'Before forward pass - Cuda memory allocated: {torch.cuda.memory_allocated() / 1e9}')
340 | print(f'Before forward pass - Cuda memory cached: {torch.cuda.memory_cached() / 1e9}')
341 | try:
342 | loss = self.batch_loss(batch_group, for_training=True) / iter_len
343 | except RuntimeError as e:
344 | print(e)
345 | for x in batch_group:
346 | all_words = [len(y['words']) for y in x['metadata']]
347 | print(f"Total sents: {len(all_words)}. "
348 | f"Min {min(all_words)}. Max {max(all_words)}")
349 | for elem in ['labels', 'd_tags']:
350 | tt = x[elem]
351 | print(
352 | f"{elem} shape {list(tt.shape)} and min {tt.min().item()} and {tt.max().item()}")
353 | for elem in ["bert", "mask", "bert-offsets"]:
354 | tt = x['tokens'][elem]
355 | print(
356 | f"{elem} shape {list(tt.shape)} and min {tt.min().item()} and {tt.max().item()}")
357 | raise e
358 |
359 | if self.cuda_verbose_step is not None and batch_num_total % self.cuda_verbose_step == 0:
360 | print(f'After forward pass - Cuda memory allocated: {torch.cuda.memory_allocated() / 1e9}')
361 | print(f'After forward pass - Cuda memory cached: {torch.cuda.memory_cached() / 1e9}')
362 |
363 | if torch.isnan(loss):
364 | raise ValueError("nan loss encountered")
365 |
366 | loss.backward()
367 |
368 | if self.cuda_verbose_step is not None and batch_num_total % self.cuda_verbose_step == 0:
369 | print(f'After backprop - Cuda memory allocated: {torch.cuda.memory_allocated() / 1e9}')
370 | print(f'After backprop - Cuda memory cached: {torch.cuda.memory_cached() / 1e9}')
371 |
372 | train_loss += loss.item() * iter_len
373 |
374 | del batch_group, loss
375 | torch.cuda.empty_cache()
376 |
377 | if self.cuda_verbose_step is not None and batch_num_total % self.cuda_verbose_step == 0:
378 | print(f'After collecting garbage - Cuda memory allocated: {torch.cuda.memory_allocated() / 1e9}')
379 | print(f'After collecting garbage - Cuda memory cached: {torch.cuda.memory_cached() / 1e9}')
380 |
381 | batch_grad_norm = self.rescale_gradients()
382 |
383 | # This does nothing if batch_num_total is None or you are using a
384 | # scheduler which doesn't update per batch.
385 | if self._learning_rate_scheduler:
386 | self._learning_rate_scheduler.step_batch(batch_num_total)
387 | if self._momentum_scheduler:
388 | self._momentum_scheduler.step_batch(batch_num_total)
389 |
390 | if self._tensorboard.should_log_histograms_this_batch():
391 | # get the magnitude of parameter updates for logging
392 | # We need a copy of current parameters to compute magnitude of updates,
393 | # and copy them to CPU so large models won't go OOM on the GPU.
394 | param_updates = {
395 | name: param.detach().cpu().clone()
396 | for name, param in self.model.named_parameters()
397 | }
398 | if batches_this_epoch % self.accumulated_batch_count == 0 or \
399 | batches_this_epoch == num_training_batches:
400 | self.optimizer.step()
401 | self.optimizer.zero_grad()
402 | for name, param in self.model.named_parameters():
403 | param_updates[name].sub_(param.detach().cpu())
404 | update_norm = torch.norm(param_updates[name].view(-1))
405 | param_norm = torch.norm(param.view(-1)).cpu()
406 | self._tensorboard.add_train_scalar(
407 | "gradient_update/" + name, update_norm / (param_norm + 1e-7)
408 | )
409 | else:
410 | if batches_this_epoch % self.accumulated_batch_count == 0 or \
411 | batches_this_epoch == num_training_batches:
412 | self.optimizer.step()
413 | self.optimizer.zero_grad()
414 |
415 | # Update moving averages
416 | if self._moving_average is not None:
417 | self._moving_average.apply(batch_num_total)
418 |
419 | # Update the description with the latest metrics
420 | metrics = training_util.get_metrics(self.model, train_loss, batches_this_epoch)
421 | description = training_util.description_from_metrics(metrics)
422 |
423 | train_generator_tqdm.set_description(description, refresh=False)
424 |
425 | # Log parameter values to Tensorboard
426 | if self._tensorboard.should_log_this_batch():
427 | self._tensorboard.log_parameter_and_gradient_statistics(self.model, batch_grad_norm)
428 | self._tensorboard.log_learning_rates(self.model, self.optimizer)
429 |
430 | self._tensorboard.add_train_scalar("loss/loss_train", metrics["loss"])
431 | self._tensorboard.log_metrics({"epoch_metrics/" + k: v for k, v in metrics.items()})
432 |
433 | if self._tensorboard.should_log_histograms_this_batch():
434 | self._tensorboard.log_histograms(self.model, histogram_parameters)
435 |
436 | if self._log_batch_size_period:
437 | cur_batch = sum([training_util.get_batch_size(batch) for batch in batch_group])
438 | cumulative_batch_size += cur_batch
439 | if (batches_this_epoch - 1) % self._log_batch_size_period == 0:
440 | average = cumulative_batch_size / batches_this_epoch
441 | logger.info(f"current batch size: {cur_batch} mean batch size: {average}")
442 | self._tensorboard.add_train_scalar("current_batch_size", cur_batch)
443 | self._tensorboard.add_train_scalar("mean_batch_size", average)
444 |
445 | # Save model if needed.
446 | if self._model_save_interval is not None and (
447 | time.time() - last_save_time > self._model_save_interval
448 | ):
449 | last_save_time = time.time()
450 | self._save_checkpoint(
451 | "{0}.{1}".format(epoch, training_util.time_to_str(int(last_save_time)))
452 | )
453 |
454 | metrics = training_util.get_metrics(self.model, train_loss, batches_this_epoch, reset=True)
455 | metrics["cpu_memory_MB"] = peak_cpu_usage
456 | for (gpu_num, memory) in gpu_usage:
457 | metrics["gpu_" + str(gpu_num) + "_memory_MB"] = memory
458 | return metrics
459 |
460 | def _validation_loss(self) -> Tuple[float, int]:
461 | """
462 | Computes the validation loss. Returns it and the number of batches.
463 | """
464 | logger.info("Validating")
465 |
466 | self.model.eval()
467 |
468 | # Replace parameter values with the shadow values from the moving averages.
469 | if self._moving_average is not None:
470 | self._moving_average.assign_average_value()
471 |
472 | if self._validation_iterator is not None:
473 | val_iterator = self._validation_iterator
474 | else:
475 | val_iterator = self.iterator
476 |
477 | num_gpus = len(self._cuda_devices)
478 |
479 | raw_val_generator = val_iterator(self._validation_data, num_epochs=1, shuffle=False)
480 | val_generator = lazy_groups_of(raw_val_generator, num_gpus)
481 | num_validation_batches = math.ceil(
482 | val_iterator.get_num_batches(self._validation_data) / num_gpus
483 | )
484 | val_generator_tqdm = Tqdm.tqdm(val_generator, total=num_validation_batches)
485 | batches_this_epoch = 0
486 | val_loss = 0
487 | for batch_group in val_generator_tqdm:
488 |
489 | loss = self.batch_loss(batch_group, for_training=False)
490 | if loss is not None:
491 | # You shouldn't necessarily have to compute a loss for validation, so we allow for
492 | # `loss` to be None. We need to be careful, though - `batches_this_epoch` is
493 | # currently only used as the divisor for the loss function, so we can safely only
494 | # count those batches for which we actually have a loss. If this variable ever
495 | # gets used for something else, we might need to change things around a bit.
496 | batches_this_epoch += 1
497 | val_loss += loss.detach().cpu().numpy()
498 |
499 | # Update the description with the latest metrics
500 | val_metrics = training_util.get_metrics(self.model, val_loss, batches_this_epoch)
501 | description = training_util.description_from_metrics(val_metrics)
502 | val_generator_tqdm.set_description(description, refresh=False)
503 |
504 | # Now restore the original parameter values.
505 | if self._moving_average is not None:
506 | self._moving_average.restore()
507 |
508 | return val_loss, batches_this_epoch
509 |
510 | def train(self) -> Dict[str, Any]:
511 | """
512 | Trains the supplied model with the supplied parameters.
513 | """
514 | try:
515 | epoch_counter = self._restore_checkpoint()
516 | except RuntimeError:
517 | traceback.print_exc()
518 | raise ConfigurationError(
519 | "Could not recover training from the checkpoint. Did you mean to output to "
520 | "a different serialization directory or delete the existing serialization "
521 | "directory?"
522 | )
523 |
524 | training_util.enable_gradient_clipping(self.model, self._grad_clipping)
525 |
526 | logger.info("Beginning training.")
527 |
528 | train_metrics: Dict[str, float] = {}
529 | val_metrics: Dict[str, float] = {}
530 | this_epoch_val_metric: float = None
531 | metrics: Dict[str, Any] = {}
532 | epochs_trained = 0
533 | training_start_time = time.time()
534 |
535 | if self.cold_step_count > 0:
536 | base_lr = self.optimizer.param_groups[0]['lr']
537 | for param_group in self.optimizer.param_groups:
538 | param_group['lr'] = self.cold_lr
539 | self.model.text_field_embedder._token_embedders['bert'].set_weights(freeze=True)
540 |
541 | metrics["best_epoch"] = self._metric_tracker.best_epoch
542 | for key, value in self._metric_tracker.best_epoch_metrics.items():
543 | metrics["best_validation_" + key] = value
544 |
545 | for epoch in range(epoch_counter, self._num_epochs):
546 | if epoch == self.cold_step_count and epoch != 0:
547 | for param_group in self.optimizer.param_groups:
548 | param_group['lr'] = base_lr
549 | self.model.text_field_embedder._token_embedders['bert'].set_weights(freeze=False)
550 |
551 | epoch_start_time = time.time()
552 | train_metrics = self._train_epoch(epoch)
553 |
554 | # get peak of memory usage
555 | if "cpu_memory_MB" in train_metrics:
556 | metrics["peak_cpu_memory_MB"] = max(
557 | metrics.get("peak_cpu_memory_MB", 0), train_metrics["cpu_memory_MB"]
558 | )
559 | for key, value in train_metrics.items():
560 | if key.startswith("gpu_"):
561 | metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value)
562 |
563 | # clear cache before validation
564 | torch.cuda.empty_cache()
565 | if self._validation_data is not None:
566 | with torch.no_grad():
567 | # We have a validation set, so compute all the metrics on it.
568 | val_loss, num_batches = self._validation_loss()
569 | val_metrics = training_util.get_metrics(
570 | self.model, val_loss, num_batches, reset=True
571 | )
572 |
573 | # Check validation metric for early stopping
574 | this_epoch_val_metric = val_metrics[self._validation_metric]
575 | self._metric_tracker.add_metric(this_epoch_val_metric)
576 |
577 | if self._metric_tracker.should_stop_early():
578 | logger.info("Ran out of patience. Stopping training.")
579 | break
580 |
581 | self._tensorboard.log_metrics(
582 | train_metrics, val_metrics=val_metrics, log_to_console=True, epoch=epoch + 1
583 | ) # +1 because tensorboard doesn't like 0
584 |
585 | # Create overall metrics dict
586 | training_elapsed_time = time.time() - training_start_time
587 | metrics["training_duration"] = str(datetime.timedelta(seconds=training_elapsed_time))
588 | metrics["training_start_epoch"] = epoch_counter
589 | metrics["training_epochs"] = epochs_trained
590 | metrics["epoch"] = epoch
591 |
592 | for key, value in train_metrics.items():
593 | metrics["training_" + key] = value
594 | for key, value in val_metrics.items():
595 | metrics["validation_" + key] = value
596 |
597 | # if self.cold_step_count <= epoch:
598 | self.scheduler.step(metrics['validation_loss'])
599 |
600 | if self._metric_tracker.is_best_so_far():
601 | # Update all the best_ metrics.
602 | # (Otherwise they just stay the same as they were.)
603 | metrics["best_epoch"] = epoch
604 | for key, value in val_metrics.items():
605 | metrics["best_validation_" + key] = value
606 |
607 | self._metric_tracker.best_epoch_metrics = val_metrics
608 |
609 | if self._serialization_dir:
610 | dump_metrics(
611 | os.path.join(self._serialization_dir, f"metrics_epoch_{epoch}.json"), metrics
612 | )
613 |
614 | # The Scheduler API is agnostic to whether your schedule requires a validation metric -
615 | # if it doesn't, the validation metric passed here is ignored.
616 | if self._learning_rate_scheduler:
617 | self._learning_rate_scheduler.step(this_epoch_val_metric, epoch)
618 | if self._momentum_scheduler:
619 | self._momentum_scheduler.step(this_epoch_val_metric, epoch)
620 |
621 | self._save_checkpoint(epoch)
622 |
623 | epoch_elapsed_time = time.time() - epoch_start_time
624 | logger.info("Epoch duration: %s", datetime.timedelta(seconds=epoch_elapsed_time))
625 |
626 | if epoch < self._num_epochs - 1:
627 | training_elapsed_time = time.time() - training_start_time
628 | estimated_time_remaining = training_elapsed_time * (
629 | (self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1
630 | )
631 | formatted_time = str(datetime.timedelta(seconds=int(estimated_time_remaining)))
632 | logger.info("Estimated training time remaining: %s", formatted_time)
633 |
634 | epochs_trained += 1
635 |
636 | # make sure pending events are flushed to disk and files are closed properly
637 | # self._tensorboard.close()
638 |
639 | # Load the best model state before returning
640 | best_model_state = self._checkpointer.best_model_state()
641 | if best_model_state:
642 | self.model.load_state_dict(best_model_state)
643 |
644 | return metrics
645 |
646 | def _save_checkpoint(self, epoch: Union[int, str]) -> None:
647 | """
648 | Saves a checkpoint of the model to self._serialization_dir.
649 | Is a no-op if self._serialization_dir is None.
650 |
651 | Parameters
652 | ----------
653 | epoch : Union[int, str], required.
654 | The epoch of training. If the checkpoint is saved in the middle
655 | of an epoch, the parameter is a string with the epoch and timestamp.
656 | """
657 | # If moving averages are used for parameters, we save
658 | # the moving average values into checkpoint, instead of the current values.
659 | if self._moving_average is not None:
660 | self._moving_average.assign_average_value()
661 |
662 | # These are the training states we need to persist.
663 | training_states = {
664 | "metric_tracker": self._metric_tracker.state_dict(),
665 | "optimizer": self.optimizer.state_dict(),
666 | "batch_num_total": self._batch_num_total,
667 | }
668 |
669 | # If we have a learning rate or momentum scheduler, we should persist them too.
670 | if self._learning_rate_scheduler is not None:
671 | training_states["learning_rate_scheduler"] = self._learning_rate_scheduler.state_dict()
672 | if self._momentum_scheduler is not None:
673 | training_states["momentum_scheduler"] = self._momentum_scheduler.state_dict()
674 |
675 | self._checkpointer.save_checkpoint(
676 | model_state=self.model.state_dict(),
677 | epoch=epoch,
678 | training_states=training_states,
679 | is_best_so_far=self._metric_tracker.is_best_so_far(),
680 | )
681 |
682 | # Restore the original values for parameters so that training will not be affected.
683 | if self._moving_average is not None:
684 | self._moving_average.restore()
685 |
686 | def _restore_checkpoint(self) -> int:
687 | """
688 | Restores the model and training state from the last saved checkpoint.
689 | This includes an epoch count and optimizer state, which is serialized separately
690 | from model parameters. This function should only be used to continue training -
691 | if you wish to load a model for inference/load parts of a model into a new
692 | computation graph, you should use the native Pytorch functions:
693 | `` model.load_state_dict(torch.load("/path/to/model/weights.th"))``
694 |
695 | If ``self._serialization_dir`` does not exist or does not contain any checkpointed weights,
696 | this function will do nothing and return 0.
697 |
698 | Returns
699 | -------
700 | epoch: int
701 | The epoch at which to resume training, which should be one after the epoch
702 | in the saved training state.
703 | """
704 | model_state, training_state = self._checkpointer.restore_checkpoint()
705 |
706 | if not training_state:
707 | # No checkpoint to restore, start at 0
708 | return 0
709 |
710 | self.model.load_state_dict(model_state)
711 | self.optimizer.load_state_dict(training_state["optimizer"])
712 | if self._learning_rate_scheduler is not None \
713 | and "learning_rate_scheduler" in training_state:
714 | self._learning_rate_scheduler.load_state_dict(training_state["learning_rate_scheduler"])
715 | if self._momentum_scheduler is not None and "momentum_scheduler" in training_state:
716 | self._momentum_scheduler.load_state_dict(training_state["momentum_scheduler"])
717 | training_util.move_optimizer_to_cuda(self.optimizer)
718 |
719 | # Currently the ``training_state`` contains a serialized ``MetricTracker``.
720 | if "metric_tracker" in training_state:
721 | self._metric_tracker.load_state_dict(training_state["metric_tracker"])
722 | # It used to be the case that we tracked ``val_metric_per_epoch``.
723 | elif "val_metric_per_epoch" in training_state:
724 | self._metric_tracker.clear()
725 | self._metric_tracker.add_metrics(training_state["val_metric_per_epoch"])
726 | # And before that we didn't track anything.
727 | else:
728 | self._metric_tracker.clear()
729 |
730 | if isinstance(training_state["epoch"], int):
731 | epoch_to_return = training_state["epoch"] + 1
732 | else:
733 | epoch_to_return = int(training_state["epoch"].split(".")[0]) + 1
734 |
735 | # For older checkpoints with batch_num_total missing, default to old behavior where
736 | # it is unchanged.
737 | batch_num_total = training_state.get("batch_num_total")
738 | if batch_num_total is not None:
739 | self._batch_num_total = batch_num_total
740 |
741 | return epoch_to_return
742 |
743 | # Requires custom from_params.
744 | @classmethod
745 | def from_params( # type: ignore
746 | cls,
747 | model: Model,
748 | serialization_dir: str,
749 | iterator: DataIterator,
750 | train_data: Iterable[Instance],
751 | validation_data: Optional[Iterable[Instance]],
752 | params: Params,
753 | validation_iterator: DataIterator = None,
754 | ) -> "Trainer":
755 |
756 | patience = params.pop_int("patience", None)
757 | validation_metric = params.pop("validation_metric", "-loss")
758 | shuffle = params.pop_bool("shuffle", True)
759 | num_epochs = params.pop_int("num_epochs", 20)
760 | cuda_device = parse_cuda_device(params.pop("cuda_device", -1))
761 | grad_norm = params.pop_float("grad_norm", None)
762 | grad_clipping = params.pop_float("grad_clipping", None)
763 | lr_scheduler_params = params.pop("learning_rate_scheduler", None)
764 | momentum_scheduler_params = params.pop("momentum_scheduler", None)
765 |
766 | if isinstance(cuda_device, list):
767 | model_device = cuda_device[0]
768 | else:
769 | model_device = cuda_device
770 | if model_device >= 0:
771 | # Moving model to GPU here so that the optimizer state gets constructed on
772 | # the right device.
773 | model = model.cuda(model_device)
774 |
775 | parameters = [[n, p] for n, p in model.named_parameters() if p.requires_grad]
776 | optimizer = Optimizer.from_params(parameters, params.pop("optimizer"))
777 | if "moving_average" in params:
778 | moving_average = MovingAverage.from_params(
779 | params.pop("moving_average"), parameters=parameters
780 | )
781 | else:
782 | moving_average = None
783 |
784 | if lr_scheduler_params:
785 | lr_scheduler = LearningRateScheduler.from_params(optimizer, lr_scheduler_params)
786 | else:
787 | lr_scheduler = None
788 | if momentum_scheduler_params:
789 | momentum_scheduler = MomentumScheduler.from_params(optimizer, momentum_scheduler_params)
790 | else:
791 | momentum_scheduler = None
792 |
793 | if "checkpointer" in params:
794 | if "keep_serialized_model_every_num_seconds" in params \
795 | or "num_serialized_models_to_keep" in params:
796 | raise ConfigurationError(
797 | "Checkpointer may be initialized either from the 'checkpointer' key or from the "
798 | "keys 'num_serialized_models_to_keep' and 'keep_serialized_model_every_num_seconds'"
799 | " but the passed config uses both methods."
800 | )
801 | checkpointer = Checkpointer.from_params(params.pop("checkpointer"))
802 | else:
803 | num_serialized_models_to_keep = params.pop_int("num_serialized_models_to_keep", 20)
804 | keep_serialized_model_every_num_seconds = params.pop_int(
805 | "keep_serialized_model_every_num_seconds", None
806 | )
807 | checkpointer = Checkpointer(
808 | serialization_dir=serialization_dir,
809 | num_serialized_models_to_keep=num_serialized_models_to_keep,
810 | keep_serialized_model_every_num_seconds=keep_serialized_model_every_num_seconds,
811 | )
812 | model_save_interval = params.pop_float("model_save_interval", None)
813 | summary_interval = params.pop_int("summary_interval", 100)
814 | histogram_interval = params.pop_int("histogram_interval", None)
815 | should_log_parameter_statistics = params.pop_bool("should_log_parameter_statistics", True)
816 | should_log_learning_rate = params.pop_bool("should_log_learning_rate", False)
817 | log_batch_size_period = params.pop_int("log_batch_size_period", None)
818 |
819 | params.assert_empty(cls.__name__)
820 | return cls(
821 | model,
822 | optimizer,
823 | iterator,
824 | train_data,
825 | validation_data,
826 | patience=patience,
827 | validation_metric=validation_metric,
828 | validation_iterator=validation_iterator,
829 | shuffle=shuffle,
830 | num_epochs=num_epochs,
831 | serialization_dir=serialization_dir,
832 | cuda_device=cuda_device,
833 | grad_norm=grad_norm,
834 | grad_clipping=grad_clipping,
835 | learning_rate_scheduler=lr_scheduler,
836 | momentum_scheduler=momentum_scheduler,
837 | checkpointer=checkpointer,
838 | model_save_interval=model_save_interval,
839 | summary_interval=summary_interval,
840 | histogram_interval=histogram_interval,
841 | should_log_parameter_statistics=should_log_parameter_statistics,
842 | should_log_learning_rate=should_log_learning_rate,
843 | log_batch_size_period=log_batch_size_period,
844 | moving_average=moving_average,
845 | )
846 |
--------------------------------------------------------------------------------