├── README.md ├── bert-base-chinese ├── config.json └── vocab.txt ├── data ├── inference │ ├── raw_data │ │ └── raw_data.txt │ └── tree_data │ │ └── tree_data.txt └── train │ ├── raw_data │ └── raw_data.txt │ └── tree_data │ ├── tree_test.txt │ ├── tree_train.txt │ └── tree_validate.txt ├── models └── pretrained_model │ └── README.md └── src ├── benepar ├── __init__.py ├── char_lstm.py ├── decode_chart.py ├── integrations │ ├── __init__.py │ ├── downloader.py │ ├── nltk_plugin.py │ ├── spacy_extensions.py │ └── spacy_plugin.py ├── nkutil.py ├── parse_base.py ├── parse_chart.py ├── partitioned_transformer.py ├── ptb_unescape.py ├── retokenization.py ├── spacy_plugin.py └── subbatching.py ├── count_fscore.py ├── evaluate.py ├── export.py ├── inference_seq2tree.py ├── learning_rates.py ├── main.py ├── seq_with_label.py ├── train_raw2tree.py ├── transliterate.py └── treebanks.py /README.md: -------------------------------------------------------------------------------- 1 | # SpanPSP 2 | This repository contains code accompanying the paper **"A CHARACTER-LEVEL SPAN-BASED MODEL FOR MANDARIN PROSODIC STRUCTURE PREDICTION"** published on ICASSP 2022. 3 | 4 | ## [TTS Samples](https://thuhcsi.github.io/SpanPSP/) | [Paper](https://arxiv.org/abs/2203.16922) 5 | 6 | ## Environment 7 | * Python 3.7 or higher. 8 | * Pytorch 1.6.0, or any compatible version. 9 | * NLTK 3.2, torch-struct 0.4, transformers 4.3.0, or compatible. 10 | * pytokenizations 0.7.2 or compatible. 11 | 12 | ## Repository structure 13 | ``` 14 | SpanPSP 15 | ├──bert-base-chinese 16 | | ├──config.json 17 | | ├──pytorch_model.bin 18 | | └──vocab.txt 19 | ├──data 20 | | ├──train 21 | | | ├──raw_data 22 | | | | └──raw_data.txt 23 | | | └──tree_data 24 | | | ├──tree_train.txt 25 | | | ├──tree_validate.txt 26 | | | └──tree_test.txt 27 | | └──inference 28 | | ├──raw_data 29 | | | └──raw_data.txt 30 | | ├──tree_data 31 | | └──tree_data.txt 32 | ├──models 33 | | ├──pretrained_model 34 | | | └──pretrained_SpanPSP_Databaker.pt 35 | | └──yours 36 | ├──src 37 | | ├──benepar 38 | | ├── ... 39 | | ├──count_fscore.py 40 | | ├──evaluate.py 41 | | ├──export.py 42 | | ├──inference_seq2tree.py 43 | | ├──learning_rate.py 44 | | ├──main.py 45 | | ├──seq_with_label.py 46 | | ├──train_raw2tree.py 47 | | ├──transliterate.py 48 | | ├──treebank.py 49 | ├──README.md 50 | ``` 51 | 52 | ## Download pretrained model 53 | You can download the pre-trained models from the link below and put them in the right place as shown in the repository structure. 54 | * ### bert-base-chinese 55 | > Link: https://huggingface.co/bert-base-chinese 56 | * ### SpanPSP_Databaker,SpanPSP_PeopleDaily 57 | > Link: https://pan.baidu.com/s/1bwwFbyP1WoEr3fLbbGeXpQ 58 | 59 | > Password: 9r2h 60 | 61 | 62 | ## Training and test with your dataset 63 | ### Data preprocessing 64 | First prepare your own dataset into the following format, and put it (__*raw_data.txt*__) in the right place as shown in the above repository structure. 65 | > 猴子#2用#1尾巴#2荡秋千#3。 66 | 67 | Then use the following command to convert the data of the above raw file from sequence format to tree format, and devide it into training, validation, and test with the ratio of 8:1:1. 68 | ``` 69 | $ python src/train_raw2tree.py 70 | ``` 71 | After that, you can get the __*tree_train.txt*__, __*tree_validate.txt*__ and __*tree_test.txt*__. 72 | ### Training 73 | Train your model using: 74 | ``` 75 | $ python src/main.py train --train-path [your_training_data_path] --dev-path [your_dev_data_path] --model-path-base [your_saving_model_path] 76 | ``` 77 | For example: 78 | ``` 79 | $ python src/main.py train --train-path data/train/tree_data/tree_train.txt --dev-path data/train/tree_data/tree_validate.txt --model-path-base models/my_model 80 | ``` 81 | ### Test 82 | Test your model using: 83 | ``` 84 | $ python src/main.py test --model-path [your_trained_model_path] --test-path [your_test_data_path] 85 | ``` 86 | For example: 87 | ``` 88 | $ python src/main.py test --model-path models/my_model.pt --test-path data/train/tree_data/tree_test.txt 89 | ``` 90 | ## Inference 91 | ### Data preprocessing 92 | First prepare your own dataset into the following format, and put it (__*raw_data.txt*__) in the right place as shown in the repository structure. 93 | > 猴子用尾巴荡秋千。 94 | 95 | Then use the following command to convert the dataset from sequence format to tree format: 96 | ``` 97 | $ python src/inference_seq2tree.py 98 | ``` 99 | After that, you can get the __*tree_data.txt*__. 100 | ### inference 101 | Inference with your data using: 102 | ``` 103 | $ python src/main.py inference --model-path [your_pretrained_model_path] --test-path [your_test_data_path] --output-path [your_output_data_path] 104 | ``` 105 | For example: 106 | ``` 107 | $ python src/main.py inference --model-path models/pretrained_model/pretrained_SpanPSP_Databaker.pt --test-path data/inference/tree_data/tree_data.txt --output-path data/inference/output_data.txt 108 | ``` 109 | -------------------------------------------------------------------------------- /bert-base-chinese/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertForMaskedLM" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "directionality": "bidi", 7 | "hidden_act": "gelu", 8 | "hidden_dropout_prob": 0.1, 9 | "hidden_size": 768, 10 | "initializer_range": 0.02, 11 | "intermediate_size": 3072, 12 | "layer_norm_eps": 1e-12, 13 | "max_position_embeddings": 512, 14 | "model_type": "bert", 15 | "num_attention_heads": 12, 16 | "num_hidden_layers": 12, 17 | "pad_token_id": 0, 18 | "pooler_fc_size": 768, 19 | "pooler_num_attention_heads": 12, 20 | "pooler_num_fc_layers": 3, 21 | "pooler_size_per_head": 128, 22 | "pooler_type": "first_token_transform", 23 | "type_vocab_size": 2, 24 | "vocab_size": 21128 25 | } 26 | -------------------------------------------------------------------------------- /data/inference/raw_data/raw_data.txt: -------------------------------------------------------------------------------- 1 | 猴子用尾巴荡秋千。 2 | 猴子用尾巴荡秋千。 3 | -------------------------------------------------------------------------------- /data/inference/tree_data/tree_data.txt: -------------------------------------------------------------------------------- 1 | (TOP (S (n 猴)(n 子)(n 用)(n 尾)(n 巴)(n 荡)(n 秋)(n 千)(。 。))) 2 | (TOP (S (n 猴)(n 子)(n 用)(n 尾)(n 巴)(n 荡)(n 秋)(n 千)(。 。))) 3 | -------------------------------------------------------------------------------- /data/train/raw_data/raw_data.txt: -------------------------------------------------------------------------------- 1 | 猴子#2用#1尾巴#2荡秋千#3。 2 | 猴子#2用#1尾巴#2荡秋千#3。 3 | -------------------------------------------------------------------------------- /data/train/tree_data/tree_test.txt: -------------------------------------------------------------------------------- 1 | (TOP (S (#3 (#2 (#1 (n 猴)(n 子))) (#2 (#1 (n 用))(#1 (n 尾)(n 巴))) (#2 (#1 (n 荡)(n 秋)(n 千)))) (。 。))) 2 | (TOP (S (#3 (#2 (#1 (n 猴)(n 子))) (#2 (#1 (n 用))(#1 (n 尾)(n 巴))) (#2 (#1 (n 荡)(n 秋)(n 千)))) (。 。))) 3 | -------------------------------------------------------------------------------- /data/train/tree_data/tree_train.txt: -------------------------------------------------------------------------------- 1 | (TOP (S (#3 (#2 (#1 (n 猴)(n 子))) (#2 (#1 (n 用))(#1 (n 尾)(n 巴))) (#2 (#1 (n 荡)(n 秋)(n 千)))) (。 。))) 2 | (TOP (S (#3 (#2 (#1 (n 猴)(n 子))) (#2 (#1 (n 用))(#1 (n 尾)(n 巴))) (#2 (#1 (n 荡)(n 秋)(n 千)))) (。 。))) 3 | -------------------------------------------------------------------------------- /data/train/tree_data/tree_validate.txt: -------------------------------------------------------------------------------- 1 | (TOP (S (#3 (#2 (#1 (n 猴)(n 子))) (#2 (#1 (n 用))(#1 (n 尾)(n 巴))) (#2 (#1 (n 荡)(n 秋)(n 千)))) (。 。))) 2 | (TOP (S (#3 (#2 (#1 (n 猴)(n 子))) (#2 (#1 (n 用))(#1 (n 尾)(n 巴))) (#2 (#1 (n 荡)(n 秋)(n 千)))) (。 。))) 3 | -------------------------------------------------------------------------------- /models/pretrained_model/README.md: -------------------------------------------------------------------------------- 1 | # Pretrained model 2 | 3 | You can download the pre-trained model from the link below and put it in this directory. 4 | 5 | **Link**: https://pan.baidu.com/s/1zgXHgRnUY_J2IDSEEq2J0w 6 | 7 | **Password**: w7d8 8 | -------------------------------------------------------------------------------- /src/benepar/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | benepar: Berkeley Neural Parser 3 | """ 4 | 5 | # This file and all code in integrations/ relate to the version of the parser 6 | # released via PyPI. If you only need to run research experiments, it is safe 7 | # to delete the integrations/ folder and replace this __init__.py with an 8 | # empty file. 9 | 10 | __all__ = [ 11 | "Parser", 12 | "InputSentence", 13 | "download", 14 | "BeneparComponent", 15 | "NonConstituentException", 16 | ] 17 | 18 | from .integrations.downloader import download 19 | from .integrations.nltk_plugin import Parser, InputSentence 20 | from .integrations.spacy_plugin import BeneparComponent, NonConstituentException 21 | -------------------------------------------------------------------------------- /src/benepar/char_lstm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class CharacterLSTM(nn.Module): 8 | def __init__(self, num_embeddings, d_embedding, d_out, char_dropout=0.0, **kwargs): 9 | super().__init__() 10 | 11 | self.d_embedding = d_embedding 12 | self.d_out = d_out 13 | 14 | self.lstm = nn.LSTM( 15 | self.d_embedding, self.d_out // 2, num_layers=1, bidirectional=True 16 | ) 17 | 18 | self.emb = nn.Embedding(num_embeddings, self.d_embedding, **kwargs) 19 | self.char_dropout = nn.Dropout(char_dropout) 20 | 21 | def forward(self, chars_packed, valid_token_mask): 22 | inp_embs = nn.utils.rnn.PackedSequence( 23 | self.char_dropout(self.emb(chars_packed.data)), 24 | batch_sizes=chars_packed.batch_sizes, 25 | sorted_indices=chars_packed.sorted_indices, 26 | unsorted_indices=chars_packed.unsorted_indices, 27 | ) 28 | 29 | _, (lstm_out, _) = self.lstm(inp_embs) 30 | lstm_out = torch.cat([lstm_out[0], lstm_out[1]], -1) 31 | 32 | # Switch to a representation where there are dummy vectors for invalid 33 | # tokens generated by padding. 34 | res = lstm_out.new_zeros( 35 | (valid_token_mask.shape[0], valid_token_mask.shape[1], lstm_out.shape[-1]) 36 | ) 37 | res[valid_token_mask] = lstm_out 38 | return res 39 | 40 | 41 | class RetokenizerForCharLSTM: 42 | # Assumes that these control characters are not present in treebank text 43 | CHAR_UNK = "\0" 44 | CHAR_ID_UNK = 0 45 | CHAR_START_SENTENCE = "\1" 46 | CHAR_START_WORD = "\2" 47 | CHAR_STOP_WORD = "\3" 48 | CHAR_STOP_SENTENCE = "\4" 49 | 50 | def __init__(self, char_vocab): 51 | self.char_vocab = char_vocab 52 | 53 | @classmethod 54 | def build_vocab(cls, sentences): 55 | char_set = set() 56 | for sentence in sentences: 57 | if isinstance(sentence, tuple): 58 | sentence = sentence[0] 59 | for word in sentence: 60 | char_set |= set(word) 61 | 62 | # If codepoints are small (e.g. Latin alphabet), index by codepoint 63 | # directly 64 | highest_codepoint = max(ord(char) for char in char_set) 65 | if highest_codepoint < 512: 66 | if highest_codepoint < 256: 67 | highest_codepoint = 256 68 | else: 69 | highest_codepoint = 512 70 | 71 | char_vocab = {} 72 | # This also takes care of constants like CHAR_UNK, etc. 73 | for codepoint in range(highest_codepoint): 74 | char_vocab[chr(codepoint)] = codepoint 75 | return char_vocab 76 | else: 77 | char_vocab = {} 78 | char_vocab[cls.CHAR_UNK] = 0 79 | char_vocab[cls.CHAR_START_SENTENCE] = 1 80 | char_vocab[cls.CHAR_START_WORD] = 2 81 | char_vocab[cls.CHAR_STOP_WORD] = 3 82 | char_vocab[cls.CHAR_STOP_SENTENCE] = 4 83 | for id_, char in enumerate(sorted(char_set), start=5): 84 | char_vocab[char] = id_ 85 | return char_vocab 86 | 87 | def __call__(self, words, space_after="ignored", return_tensors=None): 88 | if return_tensors != "np": 89 | raise NotImplementedError("Only return_tensors='np' is supported.") 90 | 91 | res = {} 92 | 93 | # Sentence-level start/stop tokens are encoded as 3 pseudo-chars 94 | # Within each word, account for 2 start/stop characters 95 | max_word_len = max(3, max(len(word) for word in words)) + 2 96 | char_ids = np.zeros((len(words) + 2, max_word_len), dtype=int) 97 | word_lens = np.zeros(len(words) + 2, dtype=int) 98 | 99 | char_ids[0, :5] = [ 100 | self.char_vocab[self.CHAR_START_WORD], 101 | self.char_vocab[self.CHAR_START_SENTENCE], 102 | self.char_vocab[self.CHAR_START_SENTENCE], 103 | self.char_vocab[self.CHAR_START_SENTENCE], 104 | self.char_vocab[self.CHAR_STOP_WORD], 105 | ] 106 | word_lens[0] = 5 107 | for i, word in enumerate(words, start=1): 108 | char_ids[i, 0] = self.char_vocab[self.CHAR_START_WORD] 109 | for j, char in enumerate(word, start=1): 110 | char_ids[i, j] = self.char_vocab.get(char, self.CHAR_ID_UNK) 111 | char_ids[i, j + 1] = self.char_vocab[self.CHAR_STOP_WORD] 112 | word_lens[i] = j + 2 113 | char_ids[i + 1, :5] = [ 114 | self.char_vocab[self.CHAR_START_WORD], 115 | self.char_vocab[self.CHAR_STOP_SENTENCE], 116 | self.char_vocab[self.CHAR_STOP_SENTENCE], 117 | self.char_vocab[self.CHAR_STOP_SENTENCE], 118 | self.char_vocab[self.CHAR_STOP_WORD], 119 | ] 120 | word_lens[i + 1] = 5 121 | 122 | res["char_ids"] = char_ids 123 | res["word_lens"] = word_lens 124 | res["valid_token_mask"] = np.ones_like(word_lens, dtype=bool) 125 | 126 | return res 127 | 128 | def pad(self, examples, return_tensors=None): 129 | if return_tensors != "pt": 130 | raise NotImplementedError("Only return_tensors='pt' is supported.") 131 | max_word_len = max(example["char_ids"].shape[-1] for example in examples) 132 | char_ids = torch.cat( 133 | [ 134 | F.pad( 135 | torch.tensor(example["char_ids"]), 136 | (0, max_word_len - example["char_ids"].shape[-1]), 137 | ) 138 | for example in examples 139 | ] 140 | ) 141 | word_lens = torch.cat( 142 | [torch.tensor(example["word_lens"]) for example in examples] 143 | ) 144 | valid_token_mask = nn.utils.rnn.pad_sequence( 145 | [torch.tensor(example["valid_token_mask"]) for example in examples], 146 | batch_first=True, 147 | padding_value=False, 148 | ) 149 | 150 | char_ids = nn.utils.rnn.pack_padded_sequence( 151 | char_ids, word_lens, batch_first=True, enforce_sorted=False 152 | ) 153 | return { 154 | "char_ids": char_ids, 155 | "valid_token_mask": valid_token_mask, 156 | } 157 | -------------------------------------------------------------------------------- /src/benepar/decode_chart.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | import numpy as np 3 | from numpy.core.fromnumeric import shape 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch_struct 8 | 9 | from .parse_base import CompressedParserOutput 10 | 11 | 12 | def pad_charts(charts, padding_value=-100): 13 | """Pad a list of variable-length charts with `padding_value`.""" 14 | batch_size = len(charts) 15 | max_len = max(chart.shape[0] for chart in charts) 16 | padded_charts = torch.full( 17 | (batch_size, max_len, max_len), 18 | padding_value, 19 | dtype=charts[0].dtype, 20 | device=charts[0].device, 21 | ) 22 | for i, chart in enumerate(charts): 23 | chart_size = chart.shape[0] 24 | padded_charts[i, :chart_size, :chart_size] = chart 25 | return padded_charts 26 | 27 | 28 | 29 | def collapse_unary_strip_pos(tree, strip_top=True): 30 | """Collapse unary chains and strip part of speech tags.""" 31 | 32 | def strip_pos(tree): 33 | if len(tree) == 1 and isinstance(tree[0], str): 34 | return tree[0] 35 | else: 36 | return nltk.tree.Tree(tree.label(), [strip_pos(child) for child in tree]) 37 | 38 | collapsed_tree = strip_pos(tree) 39 | collapsed_tree.collapse_unary(collapsePOS=True, joinChar="::") 40 | if collapsed_tree.label() in ("TOP", "ROOT", "S1", "VROOT"): 41 | if strip_top: 42 | if len(collapsed_tree) == 1: 43 | collapsed_tree = collapsed_tree[0] 44 | else: 45 | collapsed_tree.set_label("") 46 | elif len(collapsed_tree) == 1: 47 | collapsed_tree[0].set_label( 48 | collapsed_tree.label() + "::" + collapsed_tree[0].label()) 49 | collapsed_tree = collapsed_tree[0] 50 | return collapsed_tree 51 | 52 | 53 | def _get_labeled_spans(tree, spans_out, start): 54 | if isinstance(tree, str): 55 | return start + 1 56 | 57 | assert len(tree) > 1 or isinstance( 58 | tree[0], str 59 | ), "Must call collapse_unary_strip_pos first" 60 | end = start 61 | for child in tree: 62 | end = _get_labeled_spans(child, spans_out, end) 63 | spans_out.append((start, end - 1, tree.label())) 64 | return end 65 | 66 | 67 | def get_labeled_spans(tree): 68 | """Converts a tree into a list of labeled spans. 69 | 70 | Args: 71 | tree: an nltk.tree.Tree object 72 | 73 | Returns: 74 | A list of (span_start, span_end, span_label) tuples. The start and end 75 | indices indicate the first and last words of the span (a closed 76 | interval). Unary chains are collapsed, so e.g. a (S (VP ...)) will 77 | result in a single span labeled "S+VP". 78 | """ 79 | tree = collapse_unary_strip_pos(tree) 80 | spans_out = [] 81 | _get_labeled_spans(tree, spans_out, start=0) 82 | return spans_out 83 | 84 | 85 | def uncollapse_unary(tree, ensure_top=False): 86 | """Un-collapse unary chains.""" 87 | if isinstance(tree, str): 88 | return tree 89 | else: 90 | labels = tree.label().split("::") 91 | if ensure_top and labels[0] != "TOP": 92 | labels = ["TOP"] + labels 93 | children = [] 94 | for child in tree: 95 | child = uncollapse_unary(child) 96 | children.append(child) 97 | for label in labels[::-1]: 98 | children = [nltk.tree.Tree(label, children)] 99 | return children[0] 100 | 101 | 102 | class ChartDecoder: 103 | """A chart decoder for parsing formulated as span classification.""" 104 | 105 | def __init__(self, label_vocab, force_root_constituent=True): 106 | """Constructs a new ChartDecoder object. 107 | Args: 108 | label_vocab: A mapping from span labels to integer indices. 109 | """ 110 | self.label_vocab = label_vocab 111 | self.label_from_index = {i: label for label, i in label_vocab.items()} 112 | self.force_root_constituent = force_root_constituent 113 | 114 | @staticmethod 115 | def build_vocab(trees): 116 | label_set = set() 117 | for tree in trees: 118 | for _, _, label in get_labeled_spans(tree): 119 | if label: 120 | label_set.add(label) 121 | label_set = [""] + sorted(label_set) 122 | return {label: i for i, label in enumerate(label_set)} 123 | 124 | @staticmethod 125 | def infer_force_root_constituent(trees): 126 | for tree in trees: 127 | for _, _, label in get_labeled_spans(tree): 128 | if not label: 129 | return False 130 | return True 131 | 132 | def chart_from_tree(self, tree): 133 | spans = get_labeled_spans(tree) 134 | num_words = len(tree.leaves()) 135 | chart = np.full((num_words, num_words), -100, dtype=int) 136 | chart = np.tril(chart, -1) 137 | # Now all invalid entries are filled with -100, and valid entries with 0 138 | for start, end, label in spans: 139 | # Previously unseen unary chains can occur in the dev/test sets. 140 | # For now, we ignore them and don't mark the corresponding chart 141 | # entry as a constituent. 142 | if label in self.label_vocab: 143 | chart[start, end] = self.label_vocab[label] 144 | return chart 145 | 146 | def charts_from_pytorch_scores_batched(self, scores, lengths): 147 | """Runs CKY to recover span labels from scores (e.g. logits). 148 | 149 | This method uses pytorch-struct to speed up decoding compared to the 150 | pure-Python implementation of CKY used by tree_from_scores(). 151 | 152 | Args: 153 | scores: a pytorch tensor of shape (batch size, max length, 154 | max length, label vocab size). 155 | lengths: a pytorch tensor of shape (batch size,) 156 | 157 | Returns: 158 | A list of numpy arrays, each of shape (sentence length, sentence 159 | length). 160 | """ 161 | scores = scores.detach() 162 | scores = scores - scores[..., :1] 163 | if self.force_root_constituent: 164 | scores[torch.arange(scores.shape[0]), 0, lengths - 1, 0] -= 1e9 165 | dist = torch_struct.TreeCRF(scores, lengths=lengths) 166 | amax = dist.argmax 167 | amax[..., 0] += 1e-9 168 | padded_charts = amax.argmax(-1) 169 | padded_charts = padded_charts.detach().cpu().numpy() 170 | return [ 171 | chart[:length, :length] for chart, length in zip(padded_charts, lengths) 172 | ] 173 | 174 | def compressed_output_from_chart(self, chart): 175 | chart_with_filled_diagonal = chart.copy() 176 | np.fill_diagonal(chart_with_filled_diagonal, 1) 177 | chart_with_filled_diagonal[0, -1] = 1 178 | starts, inclusive_ends = np.where(chart_with_filled_diagonal) 179 | preorder_sort = np.lexsort((-inclusive_ends, starts)) 180 | starts = starts[preorder_sort] 181 | inclusive_ends = inclusive_ends[preorder_sort] 182 | labels = chart[starts, inclusive_ends] 183 | ends = inclusive_ends + 1 184 | return CompressedParserOutput(starts=starts, ends=ends, labels=labels) 185 | 186 | def tree_from_chart(self, chart, leaves): 187 | compressed_output = self.compressed_output_from_chart(chart) 188 | return compressed_output.to_tree(leaves, self.label_from_index) 189 | 190 | def tree_from_scores(self, scores, leaves): 191 | """Runs CKY to decode a tree from scores (e.g. logits). 192 | 193 | If speed is important, consider using charts_from_pytorch_scores_batched 194 | followed by compressed_output_from_chart or tree_from_chart instead. 195 | 196 | Args: 197 | scores: a chart of scores (or logits) of shape 198 | (sentence length, sentence length, label vocab size). The first 199 | two dimensions may be padded to a longer length, but all padded 200 | values will be ignored. 201 | leaves: the leaf nodes to use in the constructed tree. These 202 | may be of type str or nltk.Tree, or (word, tag) tuples that 203 | will be used to construct the leaf node objects. 204 | 205 | Returns: 206 | An nltk.Tree object. 207 | """ 208 | leaves = [ 209 | nltk.Tree(node[1], [node[0]]) if isinstance(node, tuple) else node 210 | for node in leaves 211 | ] 212 | 213 | chart = {} 214 | scores = scores - scores[:, :, 0, None] 215 | for length in range(1, len(leaves) + 1): 216 | for left in range(0, len(leaves) + 1 - length): 217 | right = left + length 218 | 219 | label_scores = scores[left, right - 1] 220 | label_scores = label_scores - label_scores[0] 221 | 222 | argmax_label_index = int( 223 | label_scores.argmax() 224 | if length < len(leaves) or not self.force_root_constituent 225 | else label_scores[1:].argmax() + 1 226 | ) 227 | argmax_label = self.label_from_index[argmax_label_index] 228 | label = argmax_label 229 | label_score = label_scores[argmax_label_index] 230 | 231 | if length == 1: 232 | tree = leaves[left] 233 | if label: 234 | tree = nltk.tree.Tree(label, [tree]) 235 | chart[left, right] = [tree], label_score 236 | continue 237 | 238 | best_split = max( 239 | range(left + 1, right), 240 | key=lambda split: (chart[left, split][1] + chart[split, right][1]), 241 | ) 242 | 243 | left_trees, left_score = chart[left, best_split] 244 | right_trees, right_score = chart[best_split, right] 245 | 246 | children = left_trees + right_trees 247 | if label: 248 | children = [nltk.tree.Tree(label, children)] 249 | 250 | chart[left, right] = (children, label_score + left_score + right_score) 251 | 252 | children, score = chart[0, len(leaves)] 253 | tree = nltk.tree.Tree("TOP", children) 254 | tree = uncollapse_unary(tree) 255 | return tree 256 | 257 | 258 | class SpanClassificationMarginLoss(nn.Module): 259 | def __init__(self, force_root_constituent=True, reduction="mean"): 260 | super().__init__() 261 | self.force_root_constituent = force_root_constituent 262 | if reduction not in ("none", "mean", "sum"): 263 | raise ValueError(f"Invalid value for reduction: {reduction}") 264 | self.reduction = reduction 265 | 266 | 267 | def forward(self, logits, labels): 268 | gold_event = F.one_hot(F.relu(labels), num_classes=logits.shape[-1]) 269 | 270 | logits = logits - logits[..., :1] 271 | lengths = (labels[:, 0, :] != -100).sum(-1) 272 | augment = (1 - gold_event).to(torch.float) 273 | 274 | if self.force_root_constituent: 275 | augment[torch.arange(augment.shape[0]), 0, lengths - 1, 0] -= 1e9 276 | dist = torch_struct.TreeCRF(logits + augment, lengths=lengths) 277 | 278 | pred_score = dist.max 279 | gold_score = (logits * gold_event).sum((1, 2, 3)) 280 | 281 | margin_losses = F.relu(pred_score - gold_score) 282 | 283 | if self.reduction == "none": 284 | return margin_losses 285 | elif self.reduction == "mean": 286 | return margin_losses.mean() 287 | elif self.reduction == "sum": 288 | return margin_losses.sum() 289 | else: 290 | assert False, f"Unexpected reduction: {self.reduction}" 291 | -------------------------------------------------------------------------------- /src/benepar/integrations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thuhcsi/SpanPSP/d116e0d4d9d855c7da6e617ef39f8fc96a9cb599/src/benepar/integrations/__init__.py -------------------------------------------------------------------------------- /src/benepar/integrations/downloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | BENEPAR_SERVER_INDEX = "https://kitaev.com/benepar/index.xml" 4 | 5 | _downloader = None 6 | def get_downloader(): 7 | global _downloader 8 | if _downloader is None: 9 | import nltk.downloader 10 | _downloader = nltk.downloader.Downloader(server_index_url=BENEPAR_SERVER_INDEX) 11 | return _downloader 12 | 13 | def download(*args, **kwargs): 14 | return get_downloader().download(*args, **kwargs) 15 | 16 | def locate_model(name): 17 | if os.path.exists(name): 18 | return name 19 | elif "/" not in name and "." not in name: 20 | import nltk.data 21 | try: 22 | nltk_loc = nltk.data.find(f"models/{name}") 23 | return nltk_loc.path 24 | except LookupError as e: 25 | arg = e.args[0].replace("nltk.download", "benepar.download") 26 | 27 | raise LookupError(arg) 28 | 29 | raise LookupError("Can't find {}".format(name)) 30 | 31 | def load_trained_model(model_name_or_path): 32 | model_path = locate_model(model_name_or_path) 33 | from ..parse_chart import ChartParser 34 | parser = ChartParser.from_trained(model_path) 35 | return parser 36 | -------------------------------------------------------------------------------- /src/benepar/integrations/nltk_plugin.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import itertools 3 | from typing import List, Optional, Tuple 4 | 5 | import nltk 6 | import torch 7 | 8 | from .downloader import load_trained_model 9 | from ..parse_base import BaseParser, BaseInputExample 10 | from ..ptb_unescape import ptb_unescape, guess_space_after 11 | 12 | 13 | TOKENIZER_LOOKUP = { 14 | "en": "english", 15 | "de": "german", 16 | "fr": "french", 17 | "pl": "polish", 18 | "sv": "swedish", 19 | } 20 | 21 | LANGUAGE_GUESS = { 22 | "ar": ("X", "XP", "WHADVP", "WHNP", "WHPP"), 23 | "zh": ("VSB", "VRD", "VPT", "VNV"), 24 | "en": ("WHNP", "WHADJP", "SINV", "SQ"), 25 | "de": ("AA", "AP", "CCP", "CH", "CNP", "VZ"), 26 | "fr": ("P+", "P+D+", "PRO+", "PROREL+"), 27 | "he": ("PREDP", "SYN_REL", "SYN_yyDOT"), 28 | "pl": ("formaczas", "znakkonca"), 29 | "sv": ("PSEUDO", "AVP", "XP"), 30 | } 31 | 32 | 33 | def guess_language(label_vocab): 34 | """Guess parser language based on its syntactic label inventory. 35 | 36 | The parser training scripts are designed to accept arbitrary input tree 37 | files with minimal language-specific behavior, but at inference time we may 38 | need to know the language identity in order to invoke other pipeline 39 | elements, such as tokenizers. 40 | """ 41 | for language, required_labels in LANGUAGE_GUESS.items(): 42 | if all(label in label_vocab for label in required_labels): 43 | return language 44 | return None 45 | 46 | 47 | @dataclasses.dataclass 48 | class InputSentence(BaseInputExample): 49 | """Parser input for a single sentence. 50 | 51 | At least one of `words` and `escaped_words` is required for each input 52 | sentence. The remaining fields are optional: the parser will attempt to 53 | derive the value for any missing fields using the fields that are provided. 54 | 55 | `words` and `space_after` together form a reversible tokenization of the 56 | input text: they represent, respectively, the Unicode text for each word and 57 | an indicator for whether the word is followed by whitespace. These are used 58 | as inputs by the parser. 59 | 60 | `tags` is a list of part-of-speech tags, if available prior to running the 61 | parser. The parser does not actually use these tags as input, but it will 62 | pass them through to its output. If `tags` is None, the parser will perform 63 | its own part of speech tagging (if the parser was not trained to also do 64 | tagging, "UNK" part-of-speech tags will be used in the output instead). 65 | 66 | `escaped_words` are the representations of each leaf to use in the output 67 | tree. If `words` is provided, `escaped_words` will not be used by the neural 68 | network portion of the parser, and will only be incorporated when 69 | constructing the output tree. Therefore, `escaped_words` may be used to 70 | accommodate any dataset-specific text encoding, such as transliteration. 71 | 72 | Here is an example of the differences between these fields for English PTB: 73 | (raw text): "Fly safely." 74 | words: " Fly safely . " 75 | space_after: False True False False False 76 | tags: `` VB RB . '' 77 | escaped_words: `` Fly safely . '' 78 | """ 79 | 80 | words: Optional[List[str]] = None 81 | space_after: Optional[List[bool]] = None 82 | tags: Optional[List[str]] = None 83 | escaped_words: Optional[List[str]] = None 84 | 85 | @property 86 | def tree(self): 87 | return None 88 | 89 | def leaves(self): 90 | return self.escaped_words 91 | 92 | def pos(self): 93 | if self.tags is not None: 94 | return list(zip(self.escaped_words, self.tags)) 95 | else: 96 | return [(word, "UNK") for word in self.escaped_words] 97 | 98 | 99 | class Parser: 100 | """Berkeley Neural Parser (benepar), integrated with NLTK. 101 | 102 | Use this class to apply the Berkeley Neural Parser to pre-tokenized datasets 103 | and treebanks, or when integrating the parser into an NLP pipeline that 104 | already performs tokenization, sentence splitting, and (optionally) 105 | part-of-speech tagging. For parsing starting with raw text, it is strongly 106 | encouraged that you use spaCy and benepar.BeneparComponent instead. 107 | 108 | Sample usage: 109 | >>> parser = benepar.Parser("benepar_en3") 110 | >>> input_sentence = benepar.InputSentence( 111 | words=['"', 'Fly', 'safely', '.', '"'], 112 | space_after=[False, True, False, False, False], 113 | tags=['``', 'VB', 'RB', '.', "''"], 114 | escaped_words=['``', 'Fly', 'safely', '.', "''"], 115 | ) 116 | >>> parser.parse(input_sentence) 117 | 118 | Not all fields of benepar.InputSentence are required, but at least one of 119 | `words` and `escaped_words` must not be None. The parser will attempt to 120 | guess the value for missing fields. For example, 121 | >>> input_sentence = benepar.InputSentence( 122 | words=['"', 'Fly', 'safely', '.', '"'], 123 | ) 124 | >>> parser.parse(input_sentence) 125 | 126 | Although this class is primarily designed for use with data that has already 127 | been tokenized, to help with interactive use and debugging it also accepts 128 | simple text string inputs. However, using this class to parse from raw text 129 | is STRONGLY DISCOURAGED for any application where parsing accuracy matters. 130 | When parsing from raw text, use spaCy and benepar.BeneparComponent instead. 131 | The reason is that parser models do not ship with a tokenizer or sentence 132 | splitter, and some models may not include a part-of-speech tagger either. A 133 | toolkit must be used to fill in these pipeline components, and spaCy 134 | outperforms NLTK in all of these areas (sometimes by a large margin). 135 | >>> parser.parse('"Fly safely."') # For debugging/interactive use only. 136 | """ 137 | 138 | def __init__(self, name, batch_size=64, language_code=None): 139 | """Load a trained parser model. 140 | 141 | Args: 142 | name (str): Model name, or path to pytorch saved model 143 | batch_size (int): Maximum number of sentences to process per batch 144 | language_code (str, optional): language code for the parser (e.g. 145 | 'en', 'he', 'zh', etc). Our official trained models will set 146 | this automatically, so this argument is only needed if training 147 | on new languages or treebanks. 148 | """ 149 | self._parser = load_trained_model(name) 150 | if torch.cuda.is_available(): 151 | self._parser.cuda() 152 | if language_code is not None: 153 | self._language_code = language_code 154 | else: 155 | self._language_code = guess_language(self._parser.config["label_vocab"]) 156 | self._tokenizer_lang = TOKENIZER_LOOKUP.get(self._language_code, None) 157 | 158 | self.batch_size = batch_size 159 | 160 | def parse(self, sentence): 161 | """Parse a single sentence 162 | 163 | Args: 164 | sentence (InputSentence or List[str] or str): Sentence to parse. 165 | If the input is of List[str], it is assumed to be a sequence of 166 | words and will behave the same as only setting the `words` field 167 | of InputSentence. If the input is of type str, the sentence will 168 | be tokenized using the default NLTK tokenizer (not recommended: 169 | if parsing from raw text, use spaCy and benepar.BeneparComponent 170 | instead). 171 | 172 | Returns: 173 | nltk.Tree 174 | """ 175 | return list(self.parse_sents([sentence]))[0] 176 | 177 | def parse_sents(self, sents): 178 | """Parse multiple sentences in batches. 179 | 180 | Args: 181 | sents (Iterable[InputSentence]): An iterable of sentences to be 182 | parsed. `sents` may also be a string, in which case it will be 183 | segmented into sentences using the default NLTK sentence 184 | splitter (not recommended: if parsing from raw text, use spaCy 185 | and benepar.BeneparComponent instead). Otherwise, each element 186 | of `sents` will be treated as a sentence. The elements of 187 | `sents` may also be List[str] or str: see Parser.parse() for 188 | documentation regarding these cases. 189 | 190 | Yields: 191 | nltk.Tree objects, one per input sentence. 192 | """ 193 | if isinstance(sents, str): 194 | if self._tokenizer_lang is None: 195 | raise ValueError( 196 | "No tokenizer available for this language. " 197 | "Please split into individual sentences and tokens " 198 | "before calling the parser." 199 | ) 200 | sents = nltk.sent_tokenize(sents, self._tokenizer_lang) 201 | 202 | end_sentinel = object() 203 | for batch_sents in itertools.zip_longest( 204 | *([iter(sents)] * self.batch_size), fillvalue=end_sentinel 205 | ): 206 | batch_inputs = [] 207 | for sent in batch_sents: 208 | if sent is end_sentinel: 209 | break 210 | elif isinstance(sent, str): 211 | if self._tokenizer_lang is None: 212 | raise ValueError( 213 | "No word tokenizer available for this language. " 214 | "Please tokenize before calling the parser." 215 | ) 216 | escaped_words = nltk.word_tokenize(sent, self._tokenizer_lang) 217 | sent = InputSentence(escaped_words=escaped_words) 218 | elif isinstance(sent, (list, tuple)): 219 | sent = InputSentence(words=sent) 220 | elif not isinstance(sent, InputSentence): 221 | raise ValueError( 222 | "Sentences must be one of: InputSentence, list, tuple, or str" 223 | ) 224 | batch_inputs.append(self._with_missing_fields_filled(sent)) 225 | 226 | for inp, output in zip( 227 | batch_inputs, self._parser.parse(batch_inputs, return_compressed=True) 228 | ): 229 | # If pos tags are provided as input, ignore any tags predicted 230 | # by the parser. 231 | if inp.tags is not None: 232 | output = output.without_predicted_tags() 233 | yield output.to_tree( 234 | inp.pos(), 235 | self._parser.decoder.label_from_index, 236 | self._parser.tag_from_index, 237 | ) 238 | 239 | def _with_missing_fields_filled(self, sent): 240 | if not isinstance(sent, InputSentence): 241 | raise ValueError("Input is not an instance of InputSentence") 242 | if sent.words is None and sent.escaped_words is None: 243 | raise ValueError("At least one of words or escaped_words is required") 244 | elif sent.words is None: 245 | sent = dataclasses.replace(sent, words=ptb_unescape(sent.escaped_words)) 246 | elif sent.escaped_words is None: 247 | escaped_words = [ 248 | word.replace("(", "-LRB-") 249 | .replace(")", "-RRB-") 250 | .replace("{", "-LCB-") 251 | .replace("}", "-RCB-") 252 | .replace("[", "-LSB-") 253 | .replace("]", "-RSB-") 254 | for word in sent.words 255 | ] 256 | sent = dataclasses.replace(sent, escaped_words=escaped_words) 257 | else: 258 | if len(sent.words) != len(sent.escaped_words): 259 | raise ValueError( 260 | f"Length of words ({len(sent.words)}) does not match " 261 | f"escaped_words ({len(sent.escaped_words)})" 262 | ) 263 | 264 | if sent.space_after is None: 265 | if self._language_code == "zh": 266 | space_after = [False for _ in sent.words] 267 | elif self._language_code in ("ar", "he"): 268 | space_after = [True for _ in sent.words] 269 | else: 270 | space_after = guess_space_after(sent.words) 271 | sent = dataclasses.replace(sent, space_after=space_after) 272 | elif len(sent.words) != len(sent.space_after): 273 | raise ValueError( 274 | f"Length of words ({len(sent.words)}) does not match " 275 | f"space_after ({len(sent.space_after)})" 276 | ) 277 | 278 | assert len(sent.words) == len(sent.escaped_words) == len(sent.space_after) 279 | return sent 280 | -------------------------------------------------------------------------------- /src/benepar/integrations/spacy_extensions.py: -------------------------------------------------------------------------------- 1 | NOT_PARSED_SENTINEL = object() 2 | 3 | 4 | class NonConstituentException(Exception): 5 | pass 6 | 7 | 8 | class ConstituentData: 9 | def __init__(self, starts, ends, labels, loc_to_constituent, label_vocab): 10 | self.starts = starts 11 | self.ends = ends 12 | self.labels = labels 13 | self.loc_to_constituent = loc_to_constituent 14 | self.label_vocab = label_vocab 15 | 16 | 17 | def get_constituent(span): 18 | constituent_data = span.doc._._constituent_data 19 | if constituent_data is NOT_PARSED_SENTINEL: 20 | raise Exception( 21 | "No constituency parse is available for this document." 22 | " Consider adding a BeneparComponent to the pipeline." 23 | ) 24 | 25 | search_start = constituent_data.loc_to_constituent[span.start] 26 | if span.start + 1 < len(constituent_data.loc_to_constituent): 27 | search_end = constituent_data.loc_to_constituent[span.start + 1] 28 | else: 29 | search_end = len(constituent_data.ends) 30 | found_position = None 31 | for position in range(search_start, search_end): 32 | if constituent_data.ends[position] <= span.end: 33 | if constituent_data.ends[position] == span.end: 34 | found_position = position 35 | break 36 | 37 | if found_position is None: 38 | raise NonConstituentException("Span is not a constituent: {}".format(span)) 39 | return constituent_data, found_position 40 | 41 | 42 | def get_labels(span): 43 | constituent_data, position = get_constituent(span) 44 | label_num = constituent_data.labels[position] 45 | return constituent_data.label_vocab[label_num] 46 | 47 | 48 | def parse_string(span): 49 | constituent_data, position = get_constituent(span) 50 | label_vocab = constituent_data.label_vocab 51 | doc = span.doc 52 | 53 | idx = position - 1 54 | 55 | def make_str(): 56 | nonlocal idx 57 | idx += 1 58 | i, j, label_idx = ( 59 | constituent_data.starts[idx], 60 | constituent_data.ends[idx], 61 | constituent_data.labels[idx], 62 | ) 63 | label = label_vocab[label_idx] 64 | if (i + 1) >= j: 65 | token = doc[i] 66 | s = ( 67 | "(" 68 | + u"{} {}".format(token.tag_, token.text) 69 | .replace("(", "-LRB-") 70 | .replace(")", "-RRB-") 71 | .replace("{", "-LCB-") 72 | .replace("}", "-RCB-") 73 | .replace("[", "-LSB-") 74 | .replace("]", "-RSB-") 75 | + ")" 76 | ) 77 | else: 78 | children = [] 79 | while ( 80 | (idx + 1) < len(constituent_data.starts) 81 | and i <= constituent_data.starts[idx + 1] 82 | and constituent_data.ends[idx + 1] <= j 83 | ): 84 | children.append(make_str()) 85 | 86 | s = u" ".join(children) 87 | 88 | for sublabel in reversed(label): 89 | s = u"({} {})".format(sublabel, s) 90 | return s 91 | 92 | return make_str() 93 | 94 | 95 | def get_subconstituents(span): 96 | constituent_data, position = get_constituent(span) 97 | label_vocab = constituent_data.label_vocab 98 | doc = span.doc 99 | 100 | while position < len(constituent_data.starts): 101 | start = constituent_data.starts[position] 102 | end = constituent_data.ends[position] 103 | 104 | if span.end <= start or span.end < end: 105 | break 106 | 107 | yield doc[start:end] 108 | position += 1 109 | 110 | 111 | def get_child_spans(span): 112 | constituent_data, position = get_constituent(span) 113 | label_vocab = constituent_data.label_vocab 114 | doc = span.doc 115 | 116 | child_start_expected = span.start 117 | position += 1 118 | while position < len(constituent_data.starts): 119 | start = constituent_data.starts[position] 120 | end = constituent_data.ends[position] 121 | 122 | if span.end <= start or span.end < end: 123 | break 124 | 125 | if start == child_start_expected: 126 | yield doc[start:end] 127 | child_start_expected = end 128 | 129 | position += 1 130 | 131 | 132 | def get_parent_span(span): 133 | constituent_data, position = get_constituent(span) 134 | label_vocab = constituent_data.label_vocab 135 | doc = span.doc 136 | sent = span.sent 137 | 138 | position -= 1 139 | while position >= 0: 140 | start = constituent_data.starts[position] 141 | end = constituent_data.ends[position] 142 | 143 | if start <= span.start and span.end <= end: 144 | return doc[start:end] 145 | if end < span.sent.start: 146 | break 147 | position -= 1 148 | 149 | return None 150 | 151 | 152 | def install_spacy_extensions(): 153 | from spacy.tokens import Doc, Span, Token 154 | 155 | # None is not allowed as a default extension value! 156 | Doc.set_extension("_constituent_data", default=NOT_PARSED_SENTINEL) 157 | 158 | Span.set_extension("labels", getter=get_labels) 159 | Span.set_extension("parse_string", getter=parse_string) 160 | Span.set_extension("constituents", getter=get_subconstituents) 161 | Span.set_extension("parent", getter=get_parent_span) 162 | Span.set_extension("children", getter=get_child_spans) 163 | 164 | Token.set_extension( 165 | "labels", getter=lambda token: get_labels(token.doc[token.i : token.i + 1]) 166 | ) 167 | Token.set_extension( 168 | "parse_string", 169 | getter=lambda token: parse_string(token.doc[token.i : token.i + 1]), 170 | ) 171 | Token.set_extension( 172 | "parent", getter=lambda token: get_parent_span(token.doc[token.i : token.i + 1]) 173 | ) 174 | 175 | 176 | try: 177 | install_spacy_extensions() 178 | except ImportError: 179 | pass 180 | -------------------------------------------------------------------------------- /src/benepar/integrations/spacy_plugin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .downloader import load_trained_model 4 | from ..parse_base import BaseParser, BaseInputExample 5 | from .spacy_extensions import ConstituentData, NonConstituentException 6 | 7 | import torch 8 | 9 | 10 | class PartialConstituentData: 11 | def __init__(self): 12 | self.starts = [np.array([], dtype=int)] 13 | self.ends = [np.array([], dtype=int)] 14 | self.labels = [np.array([], dtype=int)] 15 | 16 | def finalize(self, doc, label_vocab): 17 | self.starts = np.hstack(self.starts) 18 | self.ends = np.hstack(self.ends) 19 | self.labels = np.hstack(self.labels) 20 | 21 | # TODO(nikita): Python for loops aren't very fast 22 | loc_to_constituent = np.full(len(doc), -1, dtype=int) 23 | prev = None 24 | for position in range(self.starts.shape[0]): 25 | if self.starts[position] != prev: 26 | prev = self.starts[position] 27 | loc_to_constituent[self.starts[position]] = position 28 | 29 | return ConstituentData( 30 | self.starts, self.ends, self.labels, loc_to_constituent, label_vocab 31 | ) 32 | 33 | 34 | class SentenceWrapper(BaseInputExample): 35 | TEXT_NORMALIZATION_MAPPING = { 36 | "`": "'", 37 | "«": '"', 38 | "»": '"', 39 | "‘": "'", 40 | "’": "'", 41 | "“": '"', 42 | "”": '"', 43 | "„": '"', 44 | "‹": "'", 45 | "›": "'", 46 | "—": "--", # em dash 47 | } 48 | 49 | def __init__(self, spacy_sent): 50 | self.sent = spacy_sent 51 | 52 | @property 53 | def words(self): 54 | return [ 55 | self.TEXT_NORMALIZATION_MAPPING.get(token.text, token.text) 56 | for token in self.sent 57 | ] 58 | 59 | @property 60 | def space_after(self): 61 | return [bool(token.whitespace_) for token in self.sent] 62 | 63 | @property 64 | def tree(self): 65 | return None 66 | 67 | def leaves(self): 68 | return self.words 69 | 70 | def pos(self): 71 | return [(word, "UNK") for word in self.words] 72 | 73 | 74 | class BeneparComponent: 75 | """ 76 | Berkeley Neural Parser (benepar) component for spaCy. 77 | 78 | Sample usage: 79 | >>> nlp = spacy.load('en_core_web_md') 80 | >>> if spacy.__version__.startswith('2'): 81 | nlp.add_pipe(BeneparComponent("benepar_en3")) 82 | else: 83 | nlp.add_pipe("benepar", config={"model": "benepar_en3"}) 84 | >>> doc = nlp("The quick brown fox jumps over the lazy dog.") 85 | >>> sent = list(doc.sents)[0] 86 | >>> print(sent._.parse_string) 87 | 88 | This component is only responsible for constituency parsing and (for some 89 | trained models) part-of-speech tagging. It should be preceded in the 90 | pipeline by other components that can, at minimum, perform tokenization and 91 | sentence segmentation. 92 | """ 93 | 94 | name = "benepar" 95 | 96 | def __init__( 97 | self, 98 | name, 99 | subbatch_max_tokens=500, 100 | disable_tagger=False, 101 | batch_size="ignored", 102 | ): 103 | """Load a trained parser model. 104 | 105 | Args: 106 | name (str): Model name, or path to pytorch saved model 107 | subbatch_max_tokens (int): Maximum number of tokens to process in 108 | each batch 109 | disable_tagger (bool, default False): Unless disabled, the parser 110 | will set predicted part-of-speech tags for the document, 111 | overwriting any existing tags provided by spaCy models or 112 | previous pipeline steps. This option has no effect for parser 113 | models that do not have a part-of-speech tagger built in. 114 | batch_size: deprecated and ignored; use subbatch_max_tokens instead 115 | """ 116 | self._parser = load_trained_model(name) 117 | if torch.cuda.is_available(): 118 | self._parser.cuda() 119 | 120 | self.subbatch_max_tokens = subbatch_max_tokens 121 | self.disable_tagger = disable_tagger 122 | 123 | self._label_vocab = self._parser.config["label_vocab"] 124 | label_vocab_size = max(self._label_vocab.values()) + 1 125 | self._label_from_index = [()] * label_vocab_size 126 | for label, i in self._label_vocab.items(): 127 | if label: 128 | self._label_from_index[i] = tuple(label.split("::")) 129 | else: 130 | self._label_from_index[i] = () 131 | self._label_from_index = tuple(self._label_from_index) 132 | 133 | if not self.disable_tagger: 134 | tag_vocab = self._parser.config["tag_vocab"] 135 | tag_vocab_size = max(tag_vocab.values()) + 1 136 | self._tag_from_index = [()] * tag_vocab_size 137 | for tag, i in tag_vocab.items(): 138 | self._tag_from_index[i] = tag 139 | self._tag_from_index = tuple(self._tag_from_index) 140 | else: 141 | self._tag_from_index = None 142 | 143 | def __call__(self, doc): 144 | """Update the input document with predicted constituency parses.""" 145 | # TODO(https://github.com/nikitakit/self-attentive-parser/issues/16): handle 146 | # tokens that consist entirely of whitespace. 147 | constituent_data = PartialConstituentData() 148 | wrapped_sents = [SentenceWrapper(sent) for sent in doc.sents] 149 | for sent, parse in zip( 150 | doc.sents, 151 | self._parser.parse( 152 | wrapped_sents, 153 | return_compressed=True, 154 | subbatch_max_tokens=self.subbatch_max_tokens, 155 | ), 156 | ): 157 | constituent_data.starts.append(parse.starts + sent.start) 158 | constituent_data.ends.append(parse.ends + sent.start) 159 | constituent_data.labels.append(parse.labels) 160 | 161 | if parse.tags is not None and not self.disable_tagger: 162 | for i, tag_id in enumerate(parse.tags): 163 | sent[i].tag_ = self._tag_from_index[tag_id] 164 | 165 | doc._._constituent_data = constituent_data.finalize(doc, self._label_from_index) 166 | return doc 167 | 168 | 169 | def create_benepar_component( 170 | nlp, 171 | name, 172 | model: str, 173 | subbatch_max_tokens: int, 174 | disable_tagger: bool, 175 | ): 176 | return BeneparComponent( 177 | model, 178 | subbatch_max_tokens=subbatch_max_tokens, 179 | disable_tagger=disable_tagger, 180 | ) 181 | 182 | 183 | def register_benepar_component_factory(): 184 | # Starting with spaCy 3.0, nlp.add_pipe no longer directly accepts 185 | # BeneparComponent instances. We must instead register a component factory. 186 | import spacy 187 | 188 | if spacy.__version__.startswith("2"): 189 | return 190 | 191 | from spacy.language import Language 192 | 193 | Language.factory( 194 | "benepar", 195 | default_config={ 196 | "subbatch_max_tokens": 500, 197 | "disable_tagger": False, 198 | }, 199 | func=create_benepar_component, 200 | ) 201 | 202 | 203 | try: 204 | register_benepar_component_factory() 205 | except ImportError: 206 | pass 207 | -------------------------------------------------------------------------------- /src/benepar/nkutil.py: -------------------------------------------------------------------------------- 1 | class HParams: 2 | _skip_keys = ["populate_arguments", "set_from_args", "print", "to_dict"] 3 | 4 | def __init__(self, **kwargs): 5 | for k, v in kwargs.items(): 6 | setattr(self, k, v) 7 | 8 | def __getitem__(self, item): 9 | return getattr(self, item) 10 | 11 | def __setitem__(self, item, value): 12 | if not hasattr(self, item): 13 | raise KeyError(f"Hyperparameter {item} has not been declared yet") 14 | setattr(self, item, value) 15 | 16 | def to_dict(self): 17 | res = {} 18 | for k in dir(self): 19 | if k.startswith("_") or k in self._skip_keys: 20 | continue 21 | res[k] = self[k] 22 | return res 23 | 24 | def populate_arguments(self, parser): 25 | for k in dir(self): 26 | if k.startswith("_") or k in self._skip_keys: 27 | continue 28 | v = self[k] 29 | k = k.replace("_", "-") 30 | if type(v) in (int, float, str): 31 | parser.add_argument(f"--{k}", type=type(v), default=v) 32 | elif isinstance(v, bool): 33 | if not v: 34 | parser.add_argument(f"--{k}", action="store_true") 35 | else: 36 | parser.add_argument(f"--no-{k}", action="store_false") 37 | 38 | def set_from_args(self, args): 39 | for k in dir(self): 40 | if k.startswith("_") or k in self._skip_keys: 41 | continue 42 | if hasattr(args, k): 43 | self[k] = getattr(args, k) 44 | elif hasattr(args, f"no_{k}"): 45 | self[k] = getattr(args, f"no_{k}") 46 | 47 | def print(self): 48 | for k in dir(self): 49 | if k.startswith("_") or k in self._skip_keys: 50 | continue 51 | print(k, repr(self[k])) 52 | -------------------------------------------------------------------------------- /src/benepar/parse_base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | import dataclasses 3 | from typing import Any, Iterable, List, Optional, Tuple, Union 4 | 5 | import nltk 6 | import numpy as np 7 | 8 | 9 | class BaseInputExample(ABC): 10 | """Parser input for a single sentence (abstract interface).""" 11 | 12 | # Subclasses must define the following attributes or properties. 13 | # `words` is a list of unicode representations for each word in the sentence 14 | # and `space_after` is a list of booleans that indicate whether there is 15 | # whitespace after a word. Together, these should form a reversible 16 | # tokenization of raw text input. `tree` is an optional gold parse tree. 17 | words: List[str] 18 | space_after: List[bool] 19 | tree: Optional[nltk.Tree] 20 | 21 | @abstractmethod 22 | def leaves(self) -> Optional[List[str]]: 23 | """Returns leaves to use in the parse tree. 24 | 25 | While `words` must be raw unicode text, these should be whatever is 26 | standard for the treebank. For example, '(' in words might correspond to 27 | '-LRB-' in leaves, and leaves might include other transformations such 28 | as transliteration. 29 | """ 30 | pass 31 | 32 | @abstractmethod 33 | def pos(self) -> Optional[List[Tuple[str, str]]]: 34 | """Returns a list of (leaf, part-of-speech tag) tuples.""" 35 | pass 36 | 37 | 38 | @dataclasses.dataclass 39 | class CompressedParserOutput: 40 | """Parser output, encoded as a collection of numpy arrays. 41 | 42 | By default, a parser will return nltk.Tree objects. These have much nicer 43 | APIs than the CompressedParserOutput class, and the code involved is simpler 44 | and more readable. As a trade-off, code dealing with nltk.Tree objects is 45 | slower: the nltk.Tree type itself has some overhead, and algorithms dealing 46 | with it are implemented in pure Python as opposed to C or even CUDA. The 47 | CompressedParserOutput type is an alternative that has some optimizations 48 | for the sole purpose of speeding up inference. 49 | 50 | If trying a new parser type for research purposes, it's safe to ignore this 51 | class and the return_compressed argument to parse(). If the parser works 52 | well and is being released, the return_compressed argument can then be added 53 | with a dedicated fast implementation, or simply by using the from_tree 54 | method defined below. 55 | """ 56 | 57 | # A parse tree is represented as a set of constituents. In the case of 58 | # non-binary trees, only the labeled non-terminal nodes are included: there 59 | # are no dummy nodes inserted for binarization purposes. However, single 60 | # words are always included in the set of constituents, and they may have a 61 | # null label if there is no phrasal category above the part-of-speech tag. 62 | # All constituents are sorted according to pre-order traversal, and each has 63 | # an associated start (the index of the first word in the constituent), end 64 | # (1 + the index of the last word in the constituent), and label (index 65 | # associated with an external label_vocab dictionary.) These are then stored 66 | # in three numpy arrays: 67 | starts: Iterable[int] # Must be a numpy array 68 | ends: Iterable[int] # Must be a numpy array 69 | labels: Iterable[int] # Must be a numpy array 70 | 71 | # Part of speech tag ids as output by the parser (may be None if the parser 72 | # does not do POS tagging). These indices are associated with an external 73 | # tag_vocab dictionary. 74 | tags: Optional[Iterable[int]] = None # Must be None or a numpy array 75 | 76 | def without_predicted_tags(self): 77 | return dataclasses.replace(self, tags=None) 78 | 79 | def with_tags(self, tags): 80 | return dataclasses.replace(self, tags=tags) 81 | 82 | @classmethod 83 | def from_tree( 84 | cls, tree: nltk.Tree, label_vocab: dict, tag_vocab: Optional[dict] = None 85 | ) -> "CompressedParserOutput": 86 | num_words = len(tree.leaves()) 87 | starts = np.empty(2 * num_words, dtype=int) 88 | ends = np.empty(2 * num_words, dtype=int) 89 | labels = np.empty(2 * num_words, dtype=int) 90 | 91 | def helper(tree, start, write_idx): 92 | nonlocal starts, ends, labels 93 | label = [] 94 | while len(tree) == 1 and not isinstance(tree[0], str): 95 | if tree.label() != "TOP": 96 | label.append(tree.label()) 97 | tree = tree[0] 98 | 99 | if len(tree) == 1 and isinstance(tree[0], str): 100 | starts[write_idx] = start 101 | ends[write_idx] = start + 1 102 | labels[write_idx] = label_vocab["::".join(label)] 103 | return start + 1, write_idx + 1 104 | 105 | label.append(tree.label()) 106 | starts[write_idx] = start 107 | labels[write_idx] = label_vocab["::".join(label)] 108 | 109 | end = start 110 | new_write_idx = write_idx + 1 111 | for child in tree: 112 | end, new_write_idx = helper(child, end, new_write_idx) 113 | 114 | ends[write_idx] = end 115 | return end, new_write_idx 116 | 117 | _, num_constituents = helper(tree, 0, 0) 118 | starts = starts[:num_constituents] 119 | ends = ends[:num_constituents] 120 | labels = labels[:num_constituents] 121 | 122 | if tag_vocab is None: 123 | tags = None 124 | else: 125 | tags = np.array([tag_vocab[tag] for _, tag in tree.pos()], dtype=int) 126 | 127 | return cls(starts=starts, ends=ends, labels=labels, tags=tags) 128 | 129 | def to_tree(self, leaves, label_from_index: dict, tag_from_index: dict = None): 130 | if self.tags is not None: 131 | if tag_from_index is None: 132 | raise ValueError( 133 | "tags_from_index is required to convert predicted pos tags" 134 | ) 135 | predicted_tags = [tag_from_index[i] for i in self.tags] 136 | assert len(leaves) == len(predicted_tags) 137 | leaves = [ 138 | nltk.Tree(tag, [leaf[0] if isinstance(leaf, tuple) else leaf]) 139 | for tag, leaf in zip(predicted_tags, leaves) 140 | ] 141 | else: 142 | leaves = [ 143 | nltk.Tree(leaf[1], [leaf[0]]) 144 | if isinstance(leaf, tuple) 145 | else (nltk.Tree("UNK", [leaf]) if isinstance(leaf, str) else leaf) 146 | for leaf in leaves 147 | ] 148 | 149 | idx = -1 150 | 151 | def helper(): 152 | nonlocal idx 153 | idx += 1 154 | i, j, label = ( 155 | self.starts[idx], 156 | self.ends[idx], 157 | label_from_index[self.labels[idx]], 158 | ) 159 | if (i + 1) >= j: 160 | children = [leaves[i]] 161 | else: 162 | children = [] 163 | while ( 164 | (idx + 1) < len(self.starts) 165 | and i <= self.starts[idx + 1] 166 | and self.ends[idx + 1] <= j 167 | ): 168 | children.extend(helper()) 169 | 170 | if label: 171 | for sublabel in reversed(label.split("::")): 172 | children = [nltk.Tree(sublabel, children)] 173 | 174 | return children 175 | 176 | children = helper() 177 | return nltk.Tree("TOP", children) 178 | 179 | 180 | class BaseParser(ABC): 181 | """Parser (abstract interface)""" 182 | 183 | @classmethod 184 | @abstractmethod 185 | def from_trained( 186 | cls, model_name: str, config: dict = None, state_dict: dict = None 187 | ) -> "BaseParser": 188 | """Load a trained parser.""" 189 | pass 190 | 191 | @abstractmethod 192 | def parallelize(self, *args, **kwargs): 193 | """Spread out pre-trained model layers across GPUs.""" 194 | pass 195 | 196 | @abstractmethod 197 | def parse( 198 | self, 199 | examples: Iterable[BaseInputExample], 200 | return_compressed: bool = False, 201 | return_scores: bool = False, 202 | subbatch_max_tokens: Optional[int] = None, 203 | ) -> Union[Iterable[nltk.Tree], Iterable[Any]]: 204 | """Parse sentences.""" 205 | pass 206 | 207 | @abstractmethod 208 | def encode_and_collate_subbatches( 209 | self, examples: List[BaseInputExample], subbatch_max_tokens: int 210 | ) -> List[dict]: 211 | """Split batch into sub-batches and convert to tensor features""" 212 | pass 213 | 214 | @abstractmethod 215 | def compute_loss(self, batch: dict): 216 | pass 217 | -------------------------------------------------------------------------------- /src/benepar/parse_chart.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from transformers import AutoConfig, AutoModel 10 | 11 | from . import char_lstm 12 | from . import decode_chart 13 | from . import nkutil 14 | from .partitioned_transformer import ( 15 | ConcatPositionalEncoding, 16 | FeatureDropout, 17 | PartitionedTransformerEncoder, 18 | PartitionedTransformerEncoderLayer, 19 | ) 20 | from . import parse_base 21 | from . import retokenization 22 | from . import subbatching 23 | 24 | 25 | class ChartParser(nn.Module, parse_base.BaseParser): 26 | def __init__( 27 | self, 28 | tag_vocab, 29 | label_vocab, 30 | char_vocab, 31 | hparams, 32 | pretrained_model_path=None, 33 | ): 34 | super().__init__() 35 | self.config = locals() 36 | self.config.pop("self") 37 | self.config.pop("__class__") 38 | self.config.pop("pretrained_model_path") 39 | self.config["hparams"] = hparams.to_dict() 40 | 41 | self.tag_vocab = tag_vocab 42 | self.label_vocab = label_vocab 43 | self.char_vocab = char_vocab 44 | 45 | self.d_model = hparams.d_model 46 | 47 | self.char_encoder = None 48 | self.pretrained_model = None 49 | if hparams.use_chars_lstm: 50 | assert ( 51 | not hparams.use_pretrained 52 | ), "use_chars_lstm and use_pretrained are mutually exclusive" 53 | self.retokenizer = char_lstm.RetokenizerForCharLSTM(self.char_vocab) 54 | self.char_encoder = char_lstm.CharacterLSTM( 55 | max(self.char_vocab.values()) + 1, 56 | hparams.d_char_emb, 57 | hparams.d_model // 2, # Half-size to leave room for 58 | # partitioned positional encoding 59 | char_dropout=hparams.char_lstm_input_dropout, 60 | ) 61 | elif hparams.use_pretrained: 62 | if pretrained_model_path is None: 63 | self.retokenizer = retokenization.Retokenizer( 64 | hparams.pretrained_model, retain_start_stop=True 65 | ) 66 | self.pretrained_model = AutoModel.from_pretrained( 67 | hparams.pretrained_model 68 | ) 69 | # 70 | # self.pretrained_model.eval() 71 | # for name, param in self.pretrained_model.named_parameters(): 72 | # if (name != 'fc.weight') and (name != 'fc.bias'): 73 | # param.requires_grad = False 74 | # print(name,param.requires_grad) 75 | else: 76 | self.retokenizer = retokenization.Retokenizer( 77 | pretrained_model_path, retain_start_stop=True 78 | ) 79 | self.pretrained_model = AutoModel.from_config( 80 | AutoConfig.from_pretrained(pretrained_model_path) 81 | ) 82 | 83 | d_pretrained = self.pretrained_model.config.hidden_size 84 | 85 | if hparams.use_encoder: 86 | self.project_pretrained = nn.Linear( 87 | d_pretrained, hparams.d_model // 2, bias=False 88 | ) 89 | else: 90 | self.project_pretrained = nn.Linear( 91 | d_pretrained, hparams.d_model, bias=False 92 | ) 93 | 94 | if hparams.use_encoder: 95 | self.morpho_emb_dropout = FeatureDropout(hparams.morpho_emb_dropout) 96 | self.add_timing = ConcatPositionalEncoding( 97 | d_model=hparams.d_model, 98 | max_len=hparams.encoder_max_len, 99 | ) 100 | encoder_layer = PartitionedTransformerEncoderLayer( 101 | hparams.d_model, 102 | n_head=hparams.num_heads, 103 | d_qkv=hparams.d_kv, 104 | d_ff=hparams.d_ff, 105 | ff_dropout=hparams.relu_dropout, 106 | residual_dropout=hparams.residual_dropout, 107 | attention_dropout=hparams.attention_dropout, 108 | ) 109 | self.encoder = PartitionedTransformerEncoder( 110 | encoder_layer, hparams.num_layers 111 | ) 112 | else: 113 | self.morpho_emb_dropout = None 114 | self.add_timing = None 115 | self.encoder = None 116 | 117 | self.f_label = nn.Sequential( 118 | nn.Linear(hparams.d_model, hparams.d_label_hidden), 119 | nn.LayerNorm(hparams.d_label_hidden), 120 | nn.ReLU(), 121 | nn.Linear(hparams.d_label_hidden, max(label_vocab.values())), 122 | ) 123 | 124 | if hparams.predict_tags: 125 | self.f_tag = nn.Sequential( 126 | nn.Linear(hparams.d_model, hparams.d_tag_hidden), 127 | nn.LayerNorm(hparams.d_tag_hidden), 128 | nn.ReLU(), 129 | nn.Linear(hparams.d_tag_hidden, max(tag_vocab.values()) + 1), 130 | ) 131 | self.tag_loss_scale = hparams.tag_loss_scale 132 | self.tag_from_index = {i: label for label, i in tag_vocab.items()} 133 | else: 134 | self.f_tag = None 135 | self.tag_from_index = None 136 | 137 | self.decoder = decode_chart.ChartDecoder( 138 | label_vocab=self.label_vocab, 139 | force_root_constituent=hparams.force_root_constituent, 140 | ) 141 | self.criterion = decode_chart.SpanClassificationMarginLoss( 142 | reduction="sum", force_root_constituent=hparams.force_root_constituent 143 | ) 144 | 145 | self.parallelized_devices = None 146 | 147 | @property 148 | def device(self): 149 | if self.parallelized_devices is not None: 150 | return self.parallelized_devices[0] 151 | else: 152 | return next(self.f_label.parameters()).device 153 | 154 | @property 155 | def output_device(self): 156 | if self.parallelized_devices is not None: 157 | return self.parallelized_devices[1] 158 | else: 159 | return next(self.f_label.parameters()).device 160 | 161 | def parallelize(self, *args, **kwargs): 162 | self.parallelized_devices = (torch.device("cuda", 0), torch.device("cuda", 1)) 163 | for child in self.children(): 164 | if child != self.pretrained_model: 165 | child.to(self.output_device) 166 | self.pretrained_model.parallelize(*args, **kwargs) 167 | 168 | @classmethod 169 | def from_trained(cls, model_path): 170 | if os.path.isdir(model_path): 171 | # Multi-file format used when exporting models for release. 172 | # Unlike the checkpoints saved during training, these files include 173 | # all tokenizer parameters and a copy of the pre-trained model 174 | # config (rather than downloading these on-demand). 175 | config = AutoConfig.from_pretrained(model_path).benepar 176 | state_dict = torch.load( 177 | os.path.join(model_path, "benepar_model.bin"), map_location="cpu" 178 | ) 179 | config["pretrained_model_path"] = model_path 180 | else: 181 | # Single-file format used for saving checkpoints during training. 182 | data = torch.load(model_path, map_location="cpu") 183 | config = data["config"] 184 | state_dict = data["state_dict"] 185 | 186 | hparams = config["hparams"] 187 | 188 | if "force_root_constituent" not in hparams: 189 | hparams["force_root_constituent"] = True 190 | 191 | config["hparams"] = nkutil.HParams(**hparams) 192 | parser = cls(**config) 193 | parser.load_state_dict(state_dict) 194 | return parser 195 | 196 | def encode(self, example): 197 | if self.char_encoder is not None: 198 | encoded = self.retokenizer(example.words, return_tensors="np") 199 | else: 200 | encoded = self.retokenizer(example.words, example.space_after) 201 | 202 | if example.tree is not None: 203 | encoded["span_labels"] = torch.tensor( 204 | self.decoder.chart_from_tree(example.tree) 205 | ) 206 | if self.f_tag is not None: 207 | encoded["tag_labels"] = torch.tensor( 208 | [-100] + [self.tag_vocab[tag] for _, tag in example.pos()] + [-100] 209 | ) 210 | return encoded 211 | 212 | def pad_encoded(self, encoded_batch): 213 | batch = self.retokenizer.pad( 214 | [ 215 | { 216 | k: v 217 | for k, v in example.items() 218 | if (k != "span_labels" and k != "tag_labels") 219 | } 220 | for example in encoded_batch 221 | ], 222 | return_tensors="pt", 223 | ) 224 | if encoded_batch and "span_labels" in encoded_batch[0]: 225 | batch["span_labels"] = decode_chart.pad_charts( 226 | [example["span_labels"] for example in encoded_batch] 227 | ) 228 | if encoded_batch and "tag_labels" in encoded_batch[0]: 229 | batch["tag_labels"] = nn.utils.rnn.pad_sequence( 230 | [example["tag_labels"] for example in encoded_batch], 231 | batch_first=True, 232 | padding_value=-100, 233 | ) 234 | return batch 235 | 236 | def _get_lens(self, encoded_batch): 237 | if self.pretrained_model is not None: 238 | return [len(encoded["input_ids"]) for encoded in encoded_batch] 239 | return [len(encoded["valid_token_mask"]) for encoded in encoded_batch] 240 | 241 | def encode_and_collate_subbatches(self, examples, subbatch_max_tokens): 242 | batch_size = len(examples) 243 | batch_num_tokens = sum(len(x.words) for x in examples) 244 | encoded = [self.encode(example) for example in examples] 245 | 246 | res = [] 247 | for ids, subbatch_encoded in subbatching.split( 248 | encoded, costs=self._get_lens(encoded), max_cost=subbatch_max_tokens 249 | ): 250 | subbatch = self.pad_encoded(subbatch_encoded) 251 | subbatch["batch_size"] = batch_size 252 | subbatch["batch_num_tokens"] = batch_num_tokens 253 | res.append((len(ids), subbatch)) 254 | return res 255 | 256 | def forward(self, batch): 257 | valid_token_mask = batch["valid_token_mask"].to(self.output_device) 258 | 259 | if ( 260 | self.encoder is not None 261 | and valid_token_mask.shape[1] > self.add_timing.timing_table.shape[0] 262 | ): 263 | raise ValueError( 264 | "Sentence of length {} exceeds the maximum supported length of " 265 | "{}".format( 266 | valid_token_mask.shape[1] - 2, 267 | self.add_timing.timing_table.shape[0] - 2, 268 | ) 269 | ) 270 | 271 | if self.char_encoder is not None: 272 | assert isinstance(self.char_encoder, char_lstm.CharacterLSTM) 273 | char_ids = batch["char_ids"].to(self.device) 274 | extra_content_annotations = self.char_encoder(char_ids, valid_token_mask) 275 | elif self.pretrained_model is not None: 276 | input_ids = batch["input_ids"].to(self.device) 277 | words_from_tokens = batch["words_from_tokens"].to(self.output_device) 278 | pretrained_attention_mask = batch["attention_mask"].to(self.device) 279 | 280 | extra_kwargs = {} 281 | if "token_type_ids" in batch: 282 | extra_kwargs["token_type_ids"] = batch["token_type_ids"].to(self.device) 283 | if "decoder_input_ids" in batch: 284 | extra_kwargs["decoder_input_ids"] = batch["decoder_input_ids"].to( 285 | self.device 286 | ) 287 | extra_kwargs["decoder_attention_mask"] = batch[ 288 | "decoder_attention_mask" 289 | ].to(self.device) 290 | 291 | pretrained_out = self.pretrained_model( 292 | input_ids, attention_mask=pretrained_attention_mask, **extra_kwargs 293 | ) 294 | features = pretrained_out.last_hidden_state.to(self.output_device) 295 | features = features[ 296 | torch.arange(features.shape[0])[:, None], 297 | # Note that words_from_tokens uses index -100 for invalid positions 298 | F.relu(words_from_tokens), 299 | ] 300 | features.masked_fill_(~valid_token_mask[:, :, None], 0) 301 | if self.encoder is not None: 302 | extra_content_annotations = self.project_pretrained(features) 303 | 304 | if self.encoder is not None: 305 | encoder_in = self.add_timing( 306 | self.morpho_emb_dropout(extra_content_annotations) 307 | ) 308 | 309 | annotations = self.encoder(encoder_in, valid_token_mask) 310 | # Rearrange the annotations to ensure that the transition to 311 | # fenceposts captures an even split between position and content. 312 | # TODO(nikita): try alternatives, such as omitting position entirely 313 | annotations = torch.cat( 314 | [ 315 | annotations[..., 0::2], 316 | annotations[..., 1::2], 317 | ], 318 | -1, 319 | ) 320 | else: 321 | assert self.pretrained_model is not None 322 | annotations = self.project_pretrained(features) 323 | 324 | if self.f_tag is not None: 325 | tag_scores = self.f_tag(annotations) 326 | else: 327 | tag_scores = None 328 | 329 | fencepost_annotations = torch.cat( 330 | [ 331 | annotations[:, :-1, : self.d_model // 2], 332 | annotations[:, 1:, self.d_model // 2 :], 333 | ], 334 | -1, 335 | ) 336 | 337 | # Note that the bias added to the final layer norm is useless because 338 | # this subtraction gets rid of it 339 | span_features = ( 340 | torch.unsqueeze(fencepost_annotations, 1) 341 | - torch.unsqueeze(fencepost_annotations, 2) 342 | )[:, :-1, 1:] 343 | span_scores = self.f_label(span_features) 344 | span_scores = torch.cat( 345 | [span_scores.new_zeros(span_scores.shape[:-1] + (1,)), span_scores], -1 346 | ) 347 | return span_scores, tag_scores 348 | 349 | def compute_loss(self, batch): 350 | span_scores, tag_scores = self.forward(batch) 351 | span_labels = batch["span_labels"].to(span_scores.device) 352 | span_loss = self.criterion(span_scores, span_labels) 353 | # Divide by the total batch size, not by the subbatch size 354 | span_loss = span_loss / batch["batch_size"] 355 | if tag_scores is None: 356 | return span_loss 357 | else: 358 | tag_labels = batch["tag_labels"].to(tag_scores.device) 359 | tag_loss = self.tag_loss_scale * F.cross_entropy( 360 | tag_scores.reshape((-1, tag_scores.shape[-1])), 361 | tag_labels.reshape((-1,)), 362 | reduction="sum", 363 | ignore_index=-100, 364 | ) 365 | tag_loss = tag_loss / batch["batch_num_tokens"] 366 | return span_loss + tag_loss 367 | 368 | def _parse_encoded( 369 | self, examples, encoded, return_compressed=False, return_scores=False 370 | ): 371 | with torch.no_grad(): 372 | batch = self.pad_encoded(encoded) 373 | span_scores, tag_scores = self.forward(batch) 374 | if return_scores: 375 | span_scores_np = span_scores.cpu().numpy() 376 | else: 377 | # Start/stop tokens don't count, so subtract 2 378 | lengths = batch["valid_token_mask"].sum(-1) - 2 379 | charts_np = self.decoder.charts_from_pytorch_scores_batched( 380 | span_scores, lengths.to(span_scores.device) 381 | ) 382 | if tag_scores is not None: 383 | tag_ids_np = tag_scores.argmax(-1).cpu().numpy() 384 | else: 385 | tag_ids_np = None 386 | 387 | for i in range(len(encoded)): 388 | example_len = len(examples[i].words) 389 | if return_scores: 390 | yield span_scores_np[i, :example_len, :example_len] 391 | elif return_compressed: 392 | output = self.decoder.compressed_output_from_chart(charts_np[i]) 393 | if tag_ids_np is not None: 394 | output = output.with_tags(tag_ids_np[i, 1 : example_len + 1]) 395 | yield output 396 | else: 397 | if tag_scores is None: 398 | leaves = examples[i].pos() 399 | else: 400 | predicted_tags = [ 401 | self.tag_from_index[i] 402 | for i in tag_ids_np[i, 1 : example_len + 1] 403 | ] 404 | leaves = [ 405 | (word, predicted_tag) 406 | for predicted_tag, (word, gold_tag) in zip( 407 | predicted_tags, examples[i].pos() 408 | ) 409 | ] 410 | yield self.decoder.tree_from_chart(charts_np[i], leaves=leaves) 411 | 412 | def parse( 413 | self, 414 | examples, 415 | return_compressed=False, 416 | return_scores=False, 417 | subbatch_max_tokens=None, 418 | ): 419 | training = self.training 420 | self.eval() 421 | encoded = [self.encode(example) for example in examples] 422 | if subbatch_max_tokens is not None: 423 | res = subbatching.map( 424 | self._parse_encoded, 425 | examples, 426 | encoded, 427 | costs=self._get_lens(encoded), 428 | max_cost=subbatch_max_tokens, 429 | return_compressed=return_compressed, 430 | return_scores=return_scores, 431 | ) 432 | else: 433 | res = self._parse_encoded( 434 | examples, 435 | encoded, 436 | return_compressed=return_compressed, 437 | return_scores=return_scores, 438 | ) 439 | res = list(res) 440 | self.train(training) 441 | return res 442 | -------------------------------------------------------------------------------- /src/benepar/partitioned_transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transformer with partitioned content and position features. 3 | """ 4 | 5 | import copy 6 | import math 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | 13 | class FeatureDropoutFunction(torch.autograd.function.InplaceFunction): 14 | @staticmethod 15 | def forward(ctx, input, p=0.5, train=False, inplace=False): 16 | if p < 0 or p > 1: 17 | raise ValueError( 18 | "dropout probability has to be between 0 and 1, but got {}".format(p) 19 | ) 20 | 21 | ctx.p = p 22 | ctx.train = train 23 | ctx.inplace = inplace 24 | 25 | if ctx.inplace: 26 | ctx.mark_dirty(input) 27 | output = input 28 | else: 29 | output = input.clone() 30 | 31 | if ctx.p > 0 and ctx.train: 32 | ctx.noise = torch.empty( 33 | (input.size(0), input.size(-1)), 34 | dtype=input.dtype, 35 | layout=input.layout, 36 | device=input.device, 37 | ) 38 | if ctx.p == 1: 39 | ctx.noise.fill_(0) 40 | else: 41 | ctx.noise.bernoulli_(1 - ctx.p).div_(1 - ctx.p) 42 | ctx.noise = ctx.noise[:, None, :] 43 | output.mul_(ctx.noise) 44 | 45 | return output 46 | 47 | @staticmethod 48 | def backward(ctx, grad_output): 49 | if ctx.p > 0 and ctx.train: 50 | return grad_output.mul(ctx.noise), None, None, None 51 | else: 52 | return grad_output, None, None, None 53 | 54 | 55 | class FeatureDropout(nn.Dropout): 56 | """ 57 | Feature-level dropout: takes an input of size len x num_features and drops 58 | each feature with probabibility p. A feature is dropped across the full 59 | portion of the input that corresponds to a single batch element. 60 | """ 61 | 62 | def forward(self, x): 63 | if isinstance(x, tuple): 64 | x_c, x_p = x 65 | x_c = FeatureDropoutFunction.apply(x_c, self.p, self.training, self.inplace) 66 | x_p = FeatureDropoutFunction.apply(x_p, self.p, self.training, self.inplace) 67 | return x_c, x_p 68 | else: 69 | return FeatureDropoutFunction.apply(x, self.p, self.training, self.inplace) 70 | 71 | 72 | class PartitionedReLU(nn.ReLU): 73 | def forward(self, x): 74 | if isinstance(x, tuple): 75 | x_c, x_p = x 76 | else: 77 | x_c, x_p = torch.chunk(x, 2, dim=-1) 78 | return super().forward(x_c), super().forward(x_p) 79 | 80 | 81 | class PartitionedLinear(nn.Module): 82 | def __init__(self, in_features, out_features, bias=True): 83 | super().__init__() 84 | self.linear_c = nn.Linear(in_features // 2, out_features // 2, bias) 85 | self.linear_p = nn.Linear(in_features // 2, out_features // 2, bias) 86 | 87 | def forward(self, x): 88 | if isinstance(x, tuple): 89 | x_c, x_p = x 90 | else: 91 | x_c, x_p = torch.chunk(x, 2, dim=-1) 92 | 93 | out_c = self.linear_c(x_c) 94 | out_p = self.linear_p(x_p) 95 | return out_c, out_p 96 | 97 | 98 | class PartitionedMultiHeadAttention(nn.Module): 99 | def __init__( 100 | self, d_model, n_head, d_qkv, attention_dropout=0.1, initializer_range=0.02 101 | ): 102 | super().__init__() 103 | 104 | self.w_qkv_c = nn.Parameter(torch.Tensor(n_head, d_model // 2, 3, d_qkv // 2)) 105 | self.w_qkv_p = nn.Parameter(torch.Tensor(n_head, d_model // 2, 3, d_qkv // 2)) 106 | self.w_o_c = nn.Parameter(torch.Tensor(n_head, d_qkv // 2, d_model // 2)) 107 | self.w_o_p = nn.Parameter(torch.Tensor(n_head, d_qkv // 2, d_model // 2)) 108 | 109 | bound = math.sqrt(3.0) * initializer_range 110 | for param in [self.w_qkv_c, self.w_qkv_p, self.w_o_c, self.w_o_p]: 111 | nn.init.uniform_(param, -bound, bound) 112 | self.scaling_factor = 1 / d_qkv ** 0.5 113 | 114 | self.dropout = nn.Dropout(attention_dropout) 115 | 116 | def forward(self, x, mask=None): 117 | if isinstance(x, tuple): 118 | x_c, x_p = x 119 | else: 120 | x_c, x_p = torch.chunk(x, 2, dim=-1) 121 | qkv_c = torch.einsum("btf,hfca->bhtca", x_c, self.w_qkv_c) 122 | qkv_p = torch.einsum("btf,hfca->bhtca", x_p, self.w_qkv_p) 123 | q_c, k_c, v_c = [c.squeeze(dim=3) for c in torch.chunk(qkv_c, 3, dim=3)] 124 | q_p, k_p, v_p = [c.squeeze(dim=3) for c in torch.chunk(qkv_p, 3, dim=3)] 125 | q = torch.cat([q_c, q_p], dim=-1) * self.scaling_factor 126 | k = torch.cat([k_c, k_p], dim=-1) 127 | v = torch.cat([v_c, v_p], dim=-1) 128 | dots = torch.einsum("bhqa,bhka->bhqk", q, k) 129 | if mask is not None: 130 | dots.data.masked_fill_(~mask[:, None, None, :], -float("inf")) 131 | probs = F.softmax(dots, dim=-1) 132 | probs = self.dropout(probs) 133 | o = torch.einsum("bhqk,bhka->bhqa", probs, v) 134 | o_c, o_p = torch.chunk(o, 2, dim=-1) 135 | out_c = torch.einsum("bhta,haf->btf", o_c, self.w_o_c) 136 | out_p = torch.einsum("bhta,haf->btf", o_p, self.w_o_p) 137 | return out_c, out_p 138 | 139 | 140 | class PartitionedTransformerEncoderLayer(nn.Module): 141 | def __init__( 142 | self, 143 | d_model, 144 | n_head, 145 | d_qkv, 146 | d_ff, 147 | ff_dropout=0.1, 148 | residual_dropout=0.1, 149 | attention_dropout=0.1, 150 | activation=PartitionedReLU(), 151 | ): 152 | super().__init__() 153 | self.self_attn = PartitionedMultiHeadAttention( 154 | d_model, n_head, d_qkv, attention_dropout=attention_dropout 155 | ) 156 | self.linear1 = PartitionedLinear(d_model, d_ff) 157 | self.ff_dropout = FeatureDropout(ff_dropout) 158 | self.linear2 = PartitionedLinear(d_ff, d_model) 159 | 160 | self.norm_attn = nn.LayerNorm(d_model) 161 | self.norm_ff = nn.LayerNorm(d_model) 162 | self.residual_dropout_attn = FeatureDropout(residual_dropout) 163 | self.residual_dropout_ff = FeatureDropout(residual_dropout) 164 | 165 | self.activation = activation 166 | 167 | def forward(self, x, mask=None): 168 | residual = self.self_attn(x, mask=mask) 169 | residual = torch.cat(residual, dim=-1) 170 | residual = self.residual_dropout_attn(residual) 171 | x = self.norm_attn(x + residual) 172 | residual = self.linear2(self.ff_dropout(self.activation(self.linear1(x)))) 173 | residual = torch.cat(residual, dim=-1) 174 | residual = self.residual_dropout_ff(residual) 175 | x = self.norm_ff(x + residual) 176 | return x 177 | 178 | 179 | class PartitionedTransformerEncoder(nn.Module): 180 | def __init__(self, encoder_layer, n_layers): 181 | super().__init__() 182 | self.layers = nn.ModuleList( 183 | [copy.deepcopy(encoder_layer) for i in range(n_layers)] 184 | ) 185 | 186 | def forward(self, x, mask=None): 187 | for layer in self.layers: 188 | x = layer(x, mask=mask) 189 | return x 190 | 191 | 192 | class ConcatPositionalEncoding(nn.Module): 193 | def __init__(self, d_model=256, max_len=512): 194 | super().__init__() 195 | self.timing_table = nn.Parameter(torch.FloatTensor(max_len, d_model // 2)) 196 | nn.init.normal_(self.timing_table) 197 | self.norm = nn.LayerNorm(d_model) 198 | 199 | def forward(self, x): 200 | timing = self.timing_table[None, : x.shape[1], :] 201 | x, timing = torch.broadcast_tensors(x, timing) 202 | out = torch.cat([x, timing], dim=-1) 203 | out = self.norm(out) 204 | return out 205 | -------------------------------------------------------------------------------- /src/benepar/ptb_unescape.py: -------------------------------------------------------------------------------- 1 | PTB_UNESCAPE_MAPPING = { 2 | "«": '"', 3 | "»": '"', 4 | "‘": "'", 5 | "’": "'", 6 | "“": '"', 7 | "”": '"', 8 | "„": '"', 9 | "‹": "'", 10 | "›": "'", 11 | "\u2013": "--", # en dash 12 | "\u2014": "--", # em dash 13 | } 14 | 15 | NO_SPACE_BEFORE = {"-RRB-", "-RCB-", "-RSB-", "''"} | set("%.,!?:;") 16 | NO_SPACE_AFTER = {"-LRB-", "-LCB-", "-LSB-", "``", "`"} | set("$#") 17 | NO_SPACE_BEFORE_TOKENS_ENGLISH = {"'", "'s", "'ll", "'re", "'d", "'m", "'ve"} 18 | PTB_DASH_ESCAPED = {"-RRB-", "-RCB-", "-RSB-", "-LRB-", "-LCB-", "-LSB-", "--"} 19 | 20 | 21 | def ptb_unescape(words): 22 | cleaned_words = [] 23 | for word in words: 24 | word = PTB_UNESCAPE_MAPPING.get(word, word) 25 | # This un-escaping for / and * was not yet added for the 26 | # parser version in https://arxiv.org/abs/1812.11760v1 27 | # and related model releases (e.g. benepar_en2) 28 | word = word.replace("\\/", "/").replace("\\*", "*") 29 | # Mid-token punctuation occurs in biomedical text 30 | word = word.replace("-LSB-", "[").replace("-RSB-", "]") 31 | word = word.replace("-LRB-", "(").replace("-RRB-", ")") 32 | word = word.replace("-LCB-", "{").replace("-RCB-", "}") 33 | word = word.replace("``", '"').replace("`", "'").replace("''", '"') 34 | cleaned_words.append(word) 35 | return cleaned_words 36 | 37 | 38 | def guess_space_after_non_english(escaped_words): 39 | sp_after = [True for _ in escaped_words] 40 | for i, word in enumerate(escaped_words): 41 | if i > 0 and ( 42 | ( 43 | word.startswith("-") 44 | and not any(word.startswith(x) for x in PTB_DASH_ESCAPED) 45 | ) 46 | or any(word.startswith(x) for x in NO_SPACE_BEFORE) 47 | or word == "'" 48 | ): 49 | sp_after[i - 1] = False 50 | if ( 51 | word.endswith("-") and not any(word.endswith(x) for x in PTB_DASH_ESCAPED) 52 | ) or any(word.endswith(x) for x in NO_SPACE_AFTER): 53 | sp_after[i] = False 54 | 55 | return sp_after 56 | 57 | 58 | def guess_space_after(escaped_words, for_english=True): 59 | if not for_english: 60 | return guess_space_after_non_english(escaped_words) 61 | 62 | sp_after = [True for _ in escaped_words] 63 | for i, word in enumerate(escaped_words): 64 | if word.lower() == "n't" and i > 0: 65 | sp_after[i - 1] = False 66 | elif word.lower() == "not" and i > 0 and escaped_words[i - 1].lower() == "can": 67 | sp_after[i - 1] = False 68 | 69 | if i > 0 and ( 70 | ( 71 | word.startswith("-") 72 | and not any(word.startswith(x) for x in PTB_DASH_ESCAPED) 73 | ) 74 | or any(word.startswith(x) for x in NO_SPACE_BEFORE) 75 | or word.lower() in NO_SPACE_BEFORE_TOKENS_ENGLISH 76 | ): 77 | sp_after[i - 1] = False 78 | if ( 79 | word.endswith("-") and not any(word.endswith(x) for x in PTB_DASH_ESCAPED) 80 | ) or any(word.endswith(x) for x in NO_SPACE_AFTER): 81 | sp_after[i] = False 82 | 83 | return sp_after 84 | -------------------------------------------------------------------------------- /src/benepar/retokenization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Converts from linguistically motivated word-based tokenization to subword 3 | tokenization used by pre-trained models. 4 | """ 5 | 6 | import numpy as np 7 | import torch 8 | import transformers 9 | 10 | 11 | def retokenize( 12 | tokenizer, 13 | words, 14 | space_after, 15 | return_attention_mask=True, 16 | return_offsets_mapping=False, 17 | return_tensors=None, 18 | **kwargs 19 | ): 20 | """Re-tokenize into subwords. 21 | 22 | Args: 23 | tokenizer: An instance of transformers.PreTrainedTokenizerFast 24 | words: List of words 25 | space_after: A list of the same length as `words`, indicating whether 26 | whitespace follows each word. 27 | **kwargs: all remaining arguments are passed on to tokenizer.__call__ 28 | 29 | Returns: 30 | The output of tokenizer.__call__, with one additional dictionary field: 31 | - **words_from_tokens** -- List of the same length as `words`, where 32 | each entry is the index of the *last* subword that overlaps the 33 | corresponding word. 34 | """ 35 | s = "".join([w + (" " if sp else "") for w, sp in zip(words, space_after)]) 36 | word_offset_starts = np.cumsum( 37 | [0] + [len(w) + (1 if sp else 0) for w, sp in zip(words, space_after)] 38 | )[:-1] 39 | word_offset_ends = word_offset_starts + np.asarray([len(w) for w in words]) 40 | 41 | tokenized = tokenizer( 42 | s, 43 | return_attention_mask=return_attention_mask, 44 | return_offsets_mapping=True, 45 | return_tensors=return_tensors, 46 | **kwargs 47 | ) 48 | if return_offsets_mapping: 49 | token_offset_mapping = tokenized["offset_mapping"] 50 | else: 51 | token_offset_mapping = tokenized.pop("offset_mapping") 52 | if return_tensors is not None: 53 | token_offset_mapping = np.asarray(token_offset_mapping)[0].tolist() 54 | 55 | offset_mapping_iter = iter( 56 | [ 57 | (i, (start, end)) 58 | for (i, (start, end)) in enumerate(token_offset_mapping) 59 | if start != end 60 | ] 61 | ) 62 | token_idx, (token_start, token_end) = next(offset_mapping_iter) 63 | words_from_tokens = [-100] * len(words) 64 | for word_idx, (word_start, word_end) in enumerate( 65 | zip(word_offset_starts, word_offset_ends) 66 | ): 67 | while token_end <= word_start: 68 | token_idx, (token_start, token_end) = next(offset_mapping_iter) 69 | if token_end > word_end: 70 | words_from_tokens[word_idx] = token_idx 71 | while token_end <= word_end: 72 | words_from_tokens[word_idx] = token_idx 73 | try: 74 | token_idx, (token_start, token_end) = next(offset_mapping_iter) 75 | except StopIteration: 76 | assert word_idx == len(words) - 1 77 | break 78 | if return_tensors == "np": 79 | words_from_tokens = np.asarray(words_from_tokens, dtype=int) 80 | elif return_tensors == "pt": 81 | words_from_tokens = torch.tensor(words_from_tokens, dtype=torch.long) 82 | elif return_tensors == "tf": 83 | raise NotImplementedError("Returning tf tensors is not implemented") 84 | tokenized["words_from_tokens"] = words_from_tokens 85 | return tokenized 86 | 87 | 88 | class Retokenizer: 89 | def __init__(self, pretrained_model_name_or_path, retain_start_stop=False): 90 | self.tokenizer = transformers.AutoTokenizer.from_pretrained( 91 | pretrained_model_name_or_path, fast=True 92 | ) 93 | if not self.tokenizer.is_fast: 94 | raise NotImplementedError( 95 | "Converting from treebank tokenization to tokenization used by a " 96 | "pre-trained model requires a 'fast' tokenizer, which appears to not " 97 | "be available for this pre-trained model type." 98 | ) 99 | self.retain_start_stop = retain_start_stop 100 | self.is_t5 = "T5Tokenizer" in str(type(self.tokenizer)) 101 | self.is_gpt2 = "GPT2Tokenizer" in str(type(self.tokenizer)) 102 | 103 | if self.is_gpt2: 104 | # The provided GPT-2 tokenizer does not specify a padding token by default 105 | self.tokenizer.pad_token = self.tokenizer.eos_token 106 | 107 | if self.retain_start_stop: 108 | # When retain_start_stop is set, the next layer after the pre-trained model 109 | # expects start and stop token embeddings. For BERT these can naturally be 110 | # the feature vectors for CLS and SEP, but pre-trained models differ in the 111 | # special tokens that they use. This code attempts to find special token 112 | # positions for each pre-trained model. 113 | dummy_ids = self.tokenizer.build_inputs_with_special_tokens([-100]) 114 | if self.is_t5: 115 | # For T5 we use the output from the decoder, which accepts inputs that 116 | # are shifted relative to the encoder. 117 | dummy_ids = [self.tokenizer.pad_token_id] + dummy_ids 118 | if self.is_gpt2: 119 | # For GPT-2, we append an eos token if special tokens are needed 120 | dummy_ids = dummy_ids + [self.tokenizer.eos_token_id] 121 | try: 122 | input_idx = dummy_ids.index(-100) 123 | except ValueError: 124 | raise NotImplementedError( 125 | "Could not automatically infer how to extract start/stop tokens " 126 | "from this pre-trained model" 127 | ) 128 | num_prefix_tokens = input_idx 129 | num_suffix_tokens = len(dummy_ids) - input_idx - 1 130 | self.start_token_idx = None 131 | self.stop_token_idx = None 132 | if num_prefix_tokens > 0: 133 | self.start_token_idx = num_prefix_tokens - 1 134 | if num_suffix_tokens > 0: 135 | self.stop_token_idx = -num_suffix_tokens 136 | if self.start_token_idx is None and num_suffix_tokens > 0: 137 | self.start_token_idx = -1 138 | if self.stop_token_idx is None and num_prefix_tokens > 0: 139 | self.stop_token_idx = 0 140 | if self.start_token_idx is None or self.stop_token_idx is None: 141 | assert num_prefix_tokens == 0 and num_suffix_tokens == 0 142 | raise NotImplementedError( 143 | "Could not automatically infer how to extract start/stop tokens " 144 | "from this pre-trained model because the associated tokenizer " 145 | "appears not to add any special start/stop/cls/sep/etc. tokens " 146 | "to the sequence." 147 | ) 148 | 149 | def __call__(self, words, space_after, **kwargs): 150 | example = retokenize(self.tokenizer, words, space_after, **kwargs) 151 | if self.is_t5: 152 | # decoder_input_ids (which are shifted wrt input_ids) will be created after 153 | # padding, but we adjust words_from_tokens now, in anticipation. 154 | if isinstance(example["words_from_tokens"], list): 155 | example["words_from_tokens"] = [ 156 | x + 1 for x in example["words_from_tokens"] 157 | ] 158 | else: 159 | example["words_from_tokens"] += 1 160 | if self.retain_start_stop: 161 | num_tokens = len(example["input_ids"]) 162 | if self.is_t5: 163 | num_tokens += 1 164 | if self.is_gpt2: 165 | num_tokens += 1 166 | if kwargs.get("return_tensors") == "pt": 167 | example["input_ids"] = torch.cat( 168 | example["input_ids"], 169 | torch.tensor([self.tokenizer.eos_token_id]), 170 | ) 171 | example["attention_mask"] = torch.cat( 172 | example["attention_mask"], torch.tensor([1]) 173 | ) 174 | else: 175 | example["input_ids"].append(self.tokenizer.eos_token_id) 176 | example["attention_mask"].append(1) 177 | if num_tokens > self.tokenizer.model_max_length: 178 | raise ValueError( 179 | f"Sentence of length {num_tokens} (in sub-word tokens) exceeds the " 180 | f"maximum supported length of {self.tokenizer.model_max_length}" 181 | ) 182 | start_token_idx = ( 183 | self.start_token_idx 184 | if self.start_token_idx >= 0 185 | else num_tokens + self.start_token_idx 186 | ) 187 | stop_token_idx = ( 188 | self.stop_token_idx 189 | if self.stop_token_idx >= 0 190 | else num_tokens + self.stop_token_idx 191 | ) 192 | if kwargs.get("return_tensors") == "pt": 193 | example["words_from_tokens"] = torch.cat( 194 | [ 195 | torch.tensor([start_token_idx]), 196 | example["words_from_tokens"], 197 | torch.tensor([stop_token_idx]), 198 | ] 199 | ) 200 | else: 201 | example["words_from_tokens"] = ( 202 | [start_token_idx] + example["words_from_tokens"] + [stop_token_idx] 203 | ) 204 | return example 205 | 206 | def pad(self, encoded_inputs, return_tensors=None, **kwargs): 207 | if return_tensors != "pt": 208 | raise NotImplementedError("Only return_tensors='pt' is supported.") 209 | res = self.tokenizer.pad( 210 | [ 211 | {k: v for k, v in example.items() if k != "words_from_tokens"} 212 | for example in encoded_inputs 213 | ], 214 | return_tensors=return_tensors, 215 | **kwargs 216 | ) 217 | if self.tokenizer.padding_side == "right": 218 | res["words_from_tokens"] = torch.nn.utils.rnn.pad_sequence( 219 | [ 220 | torch.tensor(example["words_from_tokens"]) 221 | for example in encoded_inputs 222 | ], 223 | batch_first=True, 224 | padding_value=-100, 225 | ) 226 | else: 227 | # XLNet adds padding tokens on the left of the sequence, so 228 | # words_from_tokens must be adjusted to skip the added padding tokens. 229 | assert self.tokenizer.padding_side == "left" 230 | res["words_from_tokens"] = torch.nn.utils.rnn.pad_sequence( 231 | [ 232 | torch.tensor(example["words_from_tokens"]) 233 | + (res["input_ids"].shape[-1] - len(example["input_ids"])) 234 | for example in encoded_inputs 235 | ], 236 | batch_first=True, 237 | padding_value=-100, 238 | ) 239 | 240 | if self.is_t5: 241 | res["decoder_input_ids"] = torch.cat( 242 | [ 243 | torch.full_like( 244 | res["input_ids"][:, :1], self.tokenizer.pad_token_id 245 | ), 246 | res["input_ids"], 247 | ], 248 | 1, 249 | ) 250 | res["decoder_attention_mask"] = torch.cat( 251 | [ 252 | torch.ones_like(res["attention_mask"][:, :1]), 253 | res["attention_mask"], 254 | ], 255 | 1, 256 | ) 257 | res["valid_token_mask"] = res["words_from_tokens"] != -100 258 | return res 259 | -------------------------------------------------------------------------------- /src/benepar/spacy_plugin.py: -------------------------------------------------------------------------------- 1 | __all__ = ["BeneparComponent", "NonConstituentException"] 2 | 3 | import warnings 4 | 5 | from .integrations.spacy_plugin import BeneparComponent, NonConstituentException 6 | 7 | warnings.warn( 8 | "BeneparComponent and NonConstituentException have been moved to the benepar " 9 | "module. Use `from benepar import BeneparComponent, NonConstituentException` " 10 | "instead of benepar.spacy_plugin. The benepar.spacy_plugin namespace is deprecated " 11 | "and will be removed in a future version.", 12 | FutureWarning, 13 | ) 14 | -------------------------------------------------------------------------------- /src/benepar/subbatching.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for splitting batches of examples into smaller sub-batches. 3 | 4 | This is useful during training when the batch size is too large to fit on GPU, 5 | meaning that gradient accumulation across multiple sub-batches must be used. 6 | It is also useful for batching examples during evaluation. Unlike a naive 7 | approach, this code groups examples with similar lengths to reduce the amount 8 | of wasted computation due to padding. 9 | """ 10 | 11 | import numpy as np 12 | 13 | 14 | def split(*data, costs, max_cost): 15 | """Splits a batch of input items into sub-batches. 16 | 17 | Args: 18 | *data: One or more lists of input items, all of the same length 19 | costs: A list of costs for each item 20 | max_cost: Maximum total cost for each sub-batch 21 | 22 | Yields: 23 | (example_ids, *subbatch_data) tuples. 24 | """ 25 | costs = np.asarray(costs, dtype=int) 26 | costs_argsort = np.argsort(costs).tolist() 27 | 28 | subbatch_size = 1 29 | while costs_argsort: 30 | if subbatch_size == len(costs_argsort) or ( 31 | subbatch_size * costs[costs_argsort[subbatch_size]] > max_cost 32 | ): 33 | subbatch_item_ids = costs_argsort[:subbatch_size] 34 | subbatch_data = [[items[i] for i in subbatch_item_ids] for items in data] 35 | yield (subbatch_item_ids,) + tuple(subbatch_data) 36 | costs_argsort = costs_argsort[subbatch_size:] 37 | subbatch_size = 1 38 | else: 39 | subbatch_size += 1 40 | 41 | 42 | def map(func, *data, costs, max_cost, **common_kwargs): 43 | """Maps a function over subbatches of input items. 44 | 45 | Args: 46 | func: Function to map over the data 47 | *data: One or more lists of input items, all of the same length. 48 | costs: A list of costs for each item 49 | max_cost: Maximum total cost for each sub-batch 50 | **common_kwargs: Keyword arguments to pass to all calls of func 51 | 52 | Returns: 53 | A list of outputs from calling func(*subbatch_data, **kwargs) for each 54 | subbatch, and then rearranging the outputs from func into the original 55 | item order. 56 | """ 57 | res = [None] * len(data[0]) 58 | for item_ids, *subbatch_items in split(*data, costs=costs, max_cost=max_cost): 59 | subbatch_out = func(*subbatch_items, **common_kwargs) 60 | for item_id, item_out in zip(item_ids, subbatch_out): 61 | res[item_id] = item_out 62 | return res 63 | -------------------------------------------------------------------------------- /src/count_fscore.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | 4 | def remove_top(a): 5 | a = a.replace('(TOP (S ', '') 6 | a = a[::-1].replace('))', '', 1)[::-1] 7 | return a 8 | 9 | def replace_n(a, i, j, num): 10 | a = a.replace(i, '*', num) 11 | a = a.replace('*', i, num-1) 12 | a = a.replace('*', j) 13 | return a 14 | 15 | 16 | def replace1(a): 17 | a = re.sub('\n', '', a) 18 | for i in range(len(a)): 19 | num_left = 0 20 | num_right = 0 21 | flag = 0 22 | for j in range(len(a)): 23 | if a[j] == '(' : 24 | num_left += 1 25 | if a[j+1] == 'S' and flag == 0: 26 | b = a[j+1] 27 | flag = 1 28 | if a[j+1] == '#' and flag == 0: 29 | b = a[j+1] + a[j+2] 30 | flag = 1 31 | # print('mmmmm',b) 32 | elif a[j] == ')' : 33 | num_right += 1 34 | if num_right == num_left and a[j-1] == ')': 35 | # print(num_right, b) 36 | # print('mmmmmmm:',a) 37 | a = replace_n(a, ')', b, num_left) 38 | a = a.replace('('+b, '', 1) 39 | break 40 | return a 41 | 42 | 43 | def add_seg(a): 44 | a = re.sub(u"[\u4e00-\u9fa5]+", '*', a) 45 | a = re.sub(r'[a-zA-Z]+', '*', a) 46 | a = a.replace('(* *)', 'W') 47 | a = re.sub(r'[^0-9A-Za-z]+', '', a) 48 | return a 49 | 50 | 51 | def format_conversion_tree2prosody(data_path): 52 | 53 | line_sen_list = [] 54 | 55 | with open(data_path, 'r', encoding='utf-8') as f: 56 | lines = f.readlines() 57 | for line in lines: 58 | sen_token_list = [] 59 | 60 | if line != '\n' and line != '': 61 | 62 | ss = line 63 | sss = remove_top(ss) 64 | sss = replace1(sss) 65 | sen_token_list.append(sss) 66 | 67 | ssss = add_seg(sss) 68 | sen_token_list.append('|'+ssss+'\n') 69 | line_sen_list.append(''.join(sen_token_list)) 70 | f.close() 71 | 72 | with open(data_path,'w+', encoding='utf-8') as o: 73 | o.write(''.join(line_sen_list)) 74 | o.close() 75 | 76 | 77 | ################################################################################################################ 78 | 79 | def replace(data_path): 80 | ''' 81 | sen_list output: ['102011202110210203121102113', '0102121010301110103', '011210212013', '121311210211311102113', '1103110213'] 82 | ''' 83 | sen_list = [] 84 | s00_list = [] 85 | s0_list = [] 86 | with open(data_path, 'r', encoding='utf-8') as t: 87 | lines = t.readlines() 88 | num0 = 0 89 | for line in lines: 90 | seg = line.split('|') 91 | 92 | s00 = seg[0] 93 | 94 | s0 = seg[1] 95 | 96 | s = re.sub('\n', '', s0) 97 | 98 | 99 | compileX = re.compile(r'\d+') 100 | num_result = compileX.findall(s) 101 | for i in num_result: 102 | s = re.sub(i, max(i), s, 1) 103 | 104 | s = re.sub('W', '0', s) 105 | s = re.sub('01', '1', s) 106 | s = re.sub('02', '2', s) 107 | s = re.sub('03', '3', s) 108 | 109 | sen_list.append(s) 110 | s00_list.append(s00) 111 | s0_list.append(s0) 112 | t.close() 113 | return sen_list, s00_list, s0_list 114 | 115 | 116 | 117 | def score(TP, FP, FN): 118 | if TP + FP == 0: 119 | precision = 0.01 120 | else: 121 | precision = TP / (TP + FP) 122 | if TP + FN == 0: 123 | recall = 0.01 124 | else: 125 | recall = TP / (TP + FN) 126 | f1score = 2 * precision * recall / (precision + recall) 127 | return precision, recall, f1score 128 | 129 | def count(gold_path, predicted_path): 130 | 131 | format_conversion_tree2prosody(gold_path) 132 | format_conversion_tree2prosody(predicted_path) 133 | 134 | test_sen_list ,test_s00_list, test_s0_list = replace(gold_path) 135 | predicted_sen_list , predicted_s00_list, predicted_s0_list = replace(predicted_path) 136 | 137 | # a12: test 1, predicted 2 138 | a00 = a01 = a02 = a03 = a10 = a11 = a12 = a13 = a20 = a21 = a22 = a23 = a30 = a31 = a32 = a33 = 0 139 | num = 0 140 | num_match_sen = 0 141 | for i in range(len(test_sen_list)): 142 | t = test_sen_list[i] 143 | p = predicted_sen_list[i] 144 | 145 | if t == p: 146 | num_match_sen += 1 147 | 148 | 149 | if len(t) != len(p): 150 | num += 1 151 | print(num, '\n', t, test_s00_list[i], test_s0_list[i], '\n', p ,predicted_s00_list[i], predicted_s0_list[i]) 152 | else: 153 | for j in range(len(t)): 154 | if t[j] == '0': 155 | if p[j] == '0': 156 | a00 += 1 157 | if p[j] == '1': 158 | a01 += 1 159 | if p[j] == '2': 160 | a02 += 1 161 | if p[j] == '3': 162 | a03 += 1 163 | if t[j] == '1': 164 | if p[j] == '0': 165 | a10 += 1 166 | if p[j] == '1': 167 | a11 += 1 168 | if p[j] == '2': 169 | a12 += 1 170 | if p[j] == '3': 171 | a13 += 1 172 | if t[j] == '2': 173 | if p[j] == '0': 174 | a20 += 1 175 | if p[j] == '1': 176 | a21 += 1 177 | if p[j] == '2': 178 | a22 += 1 179 | if p[j] == '3': 180 | a23 += 1 181 | if t[j] == '3': 182 | if p[j] == '0': 183 | a30 += 1 184 | if p[j] == '1': 185 | a31 += 1 186 | if p[j] == '2': 187 | a32 += 1 188 | if p[j] == '3': 189 | a33 += 1 190 | 191 | precision1, recall1, fscore1 = score(a11 + a12 + a13 + a21 + a22 + a23 + a31 + a32 + a33, a01 + a02 + a03 , a10 + a20 + a30) 192 | precision2, recall2, fscore2 = score(a22 + a23 + a32 + a33, a02 + a03 + a12 + a13, a20 + a21 + a30 + a31) 193 | precision3, recall3, fscore3 = score(a33, a03 + a13 + a23, a30 + a31 + a32) 194 | precision = float((precision1 + precision2 + precision3) *100 /3 ) 195 | recall = float((recall1 + recall2 + recall3) *100 /3 ) 196 | fscore = float((fscore1 + fscore2 + fscore3) *100 /3 ) 197 | 198 | 199 | completematch = float(100 * num_match_sen/len(test_sen_list)) 200 | 201 | print('PW:',precision1, recall1, fscore1) 202 | print('PPH:',precision2, recall2, fscore2) 203 | print('IPH:',precision3, recall3, fscore3) 204 | 205 | 206 | return recall, precision, fscore, completematch 207 | 208 | -------------------------------------------------------------------------------- /src/evaluate.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os.path 3 | import re 4 | import subprocess 5 | import tempfile 6 | 7 | import nltk 8 | 9 | 10 | class FScore(object): 11 | def __init__(self, recall, precision, fscore, complete_match, tagging_accuracy=100): 12 | self.recall = recall 13 | self.precision = precision 14 | self.fscore = fscore 15 | self.complete_match = complete_match 16 | self.tagging_accuracy = tagging_accuracy 17 | 18 | def __str__(self): 19 | return ( 20 | f"(Recall={self.recall:.2f}, " 21 | f"Precision={self.precision:.2f}, " 22 | f"FScore={self.fscore:.2f}, " 23 | f"CompleteMatch={self.complete_match:.2f}" 24 | ) + ( 25 | f", TaggingAccuracy={self.tagging_accuracy:.2f})" 26 | if self.tagging_accuracy < 100 27 | else ")" 28 | ) 29 | 30 | 31 | def evalb(evalb_dir, gold_trees, predicted_trees, ref_gold_path=None): 32 | 33 | 34 | assert len(gold_trees) == len(predicted_trees) 35 | for gold_tree, predicted_tree in zip(gold_trees, predicted_trees): 36 | assert isinstance(gold_tree, nltk.Tree) 37 | assert isinstance(predicted_tree, nltk.Tree) 38 | gold_leaves = list(gold_tree.leaves()) 39 | predicted_leaves = list(predicted_tree.leaves()) 40 | assert len(gold_leaves) == len(predicted_leaves) 41 | assert all( 42 | gold_word == predicted_word 43 | for gold_word, predicted_word in zip(gold_leaves, predicted_leaves) 44 | ) 45 | 46 | temp_dir = tempfile.TemporaryDirectory(prefix="evalb-") 47 | 48 | gold_path = os.path.join(temp_dir.name, "gold.txt") 49 | predicted_path = os.path.join(temp_dir.name, "predicted.txt") 50 | output_path = os.path.join(temp_dir.name, "output.txt") 51 | 52 | 53 | 54 | with open(gold_path, "w") as outfile: 55 | if ref_gold_path is None: 56 | for tree in gold_trees: 57 | outfile.write("{}\n".format(tree.pformat(margin=1e100))) 58 | else: 59 | 60 | with open(ref_gold_path) as goldfile: 61 | outfile.write(goldfile.read()) 62 | 63 | with open(predicted_path, "w") as outfile: 64 | for tree in predicted_trees: 65 | 66 | outfile.write("{}\n".format(tree.pformat(margin=1e100))) 67 | 68 | import count_fscore 69 | 70 | fscore = FScore(math.nan, math.nan, math.nan, math.nan) 71 | fscore.recall, fscore.precision, fscore.fscore, fscore.complete_match = count_fscore.count(gold_path, predicted_path) 72 | temp_dir.cleanup() 73 | 74 | return fscore 75 | -------------------------------------------------------------------------------- /src/export.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import functools 3 | import itertools 4 | import os.path 5 | import time 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | import numpy as np 11 | 12 | import evaluate 13 | import treebanks 14 | 15 | from benepar import Parser, InputSentence 16 | from benepar.partitioned_transformer import PartitionedMultiHeadAttention 17 | 18 | import json 19 | 20 | 21 | def format_elapsed(start_time): 22 | elapsed_time = int(time.time() - start_time) 23 | minutes, seconds = divmod(elapsed_time, 60) 24 | hours, minutes = divmod(minutes, 60) 25 | days, hours = divmod(hours, 24) 26 | elapsed_string = "{}h{:02}m{:02}s".format(hours, minutes, seconds) 27 | if days > 0: 28 | elapsed_string = "{}d{}".format(days, elapsed_string) 29 | return elapsed_string 30 | 31 | 32 | def inputs_from_treebank(treebank, predict_tags): 33 | return [ 34 | InputSentence( 35 | words=example.words, 36 | space_after=example.space_after, 37 | tags=None if predict_tags else [tag for _, tag in example.pos()], 38 | escaped_words=list(example.leaves()), 39 | ) 40 | for example in treebank 41 | ] 42 | 43 | 44 | def run_test(args): 45 | print("Loading test trees from {}...".format(args.test_path)) 46 | test_treebank = treebanks.load_trees( 47 | args.test_path, args.test_path_text, args.text_processing 48 | ) 49 | print("Loaded {:,} test examples.".format(len(test_treebank))) 50 | 51 | print("Loading model from {}...".format(args.model_path)) 52 | parser = Parser(args.model_path, batch_size=args.batch_size) 53 | 54 | print("Parsing test sentences...") 55 | start_time = time.time() 56 | 57 | if args.output_path == "-": 58 | output_file = sys.stdout 59 | elif args.output_path: 60 | output_file = open(args.output_path, "w") 61 | else: 62 | output_file = None 63 | 64 | test_predicted = [] 65 | for predicted_tree in parser.parse_sents( 66 | inputs_from_treebank(test_treebank, predict_tags=args.predict_tags) 67 | ): 68 | test_predicted.append(predicted_tree) 69 | if output_file is not None: 70 | print(tree.pformat(margin=1e100), file=output_file) 71 | 72 | test_fscore = evaluate.evalb(args.evalb_dir, test_treebank.trees, test_predicted) 73 | 74 | print( 75 | "test-fscore {} " 76 | "test-elapsed {}".format( 77 | test_fscore, 78 | format_elapsed(start_time), 79 | ) 80 | ) 81 | 82 | 83 | def get_compressed_state_dict(model): 84 | state_dict = model.state_dict() 85 | for module_name, module in model.named_modules(): 86 | if not isinstance( 87 | module, (nn.Linear, nn.Embedding, PartitionedMultiHeadAttention) 88 | ): 89 | continue 90 | elif "token_type_embeddings" in module_name: 91 | continue 92 | elif "position_embeddings" in module_name: 93 | continue 94 | elif "f_tag" in module_name or "f_label" in module_name: 95 | continue 96 | elif "project_pretrained" in module_name: 97 | continue 98 | 99 | if isinstance(module, PartitionedMultiHeadAttention): 100 | weight_names = [ 101 | module_name + "." + param 102 | for param in ("w_qkv_c", "w_qkv_p", "w_o_c", "w_o_p") 103 | ] 104 | else: 105 | weight_names = [module_name + ".weight"] 106 | for weight_name in weight_names: 107 | weight = state_dict[weight_name] 108 | if weight.shape.numel() <= 2048: 109 | continue 110 | print(weight_name, ":", weight.shape.numel(), "parameters") 111 | 112 | if isinstance(module, nn.Embedding) or "word_embeddings" in module_name or "shared.weight" in weight_name: 113 | is_embedding = True 114 | else: 115 | is_embedding = False 116 | 117 | num_steps = 64 118 | use_histogram = True 119 | if "pooler.dense.weight" in weight_name: 120 | weight.data.zero_() 121 | continue 122 | elif "pretrained_model" in weight_name and not is_embedding: 123 | num_steps = 128 124 | if not model.retokenizer.is_t5: 125 | use_histogram = False 126 | elif isinstance(module, PartitionedMultiHeadAttention): 127 | num_steps = 128 128 | 129 | if use_histogram: 130 | observer = torch.quantization.HistogramObserver() 131 | observer.dst_nbins = num_steps 132 | observer(weight) 133 | scale, zero_point = observer.calculate_qparams() 134 | scale = scale.item() 135 | zero_point = zero_point.item() 136 | cluster_centers = ( 137 | scale * (np.arange(0, 256, 256 / num_steps) - zero_point)[:, None] 138 | ) 139 | cluster_centers = np.asarray(cluster_centers, dtype=np.float32) 140 | else: 141 | weight_np = weight.cpu().detach().numpy() 142 | min_val = weight_np.min() 143 | max_val = weight_np.max() 144 | bucket_width = (max_val - min_val) / num_steps 145 | cluster_centers = ( 146 | min_val 147 | + (np.arange(num_steps, dtype=np.float32) + 0.5) * bucket_width 148 | ) 149 | cluster_centers = cluster_centers.reshape((-1, 1)) 150 | 151 | codebook = torch.tensor( 152 | cluster_centers, dtype=weight.dtype, device=weight.device 153 | ) 154 | distances = weight.data.reshape((-1, 1)) - codebook.t() 155 | codes = torch.argmin(distances ** 2, dim=-1) 156 | weight_rounded = codebook[codes].reshape(weight.shape) 157 | weight.data.copy_(weight_rounded) 158 | 159 | return state_dict 160 | 161 | 162 | def run_export(args): 163 | if args.test_path is not None: 164 | print("Loading test trees from {}...".format(args.test_path)) 165 | test_treebank = treebanks.load_trees( 166 | args.test_path, args.test_path_text, args.text_processing 167 | ) 168 | print("Loaded {:,} test examples.".format(len(test_treebank))) 169 | else: 170 | test_treebank = None 171 | 172 | print("Loading model from {}...".format(args.model_path)) 173 | parser = Parser(args.model_path, batch_size=args.batch_size) 174 | model = parser._parser 175 | if model.pretrained_model is None: 176 | raise ValueError( 177 | "Exporting is only defined when using a pre-trained transformer " 178 | "encoder. For CharLSTM-based model, just distribute the pytorch " 179 | "checkpoint directly. You may manually delete the 'optimizer' " 180 | "field to reduce file size by discarding the optimizer state." 181 | ) 182 | 183 | if test_treebank is not None: 184 | print("Parsing test sentences (predicting tags)...") 185 | start_time = time.time() 186 | test_inputs = inputs_from_treebank(test_treebank, predict_tags=True) 187 | test_predicted = list(parser.parse_sents(test_inputs)) 188 | test_fscore = evaluate.evalb(args.evalb_dir, test_treebank.trees, test_predicted) 189 | test_elapsed = format_elapsed(start_time) 190 | print("test-fscore {} test-elapsed {}".format(test_fscore, test_elapsed)) 191 | 192 | print("Parsing test sentences (not predicting tags)...") 193 | start_time = time.time() 194 | test_inputs = inputs_from_treebank(test_treebank, predict_tags=False) 195 | notags_test_predicted = list(parser.parse_sents(test_inputs)) 196 | notags_test_fscore = evaluate.evalb( 197 | args.evalb_dir, test_treebank.trees, notags_test_predicted 198 | ) 199 | notags_test_elapsed = format_elapsed(start_time) 200 | print( 201 | "test-fscore {} test-elapsed {}".format(notags_test_fscore, notags_test_elapsed) 202 | ) 203 | 204 | print("Exporting tokenizer...") 205 | model.retokenizer.tokenizer.save_pretrained(args.output_dir) 206 | 207 | print("Exporting config...") 208 | config = model.pretrained_model.config 209 | config.benepar = model.config 210 | config.save_pretrained(args.output_dir) 211 | 212 | if args.compress: 213 | print("Compressing weights...") 214 | state_dict = get_compressed_state_dict(model.cpu()) 215 | print("Saving weights...") 216 | else: 217 | print("Exporting weights...") 218 | state_dict = model.cpu().state_dict() 219 | torch.save(state_dict, os.path.join(args.output_dir, "benepar_model.bin")) 220 | 221 | del model, parser, state_dict 222 | 223 | print("Loading exported model from {}...".format(args.output_dir)) 224 | exported_parser = Parser(args.output_dir, batch_size=args.batch_size) 225 | 226 | if test_treebank is None: 227 | print() 228 | print("Export complete.") 229 | print("Did not verify model accuracy because no treebank was provided.") 230 | return 231 | 232 | print("Parsing test sentences (predicting tags)...") 233 | start_time = time.time() 234 | test_inputs = inputs_from_treebank(test_treebank, predict_tags=True) 235 | exported_predicted = list(exported_parser.parse_sents(test_inputs)) 236 | exported_fscore = evaluate.evalb( 237 | args.evalb_dir, test_treebank.trees, exported_predicted 238 | ) 239 | exported_elapsed = format_elapsed(start_time) 240 | print( 241 | "exported-fscore {} exported-elapsed {}".format( 242 | exported_fscore, exported_elapsed 243 | ) 244 | ) 245 | 246 | print("Parsing test sentences (not predicting tags)...") 247 | start_time = time.time() 248 | test_inputs = inputs_from_treebank(test_treebank, predict_tags=False) 249 | notags_exported_predicted = list(exported_parser.parse_sents(test_inputs)) 250 | notags_exported_fscore = evaluate.evalb( 251 | args.evalb_dir, test_treebank.trees, notags_exported_predicted 252 | ) 253 | notags_exported_elapsed = format_elapsed(start_time) 254 | print( 255 | "exported-fscore {} exported-elapsed {}".format( 256 | notags_exported_fscore, notags_exported_elapsed 257 | ) 258 | ) 259 | 260 | print() 261 | print("Export and verification complete.") 262 | fscore_delta = evaluate.FScore( 263 | recall=notags_exported_fscore.recall - notags_test_fscore.recall, 264 | precision=notags_exported_fscore.precision - notags_test_fscore.precision, 265 | fscore=notags_exported_fscore.fscore - notags_test_fscore.fscore, 266 | complete_match=( 267 | notags_exported_fscore.complete_match - notags_test_fscore.complete_match 268 | ), 269 | tagging_accuracy=( 270 | exported_fscore.tagging_accuracy - test_fscore.tagging_accuracy 271 | ), 272 | ) 273 | print("delta-fscore {}".format(fscore_delta)) 274 | 275 | 276 | def main(): 277 | parser = argparse.ArgumentParser() 278 | subparsers = parser.add_subparsers() 279 | 280 | subparser = subparsers.add_parser("test") 281 | subparser.set_defaults(callback=run_test) 282 | subparser.add_argument("--model-path", type=str, required=True) 283 | subparser.add_argument("--evalb-dir", default="EVALB/") 284 | subparser.add_argument("--test-path", type=str, required=True) 285 | subparser.add_argument("--test-path-text", type=str) 286 | subparser.add_argument("--text-processing", default="default") 287 | subparser.add_argument("--predict-tags", action="store_true") 288 | subparser.add_argument("--output-path", default="") 289 | subparser.add_argument("--batch-size", type=int, default=8) 290 | 291 | subparser = subparsers.add_parser("export") 292 | subparser.set_defaults(callback=run_export) 293 | subparser.add_argument("--model-path", type=str, required=True) 294 | subparser.add_argument("--output-dir", type=str, required=True) 295 | subparser.add_argument("--evalb-dir", default="EVALB/") 296 | subparser.add_argument("--test-path", type=str, default=None) 297 | subparser.add_argument("--test-path-text", type=str) 298 | subparser.add_argument("--text-processing", default="default") 299 | subparser.add_argument("--compress", action="store_true") 300 | subparser.add_argument("--batch-size", type=int, default=8) 301 | 302 | args = parser.parse_args() 303 | args.callback(args) 304 | 305 | 306 | if __name__ == "__main__": 307 | main() 308 | -------------------------------------------------------------------------------- /src/inference_seq2tree.py: -------------------------------------------------------------------------------- 1 | ''' 2 | author: cxy 3 | date: 2022/04/05 4 | function: change the data format from raw to tree. 5 | input: 猴子用尾巴荡秋千。 6 | output: (TOP (S (n 猴)(n 子)(n 用))(n 尾)(n 巴)(n 荡)(n 秋)(n 千)(。 。))) 7 | ''' 8 | import re 9 | 10 | punctuation_list = [',','。','、',';',':','?','!','“','”','‘','’','—','…','(',')','《','》'] 11 | 12 | def data_pre_processing(x): 13 | x = re.sub('——','—', x) 14 | x = re.sub('……', '…', x) 15 | return x 16 | 17 | def separate_each_character(x): 18 | ''' 19 | input: 猴子用尾巴荡秋千。 20 | output: (n 猴)(n 子)(n 用)(n 尾)(n 巴)(n 荡)(n 秋)(n 千)(。 。) 21 | ''' 22 | x_list = [] 23 | for i in x: 24 | if i in punctuation_list: 25 | i = '(' + i + ' ' + i + ')' 26 | x_list.append(i) 27 | else: 28 | i = '(' + 'n' + ' ' + i + ')' 29 | x_list.append(i) 30 | x = ''.join(x_list) 31 | return x 32 | 33 | def seq2tree(x): 34 | ''' 35 | input: (n 猴)(n 子)(n 用)(n 尾)(n 巴)(n 荡)(n 秋)(n 千)(。 。) 36 | output: (TOP (S (n 猴)(n 子)(n 用)(n 尾)(n 巴)(n 荡)(n 秋)(n 千)(。 。))) 37 | ''' 38 | tree = '(' + 'TOP' + ' ' + '(' + 'S' + ' ' + x + ')' + ')' 39 | return tree 40 | 41 | def main(): 42 | seq_data_path = './data/inference/raw_data/raw_data.txt' 43 | tree_data_path = './data/inference/tree_data/tree_data.txt' 44 | 45 | line_list = [] 46 | with open(seq_data_path, 'r', encoding='utf-8') as s: 47 | lines = s.readlines() 48 | for line in lines: 49 | line = data_pre_processing(line.strip()) 50 | line = separate_each_character(line) 51 | line = seq2tree(line) 52 | line_list.append(line) 53 | s.close() 54 | 55 | with open(tree_data_path, 'w', encoding='utf-8') as t: 56 | t.write('\n'.join(line_list)) 57 | t.write('\n') 58 | 59 | if __name__ == '__main__': 60 | main() 61 | -------------------------------------------------------------------------------- /src/learning_rates.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import ReduceLROnPlateau 2 | 3 | 4 | class WarmupThenReduceLROnPlateau(ReduceLROnPlateau): 5 | def __init__(self, optimizer, warmup_steps, *args, **kwargs): 6 | """ 7 | Args: 8 | optimizer (Optimizer): Optimizer to wrap 9 | warmup_steps: number of steps before reaching base learning rate 10 | *args: Arguments for ReduceLROnPlateau 11 | **kwargs: Arguments for ReduceLROnPlateau 12 | """ 13 | super().__init__(optimizer, *args, **kwargs) 14 | self.warmup_steps = warmup_steps 15 | self.steps_taken = 0 16 | self.base_lrs = list(map(lambda group: group["lr"], optimizer.param_groups)) 17 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 18 | param_group["lr"] = lr 19 | 20 | def get_lr(self): 21 | assert self.steps_taken <= self.warmup_steps 22 | return [ 23 | base_lr * (self.steps_taken / self.warmup_steps) 24 | for base_lr in self.base_lrs 25 | ] 26 | 27 | def step(self, metrics=None): 28 | self.steps_taken += 1 29 | if self.steps_taken <= self.warmup_steps: 30 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 31 | param_group["lr"] = lr 32 | elif metrics is not None: 33 | super().step(metrics) 34 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import functools 3 | import itertools 4 | import os.path 5 | import time 6 | 7 | import torch 8 | 9 | import numpy as np 10 | 11 | from benepar import char_lstm 12 | from benepar import decode_chart 13 | from benepar import nkutil 14 | from benepar import parse_chart 15 | import evaluate 16 | import learning_rates 17 | import treebanks 18 | 19 | 20 | def format_elapsed(start_time): 21 | elapsed_time = int(time.time() - start_time) 22 | minutes, seconds = divmod(elapsed_time, 60) 23 | hours, minutes = divmod(minutes, 60) 24 | days, hours = divmod(hours, 24) 25 | elapsed_string = "{}h{:02}m{:02}s".format(hours, minutes, seconds) 26 | if days > 0: 27 | elapsed_string = "{}d{}".format(days, elapsed_string) 28 | return elapsed_string 29 | 30 | def make_hparams(): 31 | return nkutil.HParams( 32 | # Data processing 33 | max_len_train=0, # no length limit 34 | max_len_dev=0, # no length limit 35 | # Optimization 36 | batch_size=32, 37 | learning_rate=0.00005, 38 | learning_rate_warmup_steps=160, 39 | clip_grad_norm=0.0, # no clipping 40 | checks_per_epoch=1, 41 | step_decay_factor=0.5, 42 | step_decay_patience=5, 43 | max_consecutive_decays=3, # establishes a termination criterion 44 | # CharLSTM 45 | use_chars_lstm=False, 46 | d_char_emb=64, 47 | char_lstm_input_dropout=0.2, 48 | # BERT and other pre-trained models 49 | ## use_pretrained=False, 50 | ## pretrained_model="bert-base-uncased", 51 | use_pretrained=True, 52 | pretrained_model="bert-base-chinese", 53 | # Partitioned transformer encoder 54 | ## use_encoder=False, 55 | use_encoder=True, 56 | d_model=1024, 57 | num_layers=8, 58 | num_heads=8, 59 | d_kv=64, 60 | d_ff=2048, 61 | encoder_max_len=512, 62 | # Dropout 63 | morpho_emb_dropout=0.2, 64 | attention_dropout=0.2, 65 | relu_dropout=0.1, 66 | residual_dropout=0.2, 67 | # Output heads and losses 68 | force_root_constituent="auto", 69 | predict_tags=False, 70 | d_label_hidden=256, 71 | d_tag_hidden=256, 72 | tag_loss_scale=5.0, 73 | ) 74 | 75 | 76 | def run_train(args, hparams): 77 | if args.numpy_seed is not None: 78 | print("Setting numpy random seed to {}...".format(args.numpy_seed)) 79 | np.random.seed(args.numpy_seed) 80 | 81 | seed_from_numpy = np.random.randint(2147483648) 82 | print("Manual seed for pytorch:", seed_from_numpy) 83 | torch.manual_seed(seed_from_numpy) 84 | 85 | hparams.set_from_args(args) 86 | print("Hyperparameters:") 87 | hparams.print() 88 | 89 | print("Loading training trees from {}...".format(args.train_path)) 90 | train_treebank = treebanks.load_trees( 91 | args.train_path, args.train_path_text, args.text_processing 92 | ) 93 | if hparams.max_len_train > 0: ## 0 94 | train_treebank = train_treebank.filter_by_length(hparams.max_len_train) 95 | print("Loaded {:,} training examples.".format(len(train_treebank))) 96 | 97 | print("Loading development trees from {}...".format(args.dev_path)) 98 | dev_treebank = treebanks.load_trees( 99 | args.dev_path, args.dev_path_text, args.text_processing 100 | ) 101 | if hparams.max_len_dev > 0: ## 0 102 | dev_treebank = dev_treebank.filter_by_length(hparams.max_len_dev) 103 | print("Loaded {:,} development examples.".format(len(dev_treebank))) 104 | 105 | print("Constructing vocabularies...") 106 | label_vocab = decode_chart.ChartDecoder.build_vocab(train_treebank.trees) # {'': 0, '#1': 1, '#2': 2, '#2::#1': 3, '#3': 4, '#3::#2': 5} 107 | if hparams.use_chars_lstm: 108 | char_vocab = char_lstm.RetokenizerForCharLSTM.build_vocab(train_treebank.sents) 109 | else: 110 | char_vocab = None 111 | 112 | tag_vocab = set() 113 | 114 | for tree in train_treebank.trees: 115 | # print('tree.pos:', tree.pos) 116 | for _, tag in tree.pos(): 117 | tag_vocab.add(tag) 118 | tag_vocab = ["UNK"] + sorted(tag_vocab) 119 | tag_vocab = {label: i for i, label in enumerate(tag_vocab)} 120 | 121 | if hparams.force_root_constituent.lower() in ("true", "yes", "1"): 122 | hparams.force_root_constituent = True 123 | elif hparams.force_root_constituent.lower() in ("false", "no", "0"): 124 | hparams.force_root_constituent = False 125 | elif hparams.force_root_constituent.lower() == "auto": 126 | hparams.force_root_constituent = ( 127 | decode_chart.ChartDecoder.infer_force_root_constituent(train_treebank.trees) 128 | ) 129 | print("Set hparams.force_root_constituent to", hparams.force_root_constituent) 130 | 131 | print("Initializing model...") 132 | parser = parse_chart.ChartParser( 133 | tag_vocab=tag_vocab, 134 | label_vocab=label_vocab, 135 | char_vocab=char_vocab, 136 | hparams=hparams, 137 | ) 138 | if args.parallelize: 139 | parser.parallelize() 140 | elif torch.cuda.is_available(): 141 | parser.cuda() 142 | else: 143 | print("Not using CUDA!") 144 | 145 | print("Initializing optimizer...") 146 | trainable_parameters = [ 147 | param for param in parser.parameters() if param.requires_grad 148 | ] 149 | 150 | optimizer = torch.optim.Adam( 151 | trainable_parameters, lr=hparams.learning_rate, betas=(0.9, 0.98), eps=1e-9 152 | ) 153 | 154 | scheduler = learning_rates.WarmupThenReduceLROnPlateau( 155 | optimizer, 156 | hparams.learning_rate_warmup_steps, 157 | mode="max", 158 | factor=hparams.step_decay_factor, 159 | patience=hparams.step_decay_patience * hparams.checks_per_epoch, 160 | verbose=True, 161 | ) 162 | 163 | clippable_parameters = trainable_parameters 164 | grad_clip_threshold = ( 165 | np.inf if hparams.clip_grad_norm == 0 else hparams.clip_grad_norm 166 | ) 167 | 168 | print("Training...") 169 | total_processed = 0 170 | current_processed = 0 171 | check_every = len(train_treebank) / hparams.checks_per_epoch 172 | best_dev_fscore = -np.inf 173 | 174 | best_dev_model_path = None 175 | best_dev_processed = 0 176 | 177 | start_time = time.time() 178 | 179 | def check_dev(): 180 | nonlocal best_dev_fscore 181 | nonlocal best_dev_model_path 182 | nonlocal best_dev_processed 183 | 184 | dev_start_time = time.time() 185 | 186 | dev_predicted = parser.parse( 187 | dev_treebank.without_gold_annotations(), 188 | subbatch_max_tokens=args.subbatch_max_tokens, 189 | ) 190 | dev_fscore = evaluate.evalb(args.evalb_dir, dev_treebank.trees, dev_predicted) 191 | 192 | print( 193 | "dev-fscore {} " 194 | "dev-elapsed {} " 195 | "total-elapsed {}".format( 196 | dev_fscore, 197 | format_elapsed(dev_start_time), 198 | format_elapsed(start_time), 199 | ) 200 | ) 201 | 202 | if dev_fscore.fscore > best_dev_fscore: 203 | if best_dev_model_path is not None: 204 | extensions = [".pt"] 205 | for ext in extensions: 206 | path = best_dev_model_path + ext 207 | if os.path.exists(path): 208 | print("Removing previous model file {}...".format(path)) 209 | os.remove(path) 210 | 211 | best_dev_fscore = dev_fscore.fscore 212 | best_dev_model_path = "{}_dev={:.2f}".format( 213 | args.model_path_base, dev_fscore.fscore 214 | ) 215 | best_dev_processed = total_processed 216 | print("Saving new best model to {}...".format(best_dev_model_path)) 217 | torch.save( 218 | { 219 | "config": parser.config, 220 | "state_dict": parser.state_dict(), 221 | "optimizer": optimizer.state_dict(), 222 | }, 223 | best_dev_model_path + ".pt", 224 | ) 225 | 226 | data_loader = torch.utils.data.DataLoader( 227 | train_treebank, 228 | batch_size=hparams.batch_size, 229 | shuffle=True, 230 | collate_fn=functools.partial( 231 | parser.encode_and_collate_subbatches, 232 | subbatch_max_tokens=args.subbatch_max_tokens, 233 | ), 234 | ) 235 | for epoch in itertools.count(start=1): 236 | epoch_start_time = time.time() 237 | 238 | for batch_num, batch in enumerate(data_loader, start=1): 239 | optimizer.zero_grad() 240 | parser.train() 241 | 242 | batch_loss_value = 0.0 243 | for subbatch_size, subbatch in batch: 244 | loss = parser.compute_loss(subbatch) 245 | loss_value = float(loss.data.cpu().numpy()) 246 | batch_loss_value += loss_value 247 | if loss_value > 0: 248 | loss.backward() 249 | del loss 250 | total_processed += subbatch_size 251 | current_processed += subbatch_size 252 | 253 | grad_norm = torch.nn.utils.clip_grad_norm_( 254 | clippable_parameters, grad_clip_threshold 255 | ) 256 | 257 | optimizer.step() 258 | 259 | print( 260 | "epoch {:,} " 261 | "batch {:,}/{:,} " 262 | "processed {:,} " 263 | "batch-loss {:.4f} " 264 | "grad-norm {:.4f} " 265 | "epoch-elapsed {} " 266 | "total-elapsed {}".format( 267 | epoch, 268 | batch_num, 269 | int(np.ceil(len(train_treebank) / hparams.batch_size)), 270 | total_processed, 271 | batch_loss_value, 272 | grad_norm, 273 | format_elapsed(epoch_start_time), 274 | format_elapsed(start_time), 275 | ) 276 | ) 277 | 278 | if current_processed >= check_every: 279 | current_processed -= check_every 280 | check_dev() 281 | scheduler.step(metrics=best_dev_fscore) 282 | else: 283 | scheduler.step() 284 | 285 | if (total_processed - best_dev_processed) > ( 286 | (hparams.step_decay_patience + 1) 287 | * hparams.max_consecutive_decays 288 | * len(train_treebank) 289 | ): 290 | print("Terminating due to lack of improvement in dev fscore.") 291 | break 292 | 293 | ############################################################################################## 294 | def run_test(args): 295 | print("Loading test trees from {}...".format(args.test_path)) 296 | test_treebank = treebanks.load_trees( 297 | args.test_path, args.test_path_text, args.text_processing 298 | ) 299 | print("Loaded {:,} test examples.".format(len(test_treebank))) 300 | 301 | if len(args.model_path) != 1: 302 | raise NotImplementedError( 303 | "Ensembling multiple parsers is not " 304 | "implemented in this version of the code." 305 | ) 306 | 307 | model_path = args.model_path[0] 308 | print("Loading model from {}...".format(model_path)) 309 | parser = parse_chart.ChartParser.from_trained(model_path) 310 | if args.no_predict_tags and parser.f_tag is not None: 311 | print("Removing part-of-speech tagging head...") 312 | parser.f_tag = None 313 | if args.parallelize: 314 | parser.parallelize() 315 | elif torch.cuda.is_available(): 316 | parser.cuda() 317 | 318 | print("Parsing test sentences...") 319 | start_time = time.time() 320 | 321 | test_predicted = parser.parse( 322 | test_treebank.without_gold_annotations(), 323 | subbatch_max_tokens=args.subbatch_max_tokens, 324 | ) 325 | 326 | if args.output_path == "-": 327 | for tree in test_predicted: 328 | print(tree.pformat(margin=1e100)) 329 | elif args.output_path: 330 | with open(args.output_path, "w") as outfile: 331 | for tree in test_predicted: 332 | outfile.write("{}\n".format(tree.pformat(margin=1e100))) 333 | 334 | # The tree loader does some preprocessing to the trees (e.g. stripping TOP 335 | # symbols or SPMRL morphological features). We compare with the input file 336 | # directly to be extra careful about not corrupting the evaluation. We also 337 | # allow specifying a separate "raw" file for the gold trees: the inputs to 338 | # our parser have traces removed and may have predicted tags substituted, 339 | # and we may wish to compare against the raw gold trees to make sure we 340 | # haven't made a mistake. As far as we can tell all of these variations give 341 | # equivalent results. 342 | ref_gold_path = args.test_path 343 | if args.test_path_raw is not None: 344 | print("Comparing with raw trees from", args.test_path_raw) 345 | ref_gold_path = args.test_path_raw 346 | 347 | test_fscore = evaluate.evalb( 348 | args.evalb_dir, test_treebank.trees, test_predicted, ref_gold_path=ref_gold_path 349 | ) 350 | 351 | print( 352 | "test-fscore {} " 353 | "test-elapsed {}".format( 354 | test_fscore, 355 | format_elapsed(start_time), 356 | ) 357 | ) 358 | 359 | 360 | 361 | ##################################################################################################### 362 | def run_auto_labels(args): 363 | print("Loading test trees from {}...".format(args.test_path)) 364 | test_treebank = treebanks.load_trees( 365 | args.test_path, args.test_path_text, args.text_processing 366 | ) 367 | print("Loaded {:,} test examples.".format(len(test_treebank))) 368 | 369 | if len(args.model_path) != 1: 370 | raise NotImplementedError( 371 | "Ensembling multiple parsers is not " 372 | "implemented in this version of the code." 373 | ) 374 | 375 | model_path = args.model_path[0] 376 | print("Loading model from {}...".format(model_path)) 377 | parser = parse_chart.ChartParser.from_trained(model_path) 378 | if args.no_predict_tags and parser.f_tag is not None: 379 | print("Removing part-of-speech tagging head...") 380 | parser.f_tag = None 381 | if args.parallelize: 382 | parser.parallelize() 383 | elif torch.cuda.is_available(): 384 | parser.cuda() 385 | 386 | print("Parsing test sentences...") 387 | start_time = time.time() 388 | 389 | test_predicted = parser.parse( 390 | test_treebank.without_gold_annotations(), 391 | subbatch_max_tokens=args.subbatch_max_tokens, 392 | ) 393 | 394 | # if args.output_path == "-": 395 | # for tree in test_predicted: 396 | # print(tree.pformat(margin=1e100)) 397 | # elif args.output_path: 398 | # with open(args.output_path, "w") as outfile: 399 | # for tree in test_predicted: 400 | # outfile.write("{}\n".format(tree.pformat(margin=1e100))) 401 | 402 | ref_gold_path = args.test_path 403 | if args.test_path_raw is not None: 404 | print("Comparing with raw trees from", args.test_path_raw) 405 | ref_gold_path = args.test_path_raw 406 | 407 | 408 | import seq_with_label 409 | seq_with_label.output(args.output_path, test_predicted) 410 | 411 | 412 | def main(): 413 | parser = argparse.ArgumentParser() 414 | subparsers = parser.add_subparsers() 415 | 416 | hparams = make_hparams() 417 | subparser = subparsers.add_parser("train") 418 | subparser.set_defaults(callback=lambda args: run_train(args, hparams)) 419 | hparams.populate_arguments(subparser) 420 | subparser.add_argument("--numpy-seed", type=int) 421 | subparser.add_argument("--model-path-base", required=True) 422 | subparser.add_argument("--evalb-dir", default="EVALB/") 423 | subparser.add_argument("--train-path", default="data/train/tree_data/tree_train.txt") 424 | subparser.add_argument("--train-path-text", type=str) 425 | subparser.add_argument("--dev-path", default="data/train/tree_data/tree_validate.txt") 426 | subparser.add_argument("--dev-path-text", type=str) 427 | subparser.add_argument("--text-processing", default="chinese") 428 | subparser.add_argument("--subbatch-max-tokens", type=int, default=2000) 429 | subparser.add_argument("--parallelize", action="store_true") 430 | subparser.add_argument("--print-vocabs", action="store_true") 431 | 432 | subparser = subparsers.add_parser("test") 433 | subparser.set_defaults(callback=run_test) 434 | subparser.add_argument("--model-path", nargs="+", required=True) 435 | subparser.add_argument("--evalb-dir", default="EVALB/") 436 | subparser.add_argument("--test-path", default="data/train/tree_data/tree_test.txt") 437 | subparser.add_argument("--test-path-text", type=str) 438 | subparser.add_argument("--test-path-raw", type=str) 439 | subparser.add_argument("--text-processing", default="chinese") 440 | subparser.add_argument("--subbatch-max-tokens", type=int, default=500) 441 | subparser.add_argument("--parallelize", action="store_true") 442 | subparser.add_argument("--output-path", default="") 443 | subparser.add_argument("--no-predict-tags", action="store_true") 444 | 445 | 446 | subparser = subparsers.add_parser("inference") 447 | subparser.set_defaults(callback=run_auto_labels) 448 | subparser.add_argument("--model-path", nargs="+", required=True) 449 | subparser.add_argument("--evalb-dir", default="EVALB/") 450 | subparser.add_argument("--test-path", default="data/inference/tree_data/tree_data.txt") 451 | subparser.add_argument("--test-path-text", type=str) 452 | subparser.add_argument("--test-path-raw", type=str) 453 | subparser.add_argument("--text-processing", default="default") 454 | subparser.add_argument("--subbatch-max-tokens", type=int, default=500) 455 | subparser.add_argument("--parallelize", action="store_true") 456 | subparser.add_argument("--output-path", default="") 457 | subparser.add_argument("--no-predict-tags", action="store_true") 458 | 459 | 460 | args = parser.parse_args() 461 | args.callback(args) 462 | 463 | 464 | if __name__ == "__main__": 465 | main() 466 | -------------------------------------------------------------------------------- /src/seq_with_label.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | 4 | import math 5 | import os.path 6 | import subprocess 7 | import tempfile 8 | 9 | import nltk 10 | 11 | 12 | def remove_top(a): 13 | a = a.replace('(TOP (S ', '') 14 | a = a[::-1].replace('))', '', 1)[::-1] 15 | return a 16 | 17 | def replace_n(a, i, j, num): 18 | a = a.replace(i, '*', num) 19 | a = a.replace('*', i, num-1) 20 | a = a.replace('*', j) 21 | return a 22 | 23 | def replace1(a): 24 | a = re.sub('\n', '', a) 25 | for i in range(len(a)): 26 | num_left = 0 27 | num_right = 0 28 | flag = 0 29 | for j in range(len(a)): 30 | if a[j] == '(' : 31 | num_left += 1 32 | if a[j+1] == 'S' and flag == 0: 33 | b = a[j+1] 34 | flag = 1 35 | if a[j+1] == '#' and flag == 0: 36 | b = a[j+1] + a[j+2] 37 | flag = 1 38 | 39 | elif a[j] == ')' : 40 | num_right += 1 41 | if num_right == num_left and a[j-1] == ')': 42 | a = replace_n(a, ')', b, num_left) 43 | a = a.replace('('+b, '', 1) 44 | break 45 | return a 46 | 47 | def replace2(a): 48 | s = re.sub('\n', '', a) 49 | s = re.sub('#', '', s) 50 | 51 | compileX = re.compile(r'\d+') 52 | num_result = compileX.findall(s) 53 | for i in num_result: 54 | if i != '1': 55 | s = re.sub(i, '#'+ str(max(i)), s, 1) 56 | s = re.sub('#1', '##', s) 57 | s = re.sub('1','#1', s) 58 | s = re.sub('##','#1',s) 59 | 60 | s = s.replace('(n ', '') 61 | s = s.replace(')', '') 62 | 63 | punctuation_list = [',','。','、',';',':','?','!','“','”','‘','’','—','…','(',')','《','》'] 64 | for punc in punctuation_list: 65 | # s = re.sub('('+ punc, '', s) 66 | s = s.replace('('+ punc, '') 67 | 68 | s = re.sub(' ', '', s) 69 | 70 | return s 71 | 72 | 73 | def output_file(temp_data_path, output_data_path): 74 | 75 | line_sen_list = [] 76 | 77 | with open(temp_data_path, 'r', encoding='utf-8') as f: 78 | lines = f.readlines() 79 | for line in lines: 80 | sen_token_list = [] 81 | 82 | if line != '\n' and line != '': 83 | 84 | ss = line 85 | sss = remove_top(ss) 86 | sss = replace1(sss) 87 | sss = replace2(sss) 88 | sen_token_list.append(sss+ '\n') 89 | line_sen_list.append(''.join(sen_token_list)) 90 | f.close() 91 | 92 | with open(output_data_path,'w+', encoding='utf-8') as o: 93 | o.write(''.join(line_sen_list)) 94 | o.close() 95 | 96 | 97 | 98 | def output_seq(tem_data_path): 99 | 100 | line_sen_list = [] 101 | 102 | with open(tem_data_path, 'r', encoding='utf-8') as f: 103 | lines = f.readlines() 104 | for line in lines: 105 | sen_token_list = [] 106 | if line != '\n' and line != '': 107 | ss = line 108 | sss = remove_top(ss) 109 | sss = replace1(sss) 110 | sss = replace2(sss) 111 | sen_token_list.append(sss+ '\n') 112 | line_sen_list.append(''.join(sen_token_list)) 113 | f.close() 114 | return line_sen_list 115 | 116 | 117 | 118 | 119 | ################################################################################################################ 120 | def output(output_path, predicted_trees): 121 | for predicted_tree in predicted_trees: 122 | assert isinstance(predicted_tree, nltk.Tree) 123 | 124 | temp_dir = tempfile.TemporaryDirectory(prefix="evalb-") 125 | predicted_path = os.path.join(temp_dir.name, "predicted.txt") 126 | 127 | with open(predicted_path, "w") as outfile: 128 | for tree in predicted_trees: 129 | # print(tree) 130 | outfile.write("{}\n".format(tree.pformat(margin=1e100))) 131 | 132 | output_file(predicted_path, output_path) 133 | -------------------------------------------------------------------------------- /src/train_raw2tree.py: -------------------------------------------------------------------------------- 1 | ''' 2 | author: cxy 3 | date: 2022/04/05 4 | function: change the data format from raw to tree and split into train, validate and test. 5 | input: 猴子#2用#1尾巴#2荡秋千#3。 6 | output: (TOP (S (#3 (#2 (#1 (n 猴)(n 子))) (#2 (#1 (n 用))(#1 (n 尾)(n 巴))) (#2 (#1 (n 荡)(n 秋)(n 千)))) (。 。))) 7 | ''' 8 | import re 9 | from sklearn.model_selection import train_test_split 10 | 11 | 12 | punctuation_list = [',','。','、',';',':','?','!','“','”','‘','’','—','…','(',')','《','》'] 13 | 14 | def data_pre_processing(x): 15 | x = re.sub('——','—', x) 16 | x = re.sub('……', '…', x) 17 | return x 18 | 19 | 20 | 21 | def separate_each_character(x): 22 | ''' 23 | input: 猴子#2用#1尾巴#2荡秋千#3。 24 | output: (n 猴)(n 子)2(n 用)1(n 尾)(n 巴)2(n 荡)(n 秋)(n 千)3(。 。) 25 | ''' 26 | x = re.sub('#','',x) 27 | x_list = [] 28 | for i in x: 29 | if i in ['1', '2', '3']: 30 | x_list.append(i) 31 | elif i in punctuation_list: 32 | i = '(' + i + ' ' + i + ')' 33 | x_list.append(i) 34 | else: 35 | i = '(' + 'n' + ' ' + i + ')' 36 | x_list.append(i) 37 | 38 | x = ''.join(x_list) 39 | return x 40 | 41 | 42 | def seq2tree(x): 43 | ''' 44 | input: (n 猴)(n 子)2(n 用)1(n 尾)(n 巴)2(n 荡)(n 秋)(n 千)3(。 。) 45 | output: (TOP (S (#3 (#2 (#1 (n 猴)(n 子)))(#2 (#1 (n 用))(#1 (n 尾)(n 巴)))(#2 (#1 (n 荡)(n 秋)(n 千))))(。 。))) 46 | ''' 47 | iph_list = x.split('3') 48 | iph_ = [] 49 | for iph in iph_list[:-1]: 50 | pph_list = iph.split('2') 51 | pph_ = [] 52 | for pph in pph_list: 53 | pw_list = pph.split('1') 54 | pw_ = [] 55 | for pw in pw_list: 56 | pw = '(' + '#1' + ' ' + pw + ')' 57 | pw_.append(pw) 58 | pw_ = ''.join(pw_) 59 | pw_ = '(' + '#2' + ' ' + pw_ + ')' 60 | pph_.append(pw_) 61 | pph_ = ''.join(pph_) 62 | pph_ = '(' + '#3' + ' ' + pph_ + ')' 63 | iph_.append(pph_) 64 | iph_.append(iph_list[-1]) 65 | iph_ = ''.join(iph_) 66 | tree = '(' + 'TOP' + ' ' + '(' + 'S' + ' ' + iph_ + ')' + ')' 67 | return tree 68 | 69 | 70 | def write_data(output_path, line_sen_list): 71 | ''' 72 | output_path: 需要写入的文件地址 73 | line_sen_list: 需要写入的文件内容行列表 74 | ''' 75 | with open(output_path, 'w', encoding = 'utf-8') as o: 76 | o.write('\n'.join(line_sen_list)) 77 | o.close() 78 | 79 | 80 | def main(): 81 | 82 | seq_data_path = './data/train/raw_data/raw_train.txt' 83 | train_data_path = './data/train/tree_data/tree_train.txt' 84 | validate_data_path = './data/train/tree_data/tree_validate.txt' 85 | test_data_path = './data/train/tree_data/tree_test.txt' 86 | 87 | ## raw2tree 88 | line_list = [] 89 | with open(seq_data_path, 'r', encoding='utf-8') as s: 90 | lines = s.readlines() 91 | for line in lines: 92 | line = data_pre_processing(line.strip()) 93 | line = separate_each_character(line) 94 | line = seq2tree(line) 95 | line_list.append(line) 96 | s.close() 97 | 98 | ## divide dataset into train, validate, test with 8:1:1 99 | X_train, X_validate_test, _, y_validate_test = train_test_split(line_list, [0] * len(line_list), test_size = 0.2, random_state = 42) 100 | X_validate, X_test, _, _ = train_test_split(X_validate_test, y_validate_test, test_size = 0.5, random_state = 42) 101 | 102 | write_data(train_data_path, X_train) 103 | write_data(validate_data_path, X_validate) 104 | write_data(test_data_path, X_test) 105 | 106 | 107 | if __name__ == '__main__': 108 | main() 109 | 110 | -------------------------------------------------------------------------------- /src/transliterate.py: -------------------------------------------------------------------------------- 1 | BUCKWALTER_MAP = { 2 | '\'': '\u0621', 3 | '|': '\u0622', 4 | '>': '\u0623', 5 | 'O': '\u0623', 6 | '&': '\u0624', 7 | 'W': '\u0624', 8 | '<': '\u0625', 9 | 'I': '\u0625', 10 | '}': '\u0626', 11 | 'A': '\u0627', 12 | 'b': '\u0628', 13 | 'p': '\u0629', 14 | 't': '\u062A', 15 | 'v': '\u062B', 16 | 'j': '\u062C', 17 | 'H': '\u062D', 18 | 'x': '\u062E', 19 | 'd': '\u062F', 20 | '*': '\u0630', 21 | 'r': '\u0631', 22 | 'z': '\u0632', 23 | 's': '\u0633', 24 | '$': '\u0634', 25 | 'S': '\u0635', 26 | 'D': '\u0636', 27 | 'T': '\u0637', 28 | 'Z': '\u0638', 29 | 'E': '\u0639', 30 | 'g': '\u063A', 31 | '_': '\u0640', 32 | 'f': '\u0641', 33 | 'q': '\u0642', 34 | 'k': '\u0643', 35 | 'l': '\u0644', 36 | 'm': '\u0645', 37 | 'n': '\u0646', 38 | 'h': '\u0647', 39 | 'w': '\u0648', 40 | 'Y': '\u0649', 41 | 'y': '\u064A', 42 | 'F': '\u064B', 43 | 'N': '\u064C', 44 | 'K': '\u064D', 45 | 'a': '\u064E', 46 | 'u': '\u064F', 47 | 'i': '\u0650', 48 | '~': '\u0651', 49 | 'o': '\u0652', 50 | '`': '\u0670', 51 | '{': '\u0671', 52 | } 53 | 54 | BUCKWALTER_UNESCAPE = { 55 | "-LRB-": "(", 56 | "-RRB-": ")", 57 | "-LCB-": "{", 58 | "-RCB-": "}", 59 | "-LSB-": "[", 60 | "-RSB-": "]", 61 | '-PLUS-': "+", 62 | '-MINUS-': "-", 63 | } 64 | 65 | BUCKWALTER_UNCHANGED = set('.?!,"%-/:;=') 66 | 67 | HEBREW_MAP = { 68 | 'A': '\u05d0', 69 | 'B': '\u05d1', 70 | 'G': '\u05d2', 71 | 'D': '\u05d3', 72 | 'H': '\u05d4', 73 | 'W': '\u05d5', 74 | 'Z': '\u05d6', 75 | 'X': '\u05d7', 76 | 'J': '\u05d8', 77 | 'I': '\u05d9', 78 | 'K': '\u05db', 79 | 'L': '\u05dc', 80 | 'M': '\u05de', 81 | 'N': '\u05e0', 82 | 'S': '\u05e1', 83 | 'E': '\u05e2', 84 | 'P': '\u05e4', 85 | 'C': '\u05e6', 86 | 'Q': '\u05e7', 87 | 'R': '\u05e8', 88 | 'F': '\u05e9', 89 | 'T': '\u05ea', 90 | '0': '0', 91 | '1': '1', 92 | '2': '2', 93 | '3': '3', 94 | '4': '4', 95 | '5': '5', 96 | '6': '6', 97 | '7': '7', 98 | '8': '8', 99 | '9': '9', 100 | 'U': '"', 101 | 'O': '%', 102 | '.': '.', 103 | ',': ',', 104 | } 105 | 106 | HEBREW_SUFFIX_MAP = { 107 | '\u05db': '\u05da', 108 | '\u05de': '\u05dd', 109 | '\u05e0': '\u05df', 110 | '\u05e4': '\u05e3', 111 | '\u05e6': '\u05e5', 112 | } 113 | 114 | HEBREW_UNESCAPE = { 115 | "yyCLN": ":", 116 | "yyCM": ",", 117 | "yyDASH": "-", 118 | "yyDOT": ".", 119 | "yyELPS": "...", 120 | "yyEXCL": "!", 121 | "yyLRB": "(", 122 | "yyQM": "?", 123 | "yyRRB": ")", 124 | "yySCLN": ";", 125 | } 126 | 127 | 128 | 129 | def arabic(inp): 130 | """ 131 | Undo Buckwalter transliteration 132 | 133 | See: http://languagelog.ldc.upenn.edu/myl/ldc/morph/buckwalter.html 134 | 135 | This code inspired by: 136 | https://github.com/dlwh/epic/blob/master/src/main/scala/epic/util/ArabicNormalization.scala 137 | """ 138 | return "".join( 139 | BUCKWALTER_MAP.get(char, char) 140 | for char in BUCKWALTER_UNESCAPE.get(inp, inp)) 141 | 142 | def hebrew(inp): 143 | """ 144 | Undo Hebrew transliteration 145 | 146 | See: http://www.phil.uu.nl/ozsl/articles/simaan02.pdf 147 | 148 | This code inspired by: 149 | https://github.com/habeanf/yap/blob/b57502364b73ef78f3510eb890319ae268eeacca/nlp/parser/xliter8/types.go 150 | """ 151 | out = "".join( 152 | HEBREW_MAP.get(char, char) 153 | for char in HEBREW_UNESCAPE.get(inp, inp)) 154 | if out and (out[-1] in HEBREW_SUFFIX_MAP): 155 | out = out[:-1] + HEBREW_SUFFIX_MAP[out[-1]] 156 | return out 157 | 158 | TRANSLITERATIONS = { 159 | 'arabic': arabic, 160 | 'hebrew': hebrew, 161 | } 162 | -------------------------------------------------------------------------------- /src/treebanks.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import List, Optional, Tuple 3 | 4 | import nltk 5 | from nltk.corpus.reader.bracket_parse import BracketParseCorpusReader 6 | import tokenizations 7 | import torch 8 | 9 | from benepar import ptb_unescape 10 | from benepar.parse_base import BaseInputExample 11 | import transliterate 12 | 13 | 14 | @dataclasses.dataclass 15 | class ParsingExample(BaseInputExample): 16 | """A single parse tree and sentence.""" 17 | 18 | words: List[str] 19 | space_after: List[bool] 20 | tree: Optional[nltk.Tree] = None 21 | _pos: Optional[List[Tuple[str, str]]] = None 22 | 23 | def leaves(self): 24 | if self.tree is not None: 25 | return self.tree.leaves() 26 | elif self._pos is not None: 27 | return [word for word, tag in self._pos] 28 | else: 29 | return None 30 | 31 | def pos(self): 32 | if self.tree is not None: 33 | return self.tree.pos() 34 | else: 35 | return self._pos 36 | 37 | def without_gold_annotations(self): 38 | return dataclasses.replace(self, tree=None, _pos=self.pos()) 39 | 40 | 41 | class Treebank(torch.utils.data.Dataset): 42 | def __init__(self, examples): 43 | self.examples = examples 44 | 45 | def __len__(self): 46 | return len(self.examples) 47 | 48 | def __getitem__(self, index): 49 | return self.examples[index] 50 | 51 | @property 52 | def trees(self): 53 | return [x.tree for x in self.examples] 54 | 55 | @property 56 | def sents(self): 57 | return [x.words for x in self.examples] 58 | 59 | @property 60 | def tagged_sents(self): 61 | return [x.pos() for x in self.examples] 62 | 63 | def filter_by_length(self, max_len): 64 | return Treebank([x for x in self.examples if len(x.leaves()) <= max_len]) 65 | 66 | def without_gold_annotations(self): 67 | return Treebank([x.without_gold_annotations() for x in self.examples]) 68 | 69 | 70 | def read_text(text_path): 71 | sents = [] 72 | sent = [] 73 | end_of_multiword = 0 74 | multiword_combined = "" 75 | multiword_separate = [] 76 | multiword_sp_after = False 77 | with open(text_path) as f: 78 | for line in f: 79 | if not line.strip() or line.startswith("#"): 80 | if sent: 81 | sents.append(([w for w, sp in sent], [sp for w, sp in sent])) 82 | sent = [] 83 | assert end_of_multiword == 0 84 | continue 85 | fields = line.split("\t", 2) 86 | num_or_range = fields[0] 87 | w = fields[1] 88 | 89 | if "-" in num_or_range: 90 | end_of_multiword = int(num_or_range.split("-")[1]) 91 | multiword_combined = w 92 | multiword_separate = [] 93 | multiword_sp_after = "SpaceAfter=No" not in fields[-1] 94 | continue 95 | elif int(num_or_range) <= end_of_multiword: 96 | multiword_separate.append(w) 97 | if int(num_or_range) == end_of_multiword: 98 | _, separate_to_combined = tokenizations.get_alignments( 99 | multiword_combined, multiword_separate 100 | ) 101 | have_up_to = 0 102 | for i, char_idxs in enumerate(separate_to_combined): 103 | if i == len(multiword_separate) - 1: 104 | word = multiword_combined[have_up_to:] 105 | sent.append((word, multiword_sp_after)) 106 | elif char_idxs: 107 | word = multiword_combined[have_up_to : max(char_idxs) + 1] 108 | sent.append((word, False)) 109 | have_up_to = max(char_idxs) + 1 110 | else: 111 | sent.append(("", False)) 112 | assert int(num_or_range) == len(sent) 113 | end_of_multiword = 0 114 | multiword_combined = "" 115 | multiword_separate = [] 116 | multiword_sp_after = False 117 | continue 118 | else: 119 | assert int(num_or_range) == len(sent) + 1 120 | sp = "SpaceAfter=No" not in fields[-1] 121 | sent.append((w, sp)) 122 | return sents 123 | 124 | 125 | def load_trees(const_path, text_path=None, text_processing="default"): 126 | """Load a treebank. 127 | 128 | Args: 129 | const_path: Path to the file with one tree per line. 130 | text_path: (optional) Path to a file that provides the correct spelling for all 131 | tokens (without any escaping, transliteration, or other mangling) and 132 | information about whether there is whitespace after each token. Files in the 133 | CoNLL-U format (https://universaldependencies.org/format.html) are accepted, 134 | but the parser also accepts similarly-formatted files with just three fields 135 | (ID, FORM, MISC) instead of the usual ten. Text is recovered from the FORM 136 | field and any "SpaceAfter=No" annotations in the MISC field. 137 | text_processing: Text processing to use if no text_path is specified: 138 | - 'default': undo PTB-style escape sequences and attempt to guess whitespace 139 | surrounding punctuation 140 | - 'arabic': guess that all tokens are separated by spaces 141 | - 'arabic-translit': undo Buckwalter transliteration and guess that all 142 | tokens are separated by spaces 143 | - 'chinese': keep all tokens unchanged (i.e. do not attempt to find any 144 | escape sequences), and assume no whitespace between tokens 145 | - 'hebrew': guess that all tokens are separated by spaces 146 | - 'hebrew-translit': undo transliteration (see Sima'an et al. 2002) and 147 | guess that all tokens are separated by spaces 148 | 149 | Returns: 150 | A list of ParsingExample objects, which have the following attributes: 151 | - `tree` is an instance of nltk.Tree 152 | - `words` is a list of strings 153 | - `space_after` is a list of booleans 154 | """ 155 | reader = BracketParseCorpusReader("", [const_path]) 156 | trees = reader.parsed_sents() 157 | 158 | if text_path is not None: 159 | sents = read_text(text_path) 160 | elif text_processing in ("arabic-translit", "hebrew-translit"): 161 | translit = transliterate.TRANSLITERATIONS[ 162 | text_processing.replace("-translit", "") 163 | ] 164 | sents = [] 165 | for tree in trees: 166 | words = [translit(word) for word in tree.leaves()] 167 | sp_after = [True for _ in words] 168 | sents.append((words, sp_after)) 169 | elif text_processing in ("arabic", "hebrew"): 170 | sents = [] 171 | for tree in trees: 172 | words = tree.leaves() 173 | sp_after = [True for _ in words] 174 | sents.append((words, sp_after)) 175 | elif text_processing == "chinese": 176 | sents = [] 177 | for tree in trees: 178 | words = tree.leaves() 179 | sp_after = [False for _ in words] 180 | sents.append((words, sp_after)) 181 | elif text_processing == "default": 182 | sents = [] 183 | for tree in trees: 184 | words = ptb_unescape.ptb_unescape(tree.leaves()) 185 | sp_after = ptb_unescape.guess_space_after(tree.leaves()) 186 | sents.append((words, sp_after)) 187 | else: 188 | raise ValueError(f"Bad value for text_processing: {text_processing}") 189 | 190 | assert len(trees) == len(sents) 191 | treebank = Treebank( 192 | [ 193 | ParsingExample(tree=tree, words=words, space_after=space_after) 194 | for tree, (words, space_after) in zip(trees, sents) 195 | ] 196 | ) 197 | for example in treebank: 198 | assert len(example.words) == len(example.leaves()), ( 199 | "Constituency tree has a different number of tokens than the CONLL-U or " 200 | "other file used to specify reversible tokenization." 201 | ) 202 | return treebank 203 | --------------------------------------------------------------------------------