├── README.md
├── const.py
├── data
├── .gitignore
├── bht
│ ├── Chinese.bht.dev
│ ├── Chinese.bht.test
│ ├── Chinese.bht.train
│ ├── Chinese.lex.dev
│ ├── Chinese.lex.test
│ ├── Chinese.lex.train
│ ├── English.bht.dev
│ ├── English.bht.test
│ ├── English.bht.train
│ ├── English.lex.dev
│ ├── English.lex.test
│ └── English.lex.train
├── dep2bht.py
├── pos
│ ├── pos.chinese.json
│ └── pos.english.json
└── vocab
│ └── .gitignore
├── header-hexa.png
├── header.jpg
├── learning
├── __init__.py
├── crf.py
├── dataset.py
├── decode.py
├── evaluate.py
├── learn.py
└── util.py
├── load_dataset.sh
├── pat-models
└── .gitignore
├── requirements.txt
├── run.py
└── tagging
├── __init__.py
├── hexatagger.py
├── srtagger.py
├── tagger.py
├── test.py
├── tetratagger.py
├── transform.py
└── tree_tools.py
/README.md:
--------------------------------------------------------------------------------
1 | # Parsing as Tagging
2 |
3 |
4 |
5 |
6 | This repository contains code for training and evaluation of two papers:
7 |
8 | - On Parsing as Tagging
9 | - Hexatagging: Projective Dependency Parsing as Tagging
10 |
11 | ## Setting Up The Environment
12 | Set up a virtual environment and install the dependencies:
13 | ```bash
14 | pip install -r requirements.txt
15 | ```
16 |
17 | ## Getting The Data
18 | ### Constituency Parsing
19 | Follow the instructions in this [repo](https://github.com/nikitakit/self-attentive-parser/tree/master/data) to do the initial preprocessing on English WSJ and SPMRL datasets. The default data path is `data/spmrl` folder, where each file titled in `[LANGUAGE].[train/dev/test]` format.
20 | ### Dependency Parsing with Hexatagger
21 | 1. Convert CoNLL to Binary Headed Trees:
22 | ```bash
23 | python data/dep2bht.py
24 | ```
25 | This will generate the phrase-structured BHT trees in the `data/bht` directory.
26 | We placed the processed files already under the `data/bht` directory.
27 |
28 | ## Building The Tagging Vocab
29 | In order to use taggers, we need to build the vocabulary of tags for in-order, pre-order and post-order linearizations. You can cache these vocabularies using:
30 | ```bash
31 | python run.py vocab --lang [LANGUAGE] --tagger [TAGGER]
32 | ```
33 | Tagger can be `td-sr` for top-down (pre-order) shift--reduce linearization, `bu-sr` for bottom-up (post-order) shift--reduce linearization,`tetra` for in-order, and `hexa` for hexatagging linearization.
34 |
35 | ## Training
36 | Train the model and store the best checkpoint.
37 | ```bash
38 | python run.py train --batch-size [BATCH_SIZE] --tagger [TAGGER] --lang [LANGUAGE] --model [MODEL] --epochs [EPOCHS] --lr [LR] --model-path [MODEL_PATH] --output-path [PATH] --max-depth [DEPTH] --keep-per-depth [KPD] [--use-tensorboard]
39 | ```
40 | - batch size: use 32 to reproduce the results
41 | - tagger: `td-sr` or `bu-sr` or `tetra`
42 | - lang: language, one of the nine languages reported in the paper
43 | - model: `bert`, `bert+crf`, `bert+lstm`
44 | - model path: path that pretrained model is saved
45 | - output path: path to save the best trained model
46 | - max depth: maximum depth to keep in the decoding lattice
47 | - keep per depth: number of elements to keep track of in the decoding step
48 | - use-tensorboard: whether to store the logs in tensorboard or not (true or false)
49 |
50 | ## Evaluation
51 | Calculate evaluation metrics: fscore, precision, recall, loss.
52 | ```bash
53 | python run.py evaluate --lang [LANGUAGE] --model-name [MODEL] --model-path [MODEL_PATH] --bert-model-path [BERT_PATH] --max-depth [DEPTH] --keep-per-depth [KPD] [--is-greedy]
54 | ```
55 | - lang: language, one of the nine languages reported in the paper
56 | - model name: name of the checkpoint
57 | - model path: path of the checkpoint
58 | - bert model path: path to the pretrained model
59 | - max depth: maximum depth to keep in the decoding lattice
60 | - keep per depth: number of elements to keep track of in the decoding step
61 | - is greedy: whether or not use the greedy decoding, default is false
62 |
63 | # Exact Commands for Hexatagging
64 | The above commands can be used together with different taggers, models, and on different languages. To reproduce our Hexatagging results, here we put the exact commands used for training and evaluation of Hexatagger.
65 | ## Train
66 | ### PTB (English)
67 | ```bash
68 | CUDA_VISIBLE_DEVICES=0 python run.py train --lang English --max-depth 6 --tagger hexa --model bert --epochs 50 --batch-size 32 --lr 2e-5 --model-path xlnet-large-cased --output-path ./checkpoints/ --use-tensorboard True
69 | # model saved at ./checkpoints/English-hexa-bert-3e-05-50
70 | ```
71 | ### CTB (Chinese)
72 | ```bash
73 | CUDA_VISIBLE_DEVICES=0 python run.py train --lang Chinese --max-depth 6 --tagger hexa --model bert --epochs 50 --batch-size 32 --lr 2e-5 --model-path hfl/chinese-xlnet-mid --output-path ./checkpoints/ --use-tensorboard True
74 | # model saved at ./checkpoints/Chinese-hexa-bert-2e-05-50
75 | ```
76 |
77 | ### UD
78 | ```bash
79 | CUDA_VISIBLE_DEVICES=0 python run.py train --lang bg --max-depth 6 --tagger hexa --model bert --epochs 50 --batch-size 32 --lr 2e-5 --model-path bert-base-multilingual-cased --output-path ./checkpoints/ --use-tensorboard True
80 | ```
81 |
82 | ## Evaluate
83 | ### PTB
84 | ```bash
85 | python run.py evaluate --lang English --max-depth 10 --tagger hexa --bert-model-path xlnet-large-cased --model-name English-hexa-bert-3e-05-50 --batch-size 64 --model-path ./checkpoints/
86 | ```
87 |
88 | ### CTB
89 | ```bash
90 | python run.py evaluate --lang Chinese --max-depth 10 --tagger hexa --bert-model-path bert-base-chinese --model-name Chinese-hexa-bert-3e-05-50 --batch-size 64 --model-path ./checkpoints/
91 | ```
92 | ### UD
93 | ```bash
94 | python run.py evaluate --lang bg --max-depth 10 --tagger hexa --bert-model-path bert-base-multilingual-cased --model-name bg-hexa-bert-1e-05-50 --batch-size 64 --model-path ./checkpoints/
95 | ```
96 |
97 |
98 | ## Predict
99 | ### PTB
100 | ```bash
101 | python run.py predict --lang English --max-depth 10 --tagger hexa --bert-model-path xlnet-large-cased --model-name English-hexa-bert-3e-05-50 --batch-size 64 --model-path ./checkpoints/
102 | ```
103 |
104 | # Citation
105 | If you find this repository useful, please cite our papers:
106 | ```bibtex
107 | @inproceedings{amini-cotterell-2022-parsing,
108 | title = "On Parsing as Tagging",
109 | author = "Amini, Afra and
110 | Cotterell, Ryan",
111 | booktitle = "Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing",
112 | month = dec,
113 | year = "2022",
114 | address = "Abu Dhabi, United Arab Emirates",
115 | publisher = "Association for Computational Linguistics",
116 | url = "https://aclanthology.org/2022.emnlp-main.607",
117 | pages = "8884--8900",
118 | }
119 | ```
120 |
121 | ```bibtex
122 | @inproceedings{amini-etal-2023-hexatagging,
123 | title = "Hexatagging: Projective Dependency Parsing as Tagging",
124 | author = "Amini, Afra and
125 | Liu, Tianyu and
126 | Cotterell, Ryan",
127 | booktitle = "Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)",
128 | month = jul,
129 | year = "2023",
130 | address = "Toronto, Canada",
131 | publisher = "Association for Computational Linguistics",
132 | url = "https://aclanthology.org/2023.acl-short.124",
133 | pages = "1453--1464",
134 | }
135 | ```
136 |
137 |
--------------------------------------------------------------------------------
/const.py:
--------------------------------------------------------------------------------
1 | DUMMY_LABEL = "Y|X"
2 |
3 | DATA_PATH = "data/spmrl/"
4 | DEP_DATA_PATH = "data/bht/"
5 |
6 |
7 | TETRATAGGER = "tetra"
8 | HEXATAGGER = "hexa"
9 |
10 | TD_SR = "td-sr"
11 | BU_SR = "bu-sr"
12 | BERT = ["bert", "roberta", "robertaL"]
13 | BERTCRF = ["bert+crf", "roberta+crf", "robertaL+crf"]
14 | BERTLSTM = ["bert+lstm", "roberta+lstm", "robertaL+lstm"]
15 |
16 | BAQ = "Basque"
17 | CHN = "Chinese"
18 | CHN09 = "Chinese-conll09"
19 | FRE = "French"
20 | GER = "German"
21 | HEB = "Hebrew"
22 | HUN = "Hungarian"
23 | KOR = "Korean"
24 | POL = "Polish"
25 | SWE = "swedish"
26 | ENG = "English"
27 | LANG = [BAQ, CHN, CHN09, FRE, GER, HEB, HUN, KOR, POL, SWE, ENG,
28 | "bg","ca","cs","de","en","es", "fr","it","nl","no","ro","ru"]
29 |
30 | UD_LANG_TO_DIR = {
31 | "bg": "/UD_Bulgarian-BTB/bg_btb-ud-{split}.conllu",
32 | "ca": "/UD_Catalan-AnCora/ca_ancora-ud-{split}.conllu",
33 | "cs": "/UD_Czech-PDT/cs_pdt-ud-{split}.conllu",
34 | "de": "/UD_German-GSD/de_gsd-ud-{split}.conllu",
35 | "en": "/UD_English-EWT/en_ewt-ud-{split}.conllu",
36 | "es": "/UD_Spanish-AnCora/es_ancora-ud-{split}.conllu",
37 | "fr": "/UD_French-GSD/fr_gsd-ud-{split}.conllu",
38 | "it": "/UD_Italian-ISDT/it_isdt-ud-{split}.conllu",
39 | "nl": "/UD_Dutch-Alpino/nl_alpino-ud-{split}.conllu",
40 | "no": "/UD_Norwegian-Bokmaal/no_bokmaal-ud-{split}.conllu",
41 | "ro": "/UD_Romanian-RRT/ro_rrt-ud-{split}.conllu",
42 | "ru": "/UD_Russian-SynTagRus/ru_syntagrus-ud-{split}.conllu",
43 | }
44 |
--------------------------------------------------------------------------------
/data/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file
4 | !.gitignore
5 | !bht/*
6 | !pos/*
7 | !vocab/*
8 |
9 |
--------------------------------------------------------------------------------
/data/dep2bht.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 | from nltk.corpus.reader import DependencyCorpusReader
4 | from nltk.corpus.reader.util import *
5 | from nltk.tree import Tree
6 | from tqdm import tqdm as tq
7 | import sys
8 |
9 | LANG_TO_DIR = {
10 | "bg": "/UD_Bulgarian-BTB/bg_btb-ud-{split}.conllu",
11 | "ca": "/UD_Catalan-AnCora/ca_ancora-ud-{split}.conllu",
12 | "cs": "/UD_Czech-PDT/cs_pdt-ud-{split}.conllu",
13 | "de": "/UD_German-GSD/de_gsd-ud-{split}.conllu",
14 | "en": "/UD_English-EWT/en_ewt-ud-{split}.conllu",
15 | "es": "/UD_Spanish-AnCora/es_ancora-ud-{split}.conllu",
16 | "fr": "/UD_French-GSD/fr_gsd-ud-{split}.conllu",
17 | "it": "/UD_Italian-ISDT/it_isdt-ud-{split}.conllu",
18 | "nl": "/UD_Dutch-Alpino/nl_alpino-ud-{split}.conllu",
19 | "no": "/UD_Norwegian-Bokmaal/no_bokmaal-ud-{split}.conllu",
20 | "ro": "/UD_Romanian-RRT/ro_rrt-ud-{split}.conllu",
21 | "ru": "/UD_Russian-SynTagRus/ru_syntagrus-ud-{split}.conllu",
22 | }
23 |
24 |
25 | # using ^^^ as delimiter
26 | # since ^ never appears
27 |
28 | def augment_constituent_tree(const_tree, dep_tree):
29 | # augment constituent tree leaves into dicts
30 |
31 | assert len(const_tree.leaves()) == len(dep_tree.nodes) - 1
32 |
33 | leaf_nodes = list(const_tree.treepositions('leaves'))
34 | for i, pos in enumerate(leaf_nodes):
35 | x = dep_tree.nodes[1 + i]
36 | y = const_tree[pos].replace("\\", "")
37 | assert (x['word'] == y), (const_tree, dep_tree)
38 |
39 | # expanding leaves with dependency info
40 | const_tree[pos] = {
41 | "word": dep_tree.nodes[1 + i]["word"],
42 | "head": dep_tree.nodes[1 + i]["head"],
43 | "rel": dep_tree.nodes[1 + i]["rel"]
44 | }
45 |
46 | return const_tree
47 |
48 |
49 | def get_bht(root, offset=0):
50 | # Return:
51 | # The index of the head in this tree
52 | # and:
53 | # The dependant that this tree points to
54 | if type(root) is dict:
55 | # leaf node already, return its head
56 | return 0, root['head']
57 |
58 | # word offset in the current tree
59 | words_seen = 0
60 | root_projection = (offset, offset + len(root.leaves()))
61 | # init the return values to be None
62 | head, root_points_to = None, None
63 |
64 | # traverse the consituent tree
65 | for idx, child in enumerate(root):
66 | head_of_child, child_points_to = get_bht(child, offset + words_seen)
67 | if type(child) == type(root):
68 | words_seen += len(child.leaves())
69 | else:
70 | # leaf node visited
71 | words_seen += 1
72 |
73 | if child_points_to < root_projection[0] or child_points_to >= root_projection[1]:
74 | # pointing to outside of the current tree
75 | if head is not None:
76 | print("error! Non-projectivity detected.", root_projection, idx)
77 | continue # choose the first child as head
78 | head = idx
79 | root_points_to = child_points_to
80 |
81 | if root_points_to is None:
82 | # self contained sub-sentence
83 | print("multiple roots detected", root)
84 | root_points_to = 0
85 |
86 | original_label = root.label()
87 | root.set_label(f"{original_label}^^^{head}")
88 |
89 | return head, root_points_to
90 |
91 |
92 | def dep2lex(dep_tree, language="English"):
93 | # left-first attachment
94 | def dfs(node_idx):
95 | dependencies = []
96 | for rel in dep_tree.nodes[node_idx]["deps"]:
97 | for dependent in dep_tree.nodes[node_idx]["deps"][rel]:
98 | dependencies.append((dependent, rel))
99 |
100 | dependencies.append((node_idx, "SELF"))
101 | if len(dependencies) == 1:
102 | # no dependent at all, leaf node
103 | return Tree(
104 | f"X^^^{dep_tree.nodes[node_idx]['rel']}",
105 | [
106 | Tree(
107 | f"{dep_tree.nodes[node_idx]['ctag']}",
108 | [
109 | f"{dep_tree.nodes[node_idx]['word']}"
110 | ]
111 | )
112 | ]
113 | )
114 | # Now, len(dependencies) >= 2, sort dependents
115 | dependencies = sorted(dependencies)
116 |
117 | lex_tree_root = Tree(f"X^^^{0}", [])
118 | empty_slot = lex_tree_root # place to fill in the next subtree
119 | for idx, dependency in enumerate(dependencies):
120 | if dependency[0] < node_idx:
121 | # waiting for a head in the right child
122 | lex_tree_root.set_label(f"X^^^{1}")
123 | if len(lex_tree_root) == 0:
124 | # the first non-head child
125 | lex_tree_root.insert(0, dfs(dependency[0]))
126 | else:
127 | # not the first non-head child
128 | # insert a sub tree: \
129 | # X^^^1
130 | # / \
131 | # word [empty_slot]
132 | empty_slot.insert(1, Tree(f"X^^^{1}", [dfs(dependency[0])]))
133 | empty_slot = empty_slot[1]
134 | elif dependency[0] == node_idx:
135 | tree_piece = Tree(
136 | f"X^^^{dep_tree.nodes[dependency[0]]['rel']}",
137 | [
138 | Tree(
139 | f"{dep_tree.nodes[dependency[0]]['ctag']}",
140 | [
141 | f"{dep_tree.nodes[dependency[0]]['word']}"
142 | ]
143 | )
144 | ]
145 | )
146 | if len(empty_slot) == 1:
147 | # This is the head
148 | empty_slot.insert(1, tree_piece)
149 | else: # len(empty_slot) == 0
150 | lex_tree_root = tree_piece
151 | pass
152 | else:
153 | # moving on to the right of the head
154 | lex_tree_root = Tree(f"X^^^{0}", [lex_tree_root, dfs(dependency[0])])
155 | return lex_tree_root
156 |
157 | return dfs(
158 | dep_tree.nodes[0]["deps"]["root"][0] if "root" in dep_tree.nodes[0]["deps"] else
159 | dep_tree.nodes[0]["deps"]["ROOT"][0]
160 | )
161 |
162 |
163 | def dep2lex_right_first(dep_tree, language="English"):
164 | # right-first attachment
165 |
166 | def dfs(node_idx):
167 | dependencies = []
168 | for rel in dep_tree.nodes[node_idx]["deps"]:
169 | for dependent in dep_tree.nodes[node_idx]["deps"][rel]:
170 | dependencies.append((dependent, rel))
171 |
172 | dependencies.append((node_idx, "SELF"))
173 | if len(dependencies) == 1:
174 | # no dependent at all, leaf node
175 | return Tree(
176 | f"X^^^{dep_tree.nodes[node_idx]['rel']}",
177 | [
178 | Tree(
179 | f"{dep_tree.nodes[node_idx]['ctag']}",
180 | [
181 | f"{dep_tree.nodes[node_idx]['word']}"
182 | ]
183 | )
184 | ]
185 | )
186 | # Now, len(dependencies) >= 2, sort dependents
187 | dependencies = sorted(dependencies)
188 |
189 | lex_tree_root = Tree(f"X^^^{0}", [])
190 | empty_slot = lex_tree_root # place to fill in the next subtree
191 | for idx, dependency in enumerate(dependencies):
192 | if dependency[0] < node_idx:
193 | # waiting for a head in the right child
194 | lex_tree_root.set_label(f"X^^^{1}")
195 | if len(lex_tree_root) == 0:
196 | # the first non-head child
197 | lex_tree_root.insert(
198 | 0, dfs(dependency[0])
199 | )
200 | else:
201 | # not the first non-head child
202 | # insert a sub tree: \
203 | # X^^^1
204 | # / \
205 | # word [empty_slot]
206 | empty_slot.insert(
207 | 1,
208 | Tree(f"X^^^{1}", [
209 | dfs(dependency[0])
210 | ])
211 | )
212 | empty_slot = empty_slot[1]
213 | elif dependency[0] == node_idx:
214 | # This is the head
215 | right_branch_root = Tree(
216 | f"X^^^{dep_tree.nodes[dependency[0]]['rel']}",
217 | [
218 | Tree(
219 | f"{dep_tree.nodes[dependency[0]]['ctag']}",
220 | [
221 | f"{dep_tree.nodes[dependency[0]]['word']}"
222 | ]
223 | )
224 | ]
225 | )
226 | else:
227 | # moving on to the right of the head
228 | right_branch_root = Tree(
229 | f"X^^^{0}",
230 | [
231 | right_branch_root,
232 | dfs(dependency[0])
233 | ]
234 | )
235 |
236 | if len(empty_slot) == 0:
237 | lex_tree_root = right_branch_root
238 | else:
239 | empty_slot.insert(1, right_branch_root)
240 |
241 | return lex_tree_root
242 |
243 | return dfs(
244 | dep_tree.nodes[0]["deps"]["root"][0] if "root" in dep_tree.nodes[0]["deps"] else
245 | dep_tree.nodes[0]["deps"]["ROOT"][0]
246 | )
247 |
248 |
249 | if __name__ == "__main__":
250 | repo_directory = os.path.abspath(__file__)
251 |
252 | if len(sys.argv) > 1:
253 | path = sys.argv[1]
254 | reader = DependencyCorpusReader(
255 | os.path.dirname(repo_directory),
256 | [path],
257 | )
258 | dep_trees = reader.parsed_sents(path)
259 |
260 | bhts = []
261 | for dep_tree in tq(dep_trees):
262 | lex_tree = dep2lex(dep_tree, language="input")
263 | bhts.append(lex_tree)
264 | lex_tree_leaves = tuple(lex_tree.leaves())
265 | dep_tree_leaves = tuple(
266 | [str(node["word"]) for _, node in sorted(dep_tree.nodes.items())])
267 |
268 | dep_tree_leaves = dep_tree_leaves[1:]
269 |
270 | print(f"Writing BHTs to {os.path.dirname(repo_directory)}/input.bht.test")
271 | with open(os.path.dirname(repo_directory) + f"/bht/input.bht.test",
272 | "w") as fout:
273 | for lex_tree in bhts:
274 | fout.write(lex_tree._pformat_flat("", "()", False) + "\n")
275 |
276 | exit(0)
277 |
278 | for language in [
279 | "English", # PTB
280 | "Chinese", # CTB
281 | "bg","ca","cs","de","en","es","fr","it","nl","no","ro","ru" # UD2.2
282 | ]:
283 | print(f"Processing {language}...")
284 | if language == "English":
285 | path = os.path.dirname(repo_directory) + "/ptb/ptb_{split}_3.3.0.sd.clean"
286 | paths = [path.format(split=split) for split in ["train", "dev", "test"]]
287 | elif language == "Chinese":
288 | path = os.path.dirname(repo_directory) + "/ctb/{split}.ctb.conll"
289 | paths = [path.format(split=split) for split in ["train", "dev", "test"]]
290 | elif language in ["bg", "ca","cs","de","en","es","fr","it","nl","no","ro","ru"]:
291 | path = os.path.dirname(repo_directory)+f"/ctb_ptb_ud22/ud2.2/{LANG_TO_DIR[language]}"
292 | paths = []
293 | groups = re.match(r'(\w+)_\w+-ud-(\w+)\.conllu', os.path.split(path)[-1])
294 |
295 | for split in ["train", "dev", "test"]:
296 | conll_path = path.format(split=split)
297 | command = f"cd ../malt/maltparser-1.9.2/; java -jar maltparser-1.9.2.jar -c {language}_{split} -m proj" \
298 | f" -i {conll_path} -o {conll_path}.proj -pp head"
299 | os.system(command)
300 | paths.append(conll_path + ".proj")
301 |
302 | reader = DependencyCorpusReader(
303 | os.path.dirname(repo_directory),
304 | paths,
305 | )
306 |
307 | for path, split in zip(paths, ["train", "dev", "test"]):
308 | print(f"Converting {path} to lexicalized tree")
309 |
310 | with tempfile.NamedTemporaryFile(mode="w") as tmp_file:
311 | with open(path, "r", encoding='utf-8-sig') as fin:
312 | lines = [line for line in fin.readlines() if not line.startswith("#")]
313 | for idl, line in enumerate(lines):
314 | if line.strip() == "":
315 | continue
316 | assert len(line.split("\t")) == 10, line
317 | if '\xad' in line:
318 | lines[idl] = lines[idl].replace('\xad', '')
319 | if '\ufeff' in line:
320 | lines[idl] = lines[idl].replace('\ufeff', '')
321 | if " " in line:
322 | lines[idl] = lines[idl].replace(" ", ",")
323 | if ")" in line:
324 | lines[idl] = lines[idl].replace(")", "-RRB-")
325 | if "(" in line:
326 | lines[idl] = lines[idl].replace("(", "-LRB-")
327 |
328 | tmp_file.writelines(lines)
329 | tmp_file.flush()
330 | tmp_file.seek(0)
331 |
332 | dep_trees = reader.parsed_sents(tmp_file.name)
333 |
334 | bhts = []
335 | for dep_tree in tq(dep_trees):
336 | lex_tree = dep2lex(dep_tree, language=language)
337 | bhts.append(lex_tree)
338 | lex_tree_leaves = tuple(lex_tree.leaves())
339 | dep_tree_leaves = tuple(
340 | [str(node["word"]) for _, node in sorted(dep_tree.nodes.items())])
341 |
342 | dep_tree_leaves = dep_tree_leaves[1:]
343 |
344 | print(f"Writing BHTs to {os.path.dirname(repo_directory)}/{language}.bht.{split}")
345 | with open(os.path.dirname(repo_directory) + f"/bht/{language}.bht.{split}",
346 | "w") as fout:
347 | for lex_tree in bhts:
348 | fout.write(lex_tree._pformat_flat("", "()", False) + "\n")
349 |
--------------------------------------------------------------------------------
/data/pos/pos.chinese.json:
--------------------------------------------------------------------------------
1 | {"TOP": 1, "NN": 2, "NR": 3, "VV": 4, "PU": 5, "LC": 6, "JJ": 7, "AD": 8, "CD": 9, "VA": 10, "OD": 11, "DEG": 12, "VE": 13, "NT": 14, "PN": 15, "DEC": 16, "CC": 17, "DT": 18, "P": 19, "M": 20, "VC": 21, "BA": 22, "SP": 23, "AS": 24, "SB": 25, "FW": 26, "CS": 27, "X": 28, "ETC": 29, "DER": 30, "DEV": 31, "LB": 32, "MSP": 33, "IJ": 34}
--------------------------------------------------------------------------------
/data/pos/pos.english.json:
--------------------------------------------------------------------------------
1 | {"IN": 1, "DT": 2, "NNP": 3, "CD": 4, "NN": 5, "``": 6, "''": 7, "POS": 8, "-LRB-": 9, "JJ": 10, "NNS": 11, "VBP": 12, ",": 13, "CC": 14, "-RRB-": 15, "VBN": 16, "VBD": 17, "RB": 18, "TO": 19, ".": 20, "VBZ": 21, "NNPS": 22, "PRP": 23, "PRP$": 24, "VB": 25, "MD": 26, "VBG": 27, "RBR": 28, ":": 29, "WP": 30, "WDT": 31, "JJR": 32, "PDT": 33, "RBS": 34, "JJS": 35, "WRB": 36, "$": 37, "RP": 38, "FW": 39, "EX": 40, "#": 41, "WP$": 42, "UH": 43, "SYM": 44, "LS": 45}
--------------------------------------------------------------------------------
/data/vocab/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file
4 | !.gitignore
5 |
6 |
--------------------------------------------------------------------------------
/header-hexa.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rycolab/parsing-as-tagging/e7a0be2d92dc44c8c5920a33db10dcfab8573dc1/header-hexa.png
--------------------------------------------------------------------------------
/header.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rycolab/parsing-as-tagging/e7a0be2d92dc44c8c5920a33db10dcfab8573dc1/header.jpg
--------------------------------------------------------------------------------
/learning/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rycolab/parsing-as-tagging/e7a0be2d92dc44c8c5920a33db10dcfab8573dc1/learning/__init__.py
--------------------------------------------------------------------------------
/learning/crf.py:
--------------------------------------------------------------------------------
1 | # Code is inspired from https://github.com/mtreviso/linear-chain-crf
2 | import torch
3 | from torch import nn
4 |
5 |
6 | class CRF(nn.Module):
7 | """
8 | Linear-chain Conditional Random Field (CRF).
9 | Args:
10 | nb_labels (int): number of labels in your tagset, including special symbols.
11 | device: cpu or gpu,
12 | batch_first (bool): Whether the first dimension represents the batch dimension.
13 | """
14 |
15 | def __init__(
16 | self, nb_labels, device=None, batch_first=True
17 | ):
18 | super(CRF, self).__init__()
19 |
20 | self.nb_labels = nb_labels
21 | self.batch_first = batch_first
22 | self.device = device
23 |
24 | self.transitions = nn.Parameter(torch.empty(self.nb_labels, self.nb_labels))
25 | self.start_transitions = nn.Parameter(torch.empty(self.nb_labels))
26 | self.end_transitions = nn.Parameter(torch.empty(self.nb_labels))
27 |
28 | self.init_weights()
29 |
30 | def init_weights(self):
31 | nn.init.uniform_(self.start_transitions, -0.1, 0.1)
32 | nn.init.uniform_(self.end_transitions, -0.1, 0.1)
33 | nn.init.uniform_(self.transitions, -0.1, 0.1)
34 |
35 | def forward(self, emissions, tags, mask=None):
36 | """Compute the negative log-likelihood. See `log_likelihood` method."""
37 | nll = -self.log_likelihood(emissions, tags, mask=mask)
38 | return nll
39 |
40 | def log_likelihood(self, emissions, tags, mask=None):
41 | """Compute the probability of a sequence of tags given a sequence of
42 | emissions scores.
43 | Args:
44 | emissions (torch.Tensor): Sequence of emissions for each label.
45 | Shape of (batch_size, seq_len, nb_labels) if batch_first is True,
46 | (seq_len, batch_size, nb_labels) otherwise.
47 | tags (torch.LongTensor): Sequence of labels.
48 | Shape of (batch_size, seq_len) if batch_first is True,
49 | (seq_len, batch_size) otherwise.
50 | mask (torch.FloatTensor, optional): Tensor representing valid positions.
51 | If None, all positions are considered valid.
52 | Shape of (batch_size, seq_len) if batch_first is True,
53 | (seq_len, batch_size) otherwise.
54 | Returns:
55 | torch.Tensor: the log-likelihoods for each sequence in the batch.
56 | Shape of (batch_size,)
57 | """
58 |
59 | # fix tensors order by setting batch as the first dimension
60 | if not self.batch_first:
61 | emissions = emissions.transpose(0, 1)
62 | tags = tags.transpose(0, 1)
63 |
64 | if mask is None:
65 | mask = torch.ones(emissions.shape[:2], dtype=torch.float).to(self.device)
66 |
67 | scores = self._compute_scores(emissions, tags, mask=mask)
68 | partition = self._compute_log_partition(emissions, mask=mask)
69 | return torch.sum(scores - partition)
70 |
71 | def _compute_scores(self, emissions, tags, mask, last_idx=None):
72 | """Compute the scores for a given batch of emissions with their tags.
73 | Args:
74 | emissions (torch.Tensor): (batch_size, seq_len, nb_labels)
75 | tags (Torch.LongTensor): (batch_size, seq_len)
76 | mask (Torch.FloatTensor): (batch_size, seq_len)
77 | Returns:
78 | torch.Tensor: Scores for each batch.
79 | Shape of (batch_size,)
80 | """
81 | batch_size, seq_length = tags.shape
82 | scores = torch.zeros(batch_size).to(self.device)
83 |
84 | alpha_mask = torch.zeros((batch_size,), dtype=int).to(self.device)
85 | previous_tags = torch.zeros((batch_size,), dtype=int).to(self.device)
86 |
87 | for i in range(0, seq_length):
88 | is_valid = mask[:, i]
89 |
90 | current_tags = tags[:, i]
91 |
92 | e_scores = emissions[:, i].gather(1, current_tags.unsqueeze(1)).squeeze()
93 |
94 | first_t_scores = self.start_transitions[current_tags]
95 | t_scores = self.transitions[previous_tags, current_tags]
96 | t_scores = (1 - alpha_mask) * first_t_scores + alpha_mask * t_scores
97 | alpha_mask = is_valid + (1 - is_valid) * alpha_mask
98 |
99 | e_scores = e_scores * is_valid
100 | t_scores = t_scores * is_valid
101 |
102 | scores += e_scores + t_scores
103 |
104 | previous_tags = current_tags * is_valid + previous_tags * (1 - is_valid)
105 |
106 | scores += self.end_transitions[previous_tags]
107 |
108 | return scores
109 |
110 | def _compute_log_partition(self, emissions, mask):
111 | """Compute the partition function in log-space using the forward-algorithm.
112 | Args:
113 | emissions (torch.Tensor): (batch_size, seq_len, nb_labels)
114 | mask (Torch.FloatTensor): (batch_size, seq_len)
115 | Returns:
116 | torch.Tensor: the partition scores for each batch.
117 | Shape of (batch_size,)
118 | """
119 | batch_size, seq_length, nb_labels = emissions.shape
120 |
121 | alphas = torch.zeros((batch_size, nb_labels)).to(self.device)
122 | alpha_mask = torch.zeros((batch_size, 1), dtype=int).to(self.device)
123 |
124 | for i in range(0, seq_length):
125 | is_valid = mask[:, i].unsqueeze(-1)
126 |
127 | first_alphas = self.start_transitions + emissions[:, i]
128 |
129 | # (bs, nb_labels) -> (bs, 1, nb_labels)
130 | e_scores = emissions[:, i].unsqueeze(1)
131 | # (nb_labels, nb_labels) -> (bs, nb_labels, nb_labels)
132 | t_scores = self.transitions.unsqueeze(0)
133 | # (bs, nb_labels) -> (bs, nb_labels, 1)
134 | a_scores = alphas.unsqueeze(2)
135 | scores = e_scores + t_scores + a_scores
136 |
137 | new_alphas = torch.logsumexp(scores, dim=1)
138 |
139 | new_alphas = (1 - alpha_mask) * first_alphas + alpha_mask * new_alphas
140 | alpha_mask = is_valid + (1 - is_valid) * alpha_mask
141 |
142 | alphas = is_valid * new_alphas + (1 - is_valid) * alphas
143 |
144 | end_scores = alphas + self.end_transitions
145 |
146 | return torch.logsumexp(end_scores, dim=1)
147 |
--------------------------------------------------------------------------------
/learning/dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch.nn.utils.rnn import pad_sequence
8 |
9 | BERT_TOKEN_MAPPING = {
10 | "-LRB-": "(",
11 | "-RRB-": ")",
12 | "-LCB-": "{",
13 | "-RCB-": "}",
14 | "-LSB-": "[",
15 | "-RSB-": "]",
16 | "``": '"',
17 | "''": '"',
18 | "`": "'",
19 | '«': '"',
20 | '»': '"',
21 | '‘': "'",
22 | '’': "'",
23 | '“': '"',
24 | '”': '"',
25 | '„': '"',
26 | '‹': "'",
27 | '›': "'",
28 | "\u2013": "--", # en dash
29 | "\u2014": "--", # em dash
30 | }
31 |
32 |
33 | def ptb_unescape(sent):
34 | cleaned_words = []
35 | for word in sent:
36 | word = BERT_TOKEN_MAPPING.get(word, word)
37 | word = word.replace('\\/', '/').replace('\\*', '*')
38 | # Mid-token punctuation occurs in biomedical text
39 | word = word.replace('-LSB-', '[').replace('-RSB-', ']')
40 | word = word.replace('-LRB-', '(').replace('-RRB-', ')')
41 | if word == "n't" and cleaned_words:
42 | cleaned_words[-1] = cleaned_words[-1] + "n"
43 | word = "'t"
44 | cleaned_words.append(word)
45 | return cleaned_words
46 |
47 |
48 | class TaggingDataset(torch.utils.data.Dataset):
49 | def __init__(self, split, tokenizer, tag_system, reader, device, is_tetratags=False,
50 | language="english", pad_to_len=None,
51 | max_train_len=350):
52 | self.reader = reader
53 | self.split = split
54 | self.trees = self.reader.parsed_sents(split)
55 | self.tokenizer = tokenizer
56 | self.language = language
57 | self.tag_system = tag_system
58 | self.pad_token_id = self.tokenizer.pad_token_id
59 | self.pad_to_len = pad_to_len
60 | self.device = device
61 | self.is_tetratags = is_tetratags
62 |
63 | if "train" in split and max_train_len is not None:
64 | # To speed up training, we only train on short sentences.
65 | print(len(self.trees), "trees before filtering")
66 | self.trees = [
67 | tree for tree in self.trees if
68 | (len(tree.leaves()) <= max_train_len and len(tree.leaves()) >= 2)]
69 | print(len(self.trees), "trees after filtering")
70 | else:
71 | # speed up!
72 | self.trees = [
73 | tree for tree in self.trees if len(tree.leaves()) <= max_train_len]
74 |
75 | if not os.path.exists(
76 | f"./data/pos.{language.lower()}.json") and "train" in split:
77 | self.pos_dict = self.get_pos_dict()
78 | with open(f"./data/pos.{language.lower()}.json", 'w') as fp:
79 | json.dump(self.pos_dict, fp)
80 | else:
81 | with open(f"./data/pos.{language.lower()}.json", 'r') as fp:
82 | self.pos_dict = json.load(fp)
83 |
84 | def get_pos_dict(self):
85 | pos_dict = {}
86 | for t in self.trees:
87 | for _, x in t.pos():
88 | pos_dict[x] = pos_dict.get(x, 1 + len(pos_dict))
89 | return pos_dict
90 |
91 | def __len__(self):
92 | return len(self.trees)
93 |
94 | def __getitem__(self, index):
95 | tree = self.trees[index]
96 | words = ptb_unescape(tree.leaves())
97 | pos_tags = [self.pos_dict.get(x[1], 0) for x in tree.pos()]
98 |
99 | if 'albert' in str(type(self.tokenizer)):
100 | # albert is case insensitive!
101 | words = [w.lower() for w in words]
102 |
103 | encoded = self.tokenizer.encode_plus(' '.join(words))
104 | word_end_positions = [
105 | encoded.char_to_token(i)
106 | for i in np.cumsum([len(word) + 1 for word in words]) - 2]
107 | word_start_positions = [
108 | encoded.char_to_token(i)
109 | for i in np.cumsum([0] + [len(word) + 1 for word in words])[:-1]]
110 |
111 | input_ids = torch.tensor(encoded['input_ids'], dtype=torch.long)
112 | pair_ids = torch.zeros_like(input_ids)
113 | end_of_word = torch.zeros_like(input_ids)
114 | pos_ids = torch.zeros_like(input_ids)
115 |
116 | tag_ids = self.tag_system.tree_to_ids_pipeline(tree)
117 |
118 | # Pack both leaf and internal tag ids into a single "label" field.
119 | # (The huggingface API isn't flexible enough to use multiple label fields)
120 | tag_ids = [tag_id + 1 for tag_id in tag_ids] + [0]
121 | tag_ids = torch.tensor(tag_ids, dtype=torch.long)
122 | labels = torch.zeros_like(input_ids)
123 |
124 | odd_labels = tag_ids[1::2]
125 | if self.is_tetratags:
126 | even_labels = tag_ids[
127 | ::2] - self.tag_system.decode_moderator.internal_tag_vocab_size
128 | labels[word_end_positions] = (
129 | odd_labels * (
130 | self.tag_system.decode_moderator.leaf_tag_vocab_size + 1) + even_labels)
131 | pos_ids[word_end_positions] = torch.tensor(pos_tags, dtype=torch.long)
132 | pair_ids[word_end_positions] = torch.as_tensor(word_start_positions)
133 |
134 | end_of_word[word_end_positions] = 1
135 | end_of_word[word_end_positions[-1]] = 2 # last word
136 | else:
137 | even_labels = tag_ids[::2]
138 | labels[word_end_positions] = (
139 | odd_labels * (len(self.tag_system.tag_vocab) + 1) + even_labels)
140 |
141 | if self.pad_to_len is not None:
142 | pad_amount = self.pad_to_len - input_ids.shape[0]
143 | if pad_amount >= 0:
144 | input_ids = F.pad(input_ids, [0, pad_amount], value=self.pad_token_id)
145 | pos_ids = F.pad(pos_ids, [0, pad_amount], value=0)
146 | labels = F.pad(labels, [0, pad_amount], value=0)
147 |
148 | return {
149 | 'input_ids': input_ids,
150 | 'pos_ids': pos_ids,
151 | 'pair_ids': pair_ids,
152 | 'end_of_word': end_of_word,
153 | 'labels': labels
154 | }
155 |
156 | def collate(self, batch):
157 | # for GPT-2, self.pad_token_id is None
158 | pad_token_id = self.pad_token_id if self.pad_token_id is not None else -100
159 | input_ids = pad_sequence(
160 | [item['input_ids'] for item in batch],
161 | batch_first=True, padding_value=pad_token_id)
162 |
163 | attention_mask = (input_ids != pad_token_id).float()
164 | # for GPT-2, change -100 back into 0
165 | input_ids = torch.where(
166 | input_ids == -100,
167 | 0,
168 | input_ids
169 | )
170 |
171 | pair_ids = pad_sequence(
172 | [item['pair_ids'] for item in batch],
173 | batch_first=True, padding_value=pad_token_id)
174 | end_of_word = pad_sequence(
175 | [item['end_of_word'] for item in batch],
176 | batch_first=True, padding_value=0)
177 | pos_ids = pad_sequence(
178 | [item['pos_ids'] for item in batch],
179 | batch_first=True, padding_value=0)
180 |
181 | labels = pad_sequence(
182 | [item['labels'] for item in batch],
183 | batch_first=True, padding_value=0)
184 |
185 | return {
186 | 'input_ids': input_ids,
187 | 'pos_ids': pos_ids,
188 | 'pair_ids': pair_ids,
189 | 'end_of_word': end_of_word,
190 | 'attention_mask': attention_mask,
191 | 'labels': labels,
192 | }
193 |
--------------------------------------------------------------------------------
/learning/decode.py:
--------------------------------------------------------------------------------
1 | ## code adopted from: https://github.com/nikitakit/tetra-tagging
2 |
3 | import numpy as np
4 |
5 |
6 | class Beam:
7 | def __init__(self, scores, stack_depths, prev, backptrs, labels):
8 | self.scores = scores
9 | self.stack_depths = stack_depths
10 | self.prev = prev
11 | self.backptrs = backptrs
12 | self.labels = labels
13 |
14 |
15 | class BeamSearch:
16 | def __init__(
17 | self,
18 | tag_moderator,
19 | initial_stack_depth,
20 | max_depth=5,
21 | min_depth=1,
22 | keep_per_depth=1,
23 | crf_transitions=None,
24 | initial_label=None,
25 | ):
26 | # Save parameters
27 | self.tag_moderator = tag_moderator
28 | self.valid_depths = np.arange(min_depth, max_depth)
29 | self.keep_per_depth = keep_per_depth
30 | self.max_depth = max_depth
31 | self.min_depth = min_depth
32 | self.crf_transitions = crf_transitions
33 |
34 | # Initialize the beam
35 | scores = np.zeros(1, dtype=float)
36 | stack_depths = np.full(1, initial_stack_depth)
37 | prev = backptrs = labels = None
38 | if initial_label is not None:
39 | labels = np.full(1, initial_label)
40 | self.beam = Beam(scores, stack_depths, prev, backptrs, labels)
41 |
42 | def compute_new_scores(self, label_log_probs, is_last):
43 | if self.crf_transitions is None:
44 | return self.beam.scores[:, None] + label_log_probs
45 | else:
46 | if self.beam.labels is not None:
47 | all_new_scores = self.beam.scores[:, None] + label_log_probs + \
48 | self.crf_transitions["transitions"][self.beam.labels]
49 | else:
50 | all_new_scores = self.beam.scores[:, None] + label_log_probs + \
51 | self.crf_transitions["start_transitions"]
52 | if is_last:
53 | all_new_scores += self.crf_transitions["end_transitions"]
54 | return all_new_scores
55 |
56 | # This extra mask layer takes care of invalid reduce actions when there is not an empty
57 | # slot in the tree, which is needed in the top-down shift reduce tagging schema
58 | def extra_mask_layer(self, all_new_scores, all_new_stack_depths):
59 | depth_mask = np.zeros(all_new_stack_depths.shape)
60 | depth_mask[all_new_stack_depths < 0] = -np.inf
61 | depth_mask[all_new_stack_depths > self.max_depth] = -np.inf
62 | all_new_scores = all_new_scores + depth_mask
63 |
64 | all_new_stack_depths = (
65 | all_new_stack_depths
66 | + self.tag_moderator.stack_depth_change_by_id
67 | )
68 | return all_new_scores, all_new_stack_depths
69 |
70 | def advance(self, label_logits, is_last=False):
71 | label_log_probs = label_logits
72 |
73 | all_new_scores = self.compute_new_scores(label_log_probs, is_last)
74 | if self.tag_moderator.mask_binarize and self.beam.labels is not None:
75 | labels = self.beam.labels
76 | all_new_scores = self.tag_moderator.mask_scores_for_binarization(labels,
77 | all_new_scores)
78 |
79 | if self.tag_moderator.stack_depth_change_by_id_l2 is not None:
80 | all_new_stack_depths = (
81 | self.beam.stack_depths[:, None]
82 | + self.tag_moderator.stack_depth_change_by_id_l2[None, :]
83 | )
84 | all_new_scores, all_new_stack_depths = self.extra_mask_layer(all_new_scores,
85 | all_new_stack_depths)
86 | else:
87 | all_new_stack_depths = (
88 | self.beam.stack_depths[:, None]
89 | + self.tag_moderator.stack_depth_change_by_id[None, :]
90 | )
91 |
92 | masked_scores = all_new_scores[None, :, :] + np.where(
93 | all_new_stack_depths[None, :, :] == self.valid_depths[:, None, None],
94 | 0.0, -np.inf,
95 | )
96 | masked_scores = masked_scores.reshape(self.valid_depths.shape[0], -1)
97 | idxs = np.argsort(-masked_scores)[:, : self.keep_per_depth].flatten()
98 | backptrs, labels = np.unravel_index(idxs, all_new_scores.shape)
99 |
100 | transition_valid = all_new_stack_depths[
101 | backptrs, labels
102 | ] == self.valid_depths.repeat(self.keep_per_depth)
103 |
104 | backptrs = backptrs[transition_valid]
105 | labels = labels[transition_valid]
106 |
107 | self.beam = Beam(
108 | all_new_scores[backptrs, labels],
109 | all_new_stack_depths[backptrs, labels],
110 | self.beam,
111 | backptrs,
112 | labels,
113 | )
114 |
115 | def get_path(self, idx=0, required_stack_depth=1):
116 | if required_stack_depth is not None:
117 | assert self.beam.stack_depths[idx] == required_stack_depth
118 | score = self.beam.scores[idx]
119 | assert score > -np.inf
120 |
121 | beam = self.beam
122 | label_idxs = []
123 | while beam.prev is not None:
124 | label_idxs.insert(0, beam.labels[idx])
125 | idx = beam.backptrs[idx]
126 | beam = beam.prev
127 |
128 | return score, label_idxs
129 |
130 |
131 | class GreedySearch(BeamSearch):
132 | def advance(self, label_logits, is_last=False):
133 | label_log_probs = label_logits
134 |
135 | all_new_scores = self.compute_new_scores(label_log_probs, is_last)
136 | if self.tag_moderator.mask_binarize and self.beam.labels is not None:
137 | labels = self.beam.labels
138 | all_new_scores = self.tag_moderator.mask_scores_for_binarization(labels,
139 | all_new_scores)
140 |
141 | if self.tag_moderator.stack_depth_change_by_id_l2 is not None:
142 | all_new_stack_depths = (
143 | self.beam.stack_depths[:, None]
144 | + self.tag_moderator.stack_depth_change_by_id_l2[None, :]
145 | )
146 |
147 | all_new_scores, all_new_stack_depths = self.extra_mask_layer(all_new_scores,
148 | all_new_stack_depths)
149 | else:
150 | all_new_stack_depths = (
151 | self.beam.stack_depths[:, None]
152 | + self.tag_moderator.stack_depth_change_by_id[None, :]
153 | )
154 |
155 | masked_scores = all_new_scores + np.where((all_new_stack_depths >= self.min_depth)
156 | & (all_new_stack_depths <= self.max_depth),
157 | 0.0,
158 | -np.inf)
159 |
160 | masked_scores = masked_scores.reshape(-1)
161 | idxs = np.argsort(-masked_scores)[:self.keep_per_depth].flatten()
162 |
163 | backptrs, labels = np.unravel_index(idxs, all_new_scores.shape)
164 |
165 | transition_valid = (all_new_stack_depths[
166 | backptrs, labels
167 | ] >= self.min_depth) & (all_new_stack_depths[
168 | backptrs, labels
169 | ] <= self.max_depth)
170 |
171 | backptrs = backptrs[transition_valid]
172 | labels = labels[transition_valid]
173 |
174 | self.beam = Beam(
175 | all_new_scores[backptrs, labels],
176 | all_new_stack_depths[backptrs, labels],
177 | self.beam,
178 | backptrs,
179 | labels,
180 | )
181 |
--------------------------------------------------------------------------------
/learning/evaluate.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import math
3 | import os.path
4 | import re
5 | import subprocess
6 | import tempfile
7 | from copy import deepcopy
8 | from typing import Tuple
9 | import tempfile
10 |
11 | import numpy as np
12 | import torch
13 | from sklearn.metrics import precision_recall_fscore_support
14 | from tqdm import tqdm as tq
15 |
16 | from tagging.tree_tools import create_dummy_tree
17 |
18 | repo_directory = os.path.abspath(__file__)
19 |
20 | class ParseMetrics(object):
21 | def __init__(self, recall, precision, fscore, complete_match, tagging_accuracy=100):
22 | self.recall = recall
23 | self.precision = precision
24 | self.fscore = fscore
25 | self.complete_match = complete_match
26 | self.tagging_accuracy = tagging_accuracy
27 |
28 | def __str__(self):
29 | if self.tagging_accuracy < 100:
30 | return "(Recall={:.4f}, Precision={:.4f}, ParseMetrics={:.4f}, CompleteMatch={:.4f}, TaggingAccuracy={:.4f})".format(
31 | self.recall, self.precision, self.fscore, self.complete_match,
32 | self.tagging_accuracy)
33 | else:
34 | return "(Recall={:.4f}, Precision={:.4f}, ParseMetrics={:.4f}, CompleteMatch={:.4f})".format(
35 | self.recall, self.precision, self.fscore, self.complete_match)
36 |
37 |
38 | def report_eval_loss(model, eval_dataloader, device, n_iter, writer) -> np.ndarray:
39 | loss = []
40 | for batch in eval_dataloader:
41 | batch = {k: v.to(device) for k, v in batch.items()}
42 | with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
43 | outputs = model(**batch)
44 | loss.append(torch.mean(outputs[0]).cpu())
45 |
46 | mean_loss = np.mean(loss)
47 | logging.info("Eval Loss: {}".format(mean_loss))
48 | if writer is not None:
49 | writer.add_scalar('eval_loss', mean_loss, n_iter)
50 | return mean_loss
51 |
52 |
53 | def predict(
54 | model, eval_dataloader, dataset_size, num_tags, batch_size, device
55 | ) -> Tuple[np.array, np.array]:
56 | model.eval()
57 | predictions = []
58 | eval_labels = []
59 | max_len = 0
60 | idx = 0
61 |
62 | for batch in tq(eval_dataloader, disable=True):
63 | batch = {k: v.to(device) for k, v in batch.items()}
64 |
65 | with torch.no_grad(), torch.cuda.amp.autocast(
66 | enabled=True, dtype=torch.bfloat16
67 | ):
68 | outputs = model(**batch)
69 |
70 | logits = outputs[1].float().cpu().numpy()
71 | max_len = max(max_len, logits.shape[1])
72 | predictions.append(logits)
73 | labels = batch['labels'].int().cpu().numpy()
74 | eval_labels.append(labels)
75 | idx += 1
76 |
77 | predictions = np.concatenate([np.pad(logits,
78 | ((0, 0), (0, max_len - logits.shape[1]), (0, 0)),
79 | 'constant', constant_values=0) for logits in
80 | predictions], axis=0)
81 | eval_labels = np.concatenate([np.pad(labels, ((0, 0), (0, max_len - labels.shape[1])),
82 | 'constant', constant_values=0) for labels in
83 | eval_labels], axis=0)
84 |
85 | return predictions, eval_labels
86 |
87 |
88 | def calc_tag_accuracy(
89 | predictions, eval_labels, num_leaf_labels, writer, use_tensorboard
90 | ) -> Tuple[float, float]:
91 | even_predictions = predictions[..., -num_leaf_labels:]
92 | odd_predictions = predictions[..., :-num_leaf_labels]
93 | even_labels = eval_labels % (num_leaf_labels + 1) - 1
94 | odd_labels = eval_labels // (num_leaf_labels + 1) - 1
95 |
96 | odd_predictions = odd_predictions[odd_labels != -1].argmax(-1)
97 | even_predictions = even_predictions[even_labels != -1].argmax(-1)
98 |
99 | odd_labels = odd_labels[odd_labels != -1]
100 | even_labels = even_labels[even_labels != -1]
101 |
102 | odd_acc = (odd_predictions == odd_labels).mean()
103 | even_acc = (even_predictions == even_labels).mean()
104 |
105 | logging.info('odd_tags_accuracy: {}'.format(odd_acc))
106 | logging.info('even_tags_accuracy: {}'.format(even_acc))
107 |
108 | if use_tensorboard:
109 | writer.add_pr_curve('odd_tags_pr_curve', odd_labels, odd_predictions, 0)
110 | writer.add_pr_curve('even_tags_pr_curve', even_labels, even_predictions, 1)
111 | return even_acc, odd_acc
112 |
113 |
114 | def get_dependency_from_lexicalized_tree(lex_tree, triple_dict, offset=0):
115 | # this recursion assumes projectivity
116 | # Input:
117 | # root of lex-tree
118 | # Output:
119 | # the global index of the dependency root
120 | if type(lex_tree) not in {str, dict} and len(lex_tree) == 1:
121 | # unary rule
122 | # returning the global index of the head
123 | return offset
124 |
125 | head_branch_index = int(lex_tree.label().split("^^^")[1])
126 | head_global_index = None
127 | branch_to_global_dict = {}
128 |
129 | for branch_id_child, child in enumerate(lex_tree):
130 | global_id_child = get_dependency_from_lexicalized_tree(
131 | child, triple_dict, offset=offset
132 | )
133 | offset = offset + len(child.leaves())
134 | branch_to_global_dict[branch_id_child] = global_id_child
135 | if branch_id_child == head_branch_index:
136 | head_global_index = global_id_child
137 |
138 | for branch_id_child, child in enumerate(lex_tree):
139 | if branch_id_child != head_branch_index:
140 | triple_dict[branch_to_global_dict[branch_id_child]] = head_global_index
141 |
142 | return head_global_index
143 |
144 |
145 | def is_punctuation(pos):
146 | punct_set = '.' '``' "''" ':' ','
147 | return (pos in punct_set) or (pos.lower() in ['pu', 'punct']) # for CTB & UD
148 |
149 |
150 | def tree_to_dep_triples(lex_tree):
151 | triple_dict = {}
152 | dep_triples = []
153 | sent_root = get_dependency_from_lexicalized_tree(
154 | lex_tree, triple_dict
155 | )
156 | # the root of the whole sentence should refer to ROOT
157 | assert sent_root not in triple_dict
158 | # the root of the sentence
159 | triple_dict[sent_root] = -1
160 | for head, tail in sorted(triple_dict.items()):
161 | dep_triples.append((
162 | head, tail,
163 | lex_tree.pos()[head][1].split("^^^")[1].split("+")[0],
164 | lex_tree.pos()[head][1].split("^^^")[1].split("+")[1]
165 | ))
166 | return dep_triples
167 |
168 |
169 |
170 | def dependency_eval(
171 | predictions, eval_labels, eval_dataset, tag_system, output_path,
172 | model_name, max_depth, keep_per_depth, is_greedy
173 | ) -> ParseMetrics:
174 | ud_flag = eval_dataset.language not in {'English', 'Chinese'}
175 |
176 | # This can be parallelized!
177 | predicted_dev_triples, predicted_dev_triples_unlabeled = [], []
178 | gold_dev_triples, gold_dev_triples_unlabeled = [], []
179 | c_err = 0
180 |
181 | gt_triple_data, pred_triple_data = [], []
182 |
183 | for i in tq(range(predictions.shape[0]), disable=True):
184 | logits = predictions[i]
185 | is_word = (eval_labels[i] != 0)
186 |
187 | original_tree = deepcopy(eval_dataset.trees[i])
188 | original_tree.collapse_unary(collapsePOS=True, collapseRoot=True)
189 |
190 | try: # ignore the ones that failed in unchomsky_normal_form
191 | tree = tag_system.logits_to_tree(
192 | logits, original_tree.pos(),
193 | mask=is_word,
194 | max_depth=max_depth,
195 | keep_per_depth=keep_per_depth,
196 | is_greedy=is_greedy
197 | )
198 | tree.collapse_unary(collapsePOS=True, collapseRoot=True)
199 | except Exception as ex:
200 | template = "An exception of type {0} occurred. Arguments:\n{1!r}"
201 | message = template.format(type(ex).__name__, ex.args)
202 | c_err += 1
203 | predicted_dev_triples.append(create_dummy_tree(original_tree.pos()))
204 | continue
205 | if tree.leaves() != original_tree.leaves():
206 | c_err += 1
207 | predicted_dev_triples.append(create_dummy_tree(original_tree.pos()))
208 | continue
209 |
210 | gt_triples = tree_to_dep_triples(original_tree)
211 | pred_triples = tree_to_dep_triples(tree)
212 |
213 | gt_triple_data.append(gt_triples)
214 | pred_triple_data.append(pred_triples)
215 |
216 | assert len(gt_triples) == len(
217 | pred_triples), f"wrong length {len(gt_triples)} vs. {len(pred_triples)}!"
218 |
219 | for x, y in zip(sorted(gt_triples), sorted(pred_triples)):
220 | if is_punctuation(x[3]) and not ud_flag:
221 | # ignoring punctuations for evaluation
222 | continue
223 | assert x[0] == y[0], f"wrong tree {gt_triples} vs. {pred_triples}!"
224 |
225 | gold_dev_triples.append(f"{x[0]}-{x[1]}-{x[2].split(':')[0]}")
226 | gold_dev_triples_unlabeled.append(f"{x[0]}-{x[1]}")
227 |
228 | predicted_dev_triples.append(f"{y[0]}-{y[1]}-{y[2].split(':')[0]}")
229 | predicted_dev_triples_unlabeled.append(f"{y[0]}-{y[1]}")
230 |
231 | if ud_flag:
232 | # UD
233 | predicted_dev_triples, predicted_dev_triples_unlabeled = [], []
234 | gold_dev_triples, gold_dev_triples_unlabeled = [], []
235 |
236 | language, split = eval_dataset.language, eval_dataset.split.split(".")[-1]
237 |
238 | gold_temp_out, pred_temp_out = tempfile.mktemp(dir=os.path.dirname(repo_directory)), \
239 | tempfile.mktemp(dir=os.path.dirname(repo_directory))
240 | gold_temp_in, pred_temp_in = gold_temp_out + ".deproj", pred_temp_out + ".deproj"
241 |
242 |
243 | save_triplets(gt_triple_data, gold_temp_out)
244 | save_triplets(pred_triple_data, pred_temp_out)
245 |
246 | for filename, tgt_filename in zip([gold_temp_out, pred_temp_out], [gold_temp_in, pred_temp_in]):
247 | command = f"cd ./malt/maltparser-1.9.2/; java -jar maltparser-1.9.2.jar -c {language}_{split} -m deproj" \
248 | f" -i {filename} -o {tgt_filename} ; cd ../../"
249 | os.system(command)
250 |
251 | loaded_gold_dev_triples = load_triplets(gold_temp_in)
252 | loaded_pred_dev_triples = load_triplets(pred_temp_in)
253 |
254 | for gt_triples, pred_triples in zip(loaded_gold_dev_triples, loaded_pred_dev_triples):
255 | for x, y in zip(sorted(gt_triples), sorted(pred_triples)):
256 | if is_punctuation(x[3]):
257 | # ignoring punctuations for evaluation
258 | continue
259 | assert x[0] == y[0], f"wrong tree {gt_triples} vs. {pred_triples}!"
260 |
261 | gold_dev_triples.append(f"{x[0]}-{x[1]}-{x[2].split(':')[0]}")
262 | gold_dev_triples_unlabeled.append(f"{x[0]}-{x[1]}")
263 |
264 | predicted_dev_triples.append(f"{y[0]}-{y[1]}-{y[2].split(':')[0]}")
265 | predicted_dev_triples_unlabeled.append(f"{y[0]}-{y[1]}")
266 |
267 |
268 | logging.warning("Number of binarization error: {}\n".format(c_err))
269 | las_recall, las_precision, las_fscore, _ = precision_recall_fscore_support(
270 | gold_dev_triples, predicted_dev_triples, average='micro'
271 | )
272 | uas_recall, uas_precision, uas_fscore, _ = precision_recall_fscore_support(
273 | gold_dev_triples_unlabeled, predicted_dev_triples_unlabeled, average='micro'
274 | )
275 |
276 | return (ParseMetrics(las_recall, las_precision, las_fscore, complete_match=1),
277 | ParseMetrics(uas_recall, uas_precision, uas_fscore, complete_match=1))
278 |
279 |
280 | def save_triplets(triplet_data, file_path):
281 | # save triplets to file in conll format
282 | with open(file_path, 'w') as f:
283 | for triplets in triplet_data:
284 | for triplet in triplets:
285 | # 8 Витоша витоша PROPN Npfsi Definite=Ind|Gender=Fem|Number=Sing 6 nmod _ _
286 | head, tail, label, pos = triplet
287 | f.write(f"{head+1}\t-\t-\t{pos}\t-\t-\t{tail+1}\t{label}\t_\t_\n")
288 | f.write('\n')
289 |
290 | return
291 |
292 |
293 | def load_triplets(file_path):
294 | # load triplets from file in conll format
295 | triplet_data = []
296 | with open(file_path, 'r') as f:
297 | triplets = []
298 | for line in f.readlines():
299 | if line.startswith('#') or line == '\n':
300 | if triplets:
301 | triplet_data.append(triplets)
302 | triplets = []
303 | continue
304 | line_list = line.strip().split('\t')
305 | head, tail, label, pos = line_list[0], line_list[6], line_list[7], line_list[3]
306 | triplets.append((head, tail, label, pos))
307 | if triplets:
308 | triplet_data.append(triplets)
309 | return triplet_data
310 |
311 |
312 | def calc_parse_eval(predictions, eval_labels, eval_dataset, tag_system, output_path,
313 | model_name, max_depth, keep_per_depth, is_greedy) -> ParseMetrics:
314 | predicted_dev_trees = []
315 | gold_dev_trees = []
316 | c_err = 0
317 | for i in tq(range(predictions.shape[0])):
318 | logits = predictions[i]
319 | is_word = eval_labels[i] != 0
320 | original_tree = eval_dataset.trees[i]
321 | gold_dev_trees.append(original_tree)
322 | try: # ignore the ones that failed in unchomsky_normal_form
323 | tree = tag_system.logits_to_tree(logits, original_tree.pos(), mask=is_word,
324 | max_depth=max_depth,
325 | keep_per_depth=keep_per_depth,
326 | is_greedy=is_greedy)
327 | except Exception as ex:
328 | template = "An exception of type {0} occurred. Arguments:\n{1!r}"
329 | message = template.format(type(ex).__name__, ex.args)
330 | # print(message)
331 | c_err += 1
332 | predicted_dev_trees.append(create_dummy_tree(original_tree.pos()))
333 | continue
334 | if tree.leaves() != original_tree.leaves():
335 | c_err += 1
336 | predicted_dev_trees.append(create_dummy_tree(original_tree.pos()))
337 | continue
338 | predicted_dev_trees.append(tree)
339 |
340 | logging.warning("Number of binarization error: {}".format(c_err))
341 |
342 | return evalb("EVALB_SPMRL/", gold_dev_trees, predicted_dev_trees)
343 |
344 |
345 | def save_predictions(predicted_trees, file_path):
346 | with open(file_path, 'w') as f:
347 | for tree in predicted_trees:
348 | f.write(' '.join(str(tree).split()) + '\n')
349 |
350 |
351 | def evalb(evalb_dir, gold_trees, predicted_trees, ref_gold_path=None) -> ParseMetrics:
352 | # Code from: https://github.com/nikitakit/self-attentive-parser/blob/master/src/evaluate.py
353 | assert os.path.exists(evalb_dir)
354 | evalb_program_path = os.path.join(evalb_dir, "evalb")
355 | evalb_spmrl_program_path = os.path.join(evalb_dir, "evalb_spmrl")
356 | assert os.path.exists(evalb_program_path) or os.path.exists(evalb_spmrl_program_path)
357 |
358 | if os.path.exists(evalb_program_path):
359 | evalb_param_path = os.path.join(evalb_dir, "nk.prm")
360 | else:
361 | evalb_program_path = evalb_spmrl_program_path
362 | evalb_param_path = os.path.join(evalb_dir, "spmrl.prm")
363 |
364 | assert os.path.exists(evalb_program_path)
365 | assert os.path.exists(evalb_param_path)
366 |
367 | assert len(gold_trees) == len(predicted_trees)
368 | for gold_tree, predicted_tree in zip(gold_trees, predicted_trees):
369 | gold_leaves = list(gold_tree.leaves())
370 | predicted_leaves = list(predicted_tree.leaves())
371 |
372 | temp_dir = tempfile.TemporaryDirectory(prefix="evalb-")
373 | gold_path = os.path.join(temp_dir.name, "gold.txt")
374 | predicted_path = os.path.join(temp_dir.name, "predicted.txt")
375 | output_path = os.path.join(temp_dir.name, "output.txt")
376 |
377 | with open(gold_path, "w") as outfile:
378 | if ref_gold_path is None:
379 | for tree in gold_trees:
380 | outfile.write(' '.join(str(tree).split()) + '\n')
381 | else:
382 | # For the SPMRL dataset our data loader performs some modifications
383 | # (like stripping morphological features), so we compare to the
384 | # raw gold file to be certain that we haven't spoiled the evaluation
385 | # in some way.
386 | with open(ref_gold_path) as goldfile:
387 | outfile.write(goldfile.read())
388 |
389 | with open(predicted_path, "w") as outfile:
390 | for tree in predicted_trees:
391 | outfile.write(' '.join(str(tree).split()) + '\n')
392 |
393 | command = "{} -p {} {} {} > {}".format(
394 | evalb_program_path,
395 | evalb_param_path,
396 | gold_path,
397 | predicted_path,
398 | output_path,
399 | )
400 | subprocess.run(command, shell=True)
401 |
402 | fscore = ParseMetrics(math.nan, math.nan, math.nan, math.nan)
403 | with open(output_path) as infile:
404 | for line in infile:
405 | match = re.match(r"Bracketing Recall\s+=\s+(\d+\.\d+)", line)
406 | if match:
407 | fscore.recall = float(match.group(1))
408 | match = re.match(r"Bracketing Precision\s+=\s+(\d+\.\d+)", line)
409 | if match:
410 | fscore.precision = float(match.group(1))
411 | match = re.match(r"Bracketing FMeasure\s+=\s+(\d+\.\d+)", line)
412 | if match:
413 | fscore.fscore = float(match.group(1))
414 | match = re.match(r"Complete match\s+=\s+(\d+\.\d+)", line)
415 | if match:
416 | fscore.complete_match = float(match.group(1))
417 | match = re.match(r"Tagging accuracy\s+=\s+(\d+\.\d+)", line)
418 | if match:
419 | fscore.tagging_accuracy = float(match.group(1))
420 | break
421 |
422 | success = (
423 | not math.isnan(fscore.fscore) or
424 | fscore.recall == 0.0 or
425 | fscore.precision == 0.0)
426 |
427 | if success:
428 | temp_dir.cleanup()
429 | else:
430 | print("Error reading EVALB results.")
431 | print("Gold path: {}".format(gold_path))
432 | print("Predicted path: {}".format(predicted_path))
433 | print("Output path: {}".format(output_path))
434 |
435 | return fscore
436 |
437 |
438 |
439 |
440 | def dependency_decoding(
441 | predictions, eval_labels, eval_dataset, tag_system, output_path,
442 | model_name, max_depth, keep_per_depth, is_greedy
443 | ) -> ParseMetrics:
444 | ud_flag = eval_dataset.language not in {'English', 'Chinese'}
445 |
446 | # This can be parallelized!
447 | predicted_dev_triples, predicted_dev_triples_unlabeled = [], []
448 | gold_dev_triples, gold_dev_triples_unlabeled = [], []
449 | pred_hexa_tags = []
450 | c_err = 0
451 |
452 | gt_triple_data, pred_triple_data = [], []
453 |
454 | for i in tq(range(predictions.shape[0]), disable=True):
455 | logits = predictions[i]
456 | is_word = (eval_labels[i] != 0)
457 |
458 | original_tree = deepcopy(eval_dataset.trees[i])
459 | original_tree.collapse_unary(collapsePOS=True, collapseRoot=True)
460 |
461 | try: # ignore the ones that failed in unchomsky_normal_form
462 | tree = tag_system.logits_to_tree(
463 | logits, original_tree.pos(),
464 | mask=is_word,
465 | max_depth=max_depth,
466 | keep_per_depth=keep_per_depth,
467 | is_greedy=is_greedy
468 | )
469 | hexa_ids = tag_system.logits_to_ids(
470 | logits, is_word, max_depth, keep_per_depth, is_greedy=is_greedy
471 | )
472 | pred_hexa_tags.append(hexa_ids)
473 |
474 | tree.collapse_unary(collapsePOS=True, collapseRoot=True)
475 | except Exception as ex:
476 | template = "An exception of type {0} occurred. Arguments:\n{1!r}"
477 | message = template.format(type(ex).__name__, ex.args)
478 | c_err += 1
479 | predicted_dev_triples.append(create_dummy_tree(original_tree.pos()))
480 | continue
481 | if tree.leaves() != original_tree.leaves():
482 | c_err += 1
483 | predicted_dev_triples.append(create_dummy_tree(original_tree.pos()))
484 | continue
485 |
486 | gt_triples = tree_to_dep_triples(original_tree)
487 | pred_triples = tree_to_dep_triples(tree)
488 |
489 | gt_triple_data.append(gt_triples)
490 | pred_triple_data.append(pred_triples)
491 |
492 | assert len(gt_triples) == len(
493 | pred_triples), f"wrong length {len(gt_triples)} vs. {len(pred_triples)}!"
494 |
495 | for x, y in zip(sorted(gt_triples), sorted(pred_triples)):
496 | if is_punctuation(x[3]) and not ud_flag:
497 | # ignoring punctuations for evaluation
498 | continue
499 | assert x[0] == y[0], f"wrong tree {gt_triples} vs. {pred_triples}!"
500 |
501 | gold_dev_triples.append(f"{x[0]}-{x[1]}-{x[2].split(':')[0]}")
502 | gold_dev_triples_unlabeled.append(f"{x[0]}-{x[1]}")
503 |
504 | predicted_dev_triples.append(f"{y[0]}-{y[1]}-{y[2].split(':')[0]}")
505 | predicted_dev_triples_unlabeled.append(f"{y[0]}-{y[1]}")
506 |
507 | if ud_flag:
508 | # UD
509 | predicted_dev_triples, predicted_dev_triples_unlabeled = [], []
510 | gold_dev_triples, gold_dev_triples_unlabeled = [], []
511 |
512 | language, split = eval_dataset.language, eval_dataset.split.split(".")[-1]
513 |
514 | gold_temp_out, pred_temp_out = tempfile.mktemp(dir=os.path.dirname(repo_directory)), \
515 | tempfile.mktemp(dir=os.path.dirname(repo_directory))
516 | gold_temp_in, pred_temp_in = gold_temp_out + ".deproj", pred_temp_out + ".deproj"
517 |
518 |
519 | save_triplets(gt_triple_data, gold_temp_out)
520 | save_triplets(pred_triple_data, pred_temp_out)
521 |
522 | for filename, tgt_filename in zip([gold_temp_out, pred_temp_out], [gold_temp_in, pred_temp_in]):
523 | command = f"cd ./malt/maltparser-1.9.2/; java -jar maltparser-1.9.2.jar -c {language}_{split} -m deproj" \
524 | f" -i {filename} -o {tgt_filename} ; cd ../../"
525 | os.system(command)
526 |
527 | loaded_gold_dev_triples = load_triplets(gold_temp_in)
528 | loaded_pred_dev_triples = load_triplets(pred_temp_in)
529 |
530 | for gt_triples, pred_triples in zip(loaded_gold_dev_triples, loaded_pred_dev_triples):
531 | for x, y in zip(sorted(gt_triples), sorted(pred_triples)):
532 | if is_punctuation(x[3]):
533 | # ignoring punctuations for evaluation
534 | continue
535 | assert x[0] == y[0], f"wrong tree {gt_triples} vs. {pred_triples}!"
536 |
537 | gold_dev_triples.append(f"{x[0]}-{x[1]}-{x[2].split(':')[0]}")
538 | gold_dev_triples_unlabeled.append(f"{x[0]}-{x[1]}")
539 |
540 | predicted_dev_triples.append(f"{y[0]}-{y[1]}-{y[2].split(':')[0]}")
541 | predicted_dev_triples_unlabeled.append(f"{y[0]}-{y[1]}")
542 |
543 | return {
544 | "predicted_dev_triples": predicted_dev_triples,
545 | "predicted_hexa_tags": pred_hexa_tags
546 | }
--------------------------------------------------------------------------------
/learning/learn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
5 |
6 | from transformers import AutoModelForTokenClassification, AutoModel
7 |
8 |
9 | def calc_loss_helper(logits, labels, attention_mask, num_even_tags, num_odd_tags):
10 | # shape: (batch_size, seq_len, num_tags) -> (batch_size, num_tags, seq_len)
11 | logits = torch.movedim(logits, -1, 1)
12 | odd_logits, even_logits = torch.split(logits, [num_odd_tags, num_even_tags], dim=1)
13 | odd_labels = (labels // (num_even_tags + 1)) - 1
14 | even_labels = (labels % (num_even_tags + 1)) - 1
15 | # The last word will have only even label
16 |
17 | # Only keep active parts of the loss
18 | active_even_labels = torch.where(
19 | attention_mask, even_labels, -1
20 | )
21 | active_odd_labels = torch.where(
22 | attention_mask, odd_labels, -1
23 | )
24 | loss = (F.cross_entropy(even_logits, active_even_labels, ignore_index=-1)
25 | + F.cross_entropy(odd_logits, active_odd_labels, ignore_index=-1))
26 |
27 | return loss
28 |
29 |
30 | class ModelForTetratagging(nn.Module):
31 | def __init__(self, config):
32 | super().__init__()
33 | self.num_even_tags = config.task_specific_params['num_even_tags']
34 | self.num_odd_tags = config.task_specific_params['num_odd_tags']
35 | self.model_path = config.task_specific_params['model_path']
36 | self.use_pos = config.task_specific_params.get('use_pos', False)
37 | self.num_pos_tags = config.task_specific_params.get('num_pos_tags', 50)
38 |
39 | self.pos_emb_dim = config.task_specific_params['pos_emb_dim']
40 | self.dropout_rate = config.task_specific_params['dropout']
41 |
42 | self.bert = AutoModel.from_pretrained(self.model_path, config=config)
43 | if self.use_pos:
44 | self.pos_encoder = nn.Sequential(
45 | nn.Embedding(self.num_pos_tags, self.pos_emb_dim, padding_idx=0)
46 | )
47 |
48 | self.endofword_embedding = nn.Embedding(2, self.pos_emb_dim)
49 | self.lstm = nn.LSTM(
50 | 2 * config.hidden_size + self.pos_emb_dim * (1 + self.use_pos),
51 | config.hidden_size,
52 | config.task_specific_params['lstm_layers'],
53 | dropout=self.dropout_rate,
54 | batch_first=True, bidirectional=True
55 | )
56 | self.projection = nn.Sequential(
57 | nn.Linear(2 * config.hidden_size, config.num_labels)
58 | )
59 |
60 | def forward(
61 | self,
62 | input_ids=None,
63 | pair_ids=None,
64 | pos_ids=None,
65 | end_of_word=None,
66 | attention_mask=None,
67 | head_mask=None,
68 | inputs_embeds=None,
69 | labels=None,
70 | output_attentions=None,
71 | output_hidden_states=None,
72 | ):
73 | outputs = self.bert(
74 | input_ids,
75 | attention_mask=attention_mask,
76 | head_mask=head_mask,
77 | inputs_embeds=inputs_embeds,
78 | output_attentions=output_attentions,
79 | output_hidden_states=output_hidden_states,
80 | )
81 | if self.use_pos:
82 | pos_encodings = self.pos_encoder(pos_ids)
83 | token_repr = torch.cat([outputs[0], pos_encodings], dim=-1)
84 | else:
85 | token_repr = outputs[0]
86 |
87 | start_repr = outputs[0].take_along_dim(pair_ids.unsqueeze(-1), dim=1)
88 | token_repr = torch.cat([token_repr, start_repr], dim=-1)
89 | token_repr = torch.cat([token_repr, self.endofword_embedding((pos_ids != 0).long())],
90 | dim=-1)
91 |
92 | lens = attention_mask.sum(dim=-1).cpu()
93 | token_repr = pack_padded_sequence(token_repr, lens, batch_first=True,
94 | enforce_sorted=False)
95 | token_repr = self.lstm(token_repr)[0]
96 | token_repr, _ = pad_packed_sequence(token_repr, batch_first=True)
97 |
98 | tag_logits = self.projection(token_repr)
99 |
100 | loss = None
101 | if labels is not None and self.training:
102 | loss = calc_loss_helper(
103 | tag_logits, labels, attention_mask.bool(),
104 | self.num_even_tags, self.num_odd_tags
105 | )
106 | return loss, tag_logits
107 | else:
108 | return loss, tag_logits
109 |
--------------------------------------------------------------------------------
/learning/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 | from typing import Optional, Tuple, Any, Dict, Iterable, Union
6 | import math
7 |
8 |
9 | class TakeLSTMOutput(nn.Module):
10 | # Take the last hidden state from the output of the LSTM
11 | def forward(self, x):
12 | tensor, _ = x
13 | return tensor
14 |
15 |
16 | def hexa_loss(logits, labels, attention_mask, num_even_tags, num_odd_tags):
17 | # shape: (batch_size, seq_len, num_tags) -> (batch_size, num_tags, seq_len)
18 | logits = torch.movedim(logits, -1, 1)
19 | odd_logits, even_logits = torch.split(logits, [num_odd_tags, num_even_tags], dim=1)
20 |
21 | odd_labels = (labels // (num_even_tags + 1)) - 1
22 | even_labels = (labels % (num_even_tags + 1)) - 1
23 | # The last word will have only even label
24 |
25 | # Only keep active parts of the loss
26 | active_even_labels = torch.where(attention_mask, even_labels, -1)
27 | active_odd_labels = torch.where(attention_mask, odd_labels, -1)
28 | loss = (F.cross_entropy(even_logits, active_even_labels, ignore_index=-1)
29 | + F.cross_entropy(odd_logits, active_odd_labels, ignore_index=-1))
30 |
31 | ground_truth_likelihood = 0.
32 |
33 | return loss, ground_truth_likelihood
34 |
35 |
36 | def calc_loss_helper(logits, labels, attention_mask):
37 | # Only keep active parts of the loss
38 |
39 | # logits: shape: (batch_size, *, num_tags)
40 | # active_logits: shape: (batch_size, num_tags, *)
41 | active_logits = torch.movedim(logits, -1, 1)
42 | if active_logits.dim() == 4: # for permute logits
43 | attention_mask = attention_mask.unsqueeze(2)
44 |
45 | # shape: (batch_size, seq_len, ...)
46 | active_labels = torch.where(
47 | (attention_mask == 1), labels, -1
48 | )
49 | loss = F.cross_entropy(active_logits, active_labels, ignore_index=-1)
50 | ground_truth_likelihood = 0.
51 |
52 | return loss, ground_truth_likelihood
53 |
54 |
55 | def interval_sort(
56 | values, # shape: (batch_size, seq_len)
57 | gt_values, # shape: (batch_size, seq_len)
58 | mask, # shape: (batch_size, seq_len)
59 | intervals # shape: (batch_size, num_intervals, 4)
60 | ):
61 | # Args:
62 | # value: (batch_size, seq_len)
63 | # intervals: (batch_size, num_intervals, 4)
64 | # tuples of (start, end, split, self) intervals. end is inclusive.
65 | # Returns:
66 | # interval_stats: (batch_size, num_intervals, 3)
67 | # tuples of (min, median, max) values in the interval
68 |
69 | # shape: (batch_size, seq_len)
70 | batch_size, seq_len = values.size()
71 | num_intervals = intervals.size(1)
72 |
73 | _, sorted_indices = torch.where(
74 | mask, gt_values, float('inf') # ascending
75 | ).sort(dim=1)
76 | sorted_values = values.gather(dim=1, index=sorted_indices)
77 |
78 | range_vec = torch.arange( # shape: (1, seq_len, 1)
79 | 0, values.size(1), device=intervals.device, dtype=torch.long
80 | )[None, :, None]
81 | # shape: (batch_size, num_intervals)
82 | relative_indices = (intervals.select(dim=-1, index=3) -
83 | intervals.select(dim=-1, index=0))
84 | # shape: (batch_size, seq_len, num_intervals)
85 | in_interval = ((range_vec >= intervals[:, None, :, 0]) &
86 | (range_vec <= intervals[:, None, :, 1]))
87 | # shape: (batch_size, seq_len, num_intervals)
88 | projected_indices = torch.where(
89 | in_interval, sorted_indices[:, :, None], seq_len - 1
90 | ).sort(dim=1)[0]
91 |
92 | # shape (batch_size, num_intervals)
93 | projected_indices = projected_indices.take_along_dim(
94 | relative_indices[:, None, :], dim=1
95 | ).squeeze(1)
96 |
97 | # shape: (batch_size, num_intervals)
98 | # projected_values = values.take_along_dim(projected_indices, dim=1)
99 | projected_values = sorted_values.take_along_dim(projected_indices, dim=1)
100 |
101 | # shape: (batch_size, num_intervals)
102 | split_rank = intervals.select(dim=-1, index=2).long()
103 |
104 | # shape: (batch_size, num_intervals), avoid out of range
105 | safe_split_rank = torch.where(
106 | split_rank == 0, 1, split_rank
107 | )
108 | # shape: (batch_size, num_intervals)
109 | split_thresholds = (sorted_values.gather(dim=1, index=safe_split_rank - 1) +
110 | sorted_values.gather(dim=1, index=safe_split_rank)) / 2.0
111 | # shape: (batch_size, num_intervals)
112 | # > 0: go right, < 0: go left
113 | split_logits = projected_values - split_thresholds
114 |
115 | return split_logits
116 |
117 |
118 | def logsoftperm(
119 | input: torch.Tensor, # shape: (*, num_elements)
120 | perm: torch.Tensor, # shape: (*, num_elements)
121 | mask: Optional[torch.BoolTensor] = None # shape: (*, num_elements)
122 | ):
123 | # Args:
124 | # input: (*, num_elements)
125 | # perm: (*, num_elements)
126 | # Returns:
127 | # output: (*, num_elements)
128 | max_value = input.max().detach() + 1.0
129 | if mask is not None:
130 | input = input.masked_fill(~mask, max_value)
131 | perm = perm.masked_fill(~mask, max_value)
132 |
133 | # shape: (*, num_elements)
134 | sorted_input, _ = input.sort(dim=-1)
135 | # shape: (*, num_elements, num_elements)
136 | logits_matrix = -torch.abs(perm[:, None, :] - sorted_input[:, :, None])
137 | # shape: (*, num_elements)
138 | log_likelihood_perm = F.log_softmax(logits_matrix, dim=-1).diagonal(dim1=-2, dim2=-1)
139 |
140 | if mask is not None:
141 | log_likelihood_perm = log_likelihood_perm.masked_fill(~mask, 0.0)
142 |
143 | return log_likelihood_perm
144 |
145 |
146 | def _batched_index_select(
147 | target: torch.Tensor,
148 | indices: torch.LongTensor,
149 | flattened_indices: Optional[torch.LongTensor] = None,
150 | ) -> torch.Tensor:
151 | # Args:
152 | # target: (..., seq_len, ...)
153 | # indices: (..., num_indices)
154 | # Returns:
155 | # selected_targets (..., num_indices, ...)
156 |
157 | # dim is the index of the last dimension of indices
158 | dim = indices.dim() - 1
159 | unidim = False
160 | if target.dim() == 2:
161 | # (batch_size, sequence_length) -> (batch_size, sequence_length, 1)
162 | unidim = True
163 | target = target.unsqueeze(-1)
164 |
165 | target_size, indices_size = target.size(), indices.size()
166 | # flatten dimensions before dim, make a pseudo batch dimension
167 | indices = indices.view(math.prod([*indices_size[:dim]]), indices_size[dim])
168 |
169 | if flattened_indices is None:
170 | # Shape: (batch_size * d_1 * ... * d_n)
171 | flattened_indices = flatten_and_batch_shift_indices(
172 | indices, target_size[dim]
173 | )
174 |
175 | # Shape: (batch_size * sequence_length, embedding_size)
176 | flattened_target = target.reshape(-1, *target_size[dim + 1:])
177 | # Shape: (batch_size * d_1 * ... * d_n, embedding_size)
178 | flattened_selected = flattened_target.index_select(0, flattened_indices)
179 | selected_shape = list(indices_size) + ([] if unidim else list(target_size[dim + 1:]))
180 | # Shape: (batch_size, d_1, ..., d_n, embedding_size)
181 | selected_targets = flattened_selected.reshape(*selected_shape)
182 | return selected_targets
183 |
184 |
185 | def flatten_and_batch_shift_indices(
186 | indices: torch.Tensor, sequence_length: int
187 | ) -> torch.Tensor:
188 | # Input:
189 | # indices: (batch_size, num_indices)
190 | # sequence_length: int, d_1*d_2*...*d_n
191 | # Returns:
192 | # offset_indices Shape: (batch_size*d_1*d_2*...*d_n)
193 | offsets = torch.arange(0, indices.size(0), dtype=torch.long,
194 | device=indices.device) * sequence_length
195 |
196 | for _ in range(indices.dim() - 1):
197 | offsets = offsets.unsqueeze(1)
198 |
199 | # Shape: (batch_size, d_1, ..., d_n)
200 | offset_indices = indices + offsets
201 | # Shape: (batch_size * d_1 * ... * d_n)
202 | offset_indices = offset_indices.view(-1)
203 |
204 | return offset_indices
205 |
206 |
207 | def onehot_with_ignore_label(labels, num_class, ignore_label):
208 | # One-hot encode the modified labels
209 | one_hot_labels = torch.nn.functional.one_hot(
210 | labels.masked_fill((labels == ignore_label), num_class),
211 | num_classes=num_class + 1
212 | )
213 | # Remove the last row in the one-hot encoding
214 | # shape: (*, num_class+1 -> num_class)
215 | one_hot_labels = one_hot_labels[..., :-1]
216 | return one_hot_labels
217 |
--------------------------------------------------------------------------------
/load_dataset.sh:
--------------------------------------------------------------------------------
1 | # Based on https://github.com/nikitakit/tetra-tagging/blob/master/examples/training.ipynb
2 |
3 | %%bash
4 | if [ ! -e self-attentive-parser ]; then
5 | git clone https://github.com/nikitakit/self-attentive-parser &> /dev/null
6 | fi
7 | rm -rf train dev test EVALB/
8 | cp self-attentive-parser/data/02-21.10way.clean ./data/train
9 | cp self-attentive-parser/data/22.auto.clean ./data/dev
10 | cp self-attentive-parser/data/23.auto.clean ./data/test
11 | # The evalb program needs to be compiled
12 | cp -R self-attentive-parser/EVALB EVALB
13 | rm -rf self-attentive-parser
14 | cd EVALB && make &> /dev/null
15 | # To test that everything works as intended, we check that the F1 score when
16 | # comparing the dev set with itself is 100.
17 | ./evalb -p nk.prm ../data/dev ../data/dev | grep FMeasure | head -n 1
--------------------------------------------------------------------------------
/pat-models/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file
4 | !.gitignore
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.0.0
2 | appnope==0.1.3
3 | argon2-cffi==21.3.0
4 | argon2-cffi-bindings==21.2.0
5 | asttokens==2.0.5
6 | attrs==21.4.0
7 | backcall==0.2.0
8 | beautifulsoup4==4.10.0
9 | bleach==4.1.0
10 | cachetools==5.0.0
11 | certifi==2021.10.8
12 | cffi==1.15.0
13 | charset-normalizer==2.0.12
14 | click==8.1.2
15 | debugpy==1.6.0
16 | decorator==5.1.1
17 | defusedxml==0.7.1
18 | entrypoints==0.4
19 | executing==0.8.3
20 | fastjsonschema==2.15.3
21 | filelock==3.6.0
22 | google-auth==2.6.2
23 | google-auth-oauthlib==0.4.6
24 | grpcio==1.44.0
25 | huggingface-hub==0.5.1
26 | idna==3.3
27 | importlib-metadata==4.11.3
28 | importlib-resources==5.6.0
29 | ipykernel==6.12.1
30 | ipython==8.2.0
31 | ipython-genutils==0.2.0
32 | ipywidgets==7.7.0
33 | jedi==0.18.1
34 | Jinja2==3.1.1
35 | joblib==1.1.1
36 | jsonschema==4.4.0
37 | jupyter==1.0.0
38 | jupyter-client==7.2.1
39 | jupyter-console==6.4.3
40 | jupyter-core==4.9.2
41 | jupyterlab-pygments==0.1.2
42 | jupyterlab-widgets==1.1.0
43 | Markdown==3.3.6
44 | MarkupSafe==2.1.1
45 | matplotlib-inline==0.1.3
46 | mistune==0.8.4
47 | nbclient==0.5.13
48 | nbconvert==6.4.5
49 | nbformat==5.3.0
50 | nest-asyncio==1.5.5
51 | nltk==3.7
52 | notebook==6.4.10
53 | numpy==1.22.3
54 | oauthlib==3.2.0
55 | packaging==21.3
56 | pandas==1.4.2
57 | pandocfilters==1.5.0
58 | parso==0.8.3
59 | pexpect==4.8.0
60 | pickleshare==0.7.5
61 | plotly==5.7.0
62 | prometheus-client==0.14.0
63 | prompt-toolkit==3.0.29
64 | protobuf==3.20.0
65 | psutil==5.9.0
66 | ptyprocess==0.7.0
67 | pure-eval==0.2.2
68 | pyasn1==0.4.8
69 | pyasn1-modules==0.2.8
70 | pycparser==2.21
71 | Pygments==2.11.2
72 | pyparsing==3.0.7
73 | pyrsistent==0.18.1
74 | python-dateutil==2.8.2
75 | pytz==2022.1
76 | PyYAML==6.0
77 | pyzmq==22.3.0
78 | qtconsole==5.3.0
79 | QtPy==2.0.1
80 | regex==2022.3.15
81 | requests==2.27.1
82 | requests-oauthlib==1.3.1
83 | rsa==4.8
84 | sacremoses==0.0.49
85 | scipy==1.8.1
86 | Send2Trash==1.8.0
87 | six==1.16.0
88 | soupsieve==2.3.1
89 | spicy==0.16.0
90 | stack-data==0.2.0
91 | tenacity==8.0.1
92 | tensorboard==2.8.0
93 | tensorboard-data-server==0.6.1
94 | tensorboard-plugin-wit==1.8.1
95 | terminado==0.13.3
96 | testpath==0.6.0
97 | tokenizers==0.11.6
98 | torch==1.11.0
99 | tornado==6.1
100 | tqdm==4.64.0
101 | traitlets==5.1.1
102 | transformers==4.18.0
103 | typing_extensions==4.1.1
104 | urllib3==1.26.9
105 | wcwidth==0.2.5
106 | webencodings==0.5.1
107 | Werkzeug==2.1.1
108 | widgetsnbextension==3.6.0
109 | zipp==3.8.0
110 | bitsandbytes==0.39.0
111 | scikit-learn==1.2.2
112 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import logging
4 | import pickle
5 | import random
6 | import sys
7 | import json
8 |
9 | import numpy as np
10 | import torch
11 | import transformers
12 | from bitsandbytes.optim import AdamW
13 | from nltk.corpus.reader.bracket_parse import BracketParseCorpusReader
14 | from torch.utils.data import DataLoader
15 | from torch.utils.tensorboard import SummaryWriter
16 | from tqdm import tqdm as tq
17 |
18 | from const import *
19 | from learning.dataset import TaggingDataset
20 | from learning.evaluate import predict, dependency_eval, calc_parse_eval, calc_tag_accuracy, dependency_decoding
21 | from learning.learn import ModelForTetratagging
22 | from tagging.hexatagger import HexaTagger
23 |
24 | # Set random seed
25 | RANDOM_SEED = 1
26 | torch.manual_seed(RANDOM_SEED)
27 | random.seed(RANDOM_SEED)
28 | np.random.seed(RANDOM_SEED)
29 | print('Random seed: {}'.format(RANDOM_SEED), file=sys.stderr)
30 |
31 | logging.basicConfig(
32 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
33 | datefmt='%m/%d/%Y %H:%M:%S',
34 | level=logging.INFO
35 | )
36 | logger = logging.getLogger(__file__)
37 |
38 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39 |
40 | parser = argparse.ArgumentParser()
41 | subparser = parser.add_subparsers(dest='command')
42 | train = subparser.add_parser('train')
43 | evaluate = subparser.add_parser('evaluate')
44 | predict_parser = subparser.add_parser('predict')
45 | vocab = subparser.add_parser('vocab')
46 |
47 | vocab.add_argument('--tagger', choices=[HEXATAGGER, TETRATAGGER, TD_SR, BU_SR], required=True,
48 | help="Tagging schema")
49 | vocab.add_argument('--lang', choices=LANG, default=ENG, help="Language")
50 | vocab.add_argument('--output-path', choices=[HEXATAGGER, TETRATAGGER, TD_SR, BU_SR],
51 | default="data/vocab/")
52 |
53 | train.add_argument('--tagger', choices=[HEXATAGGER, TETRATAGGER, TD_SR, BU_SR], required=True,
54 | help="Tagging schema")
55 | train.add_argument('--lang', choices=LANG, default=ENG, help="Language")
56 | train.add_argument('--tag-vocab-path', type=str, default="data/vocab/")
57 | train.add_argument('--model', choices=BERT, required=True, help="Model architecture")
58 |
59 | train.add_argument('--model-path', type=str, default='bertlarge',
60 | help="Bert model path or name, "
61 | "xlnet-large-cased for english, hfl/chinese-xlnet-mid for chinese")
62 | train.add_argument('--output-path', type=str, default='pat-models/',
63 | help="Path to save trained models")
64 | train.add_argument('--use-tensorboard', type=bool, default=False,
65 | help="Whether to use the tensorboard for logging the results make sure to "
66 | "add credentials to run.py if set to true")
67 |
68 | train.add_argument('--max-depth', type=int, default=10,
69 | help="Max stack depth used for decoding")
70 | train.add_argument('--keep-per-depth', type=int, default=1,
71 | help="Max elements to keep per depth")
72 |
73 | train.add_argument('--lr', type=float, default=2e-5)
74 | train.add_argument('--epochs', type=int, default=50)
75 | train.add_argument('--batch-size', type=int, default=32)
76 | train.add_argument('--num-warmup-steps', type=int, default=200)
77 | train.add_argument('--weight-decay', type=float, default=0.01)
78 |
79 | evaluate.add_argument('--model-name', type=str, required=True)
80 | evaluate.add_argument('--lang', choices=LANG, default=ENG, help="Language")
81 | evaluate.add_argument('--tagger', choices=[HEXATAGGER, TETRATAGGER, TD_SR, BU_SR],
82 | required=True,
83 | help="Tagging schema")
84 | evaluate.add_argument('--tag-vocab-path', type=str, default="data/vocab/")
85 | evaluate.add_argument('--model-path', type=str, default='pat-models/')
86 | evaluate.add_argument('--bert-model-path', type=str, default='mbert/')
87 | evaluate.add_argument('--output-path', type=str, default='results/')
88 | evaluate.add_argument('--batch-size', type=int, default=16)
89 | evaluate.add_argument('--max-depth', type=int, default=10,
90 | help="Max stack depth used for decoding")
91 | evaluate.add_argument('--is-greedy', type=bool, default=False,
92 | help="Whether or not to use greedy decoding")
93 | evaluate.add_argument('--keep-per-depth', type=int, default=1,
94 | help="Max elements to keep per depth")
95 | evaluate.add_argument('--use-tensorboard', type=bool, default=False,
96 | help="Whether to use the tensorboard for logging the results make sure "
97 | "to add credentials to run.py if set to true")
98 |
99 |
100 | predict_parser.add_argument('--model-name', type=str, required=True)
101 | predict_parser.add_argument('--lang', choices=LANG, default=ENG, help="Language")
102 | predict_parser.add_argument('--tagger', choices=[HEXATAGGER, TETRATAGGER, TD_SR, BU_SR],
103 | required=True,
104 | help="Tagging schema")
105 | predict_parser.add_argument('--tag-vocab-path', type=str, default="data/vocab/")
106 | predict_parser.add_argument('--model-path', type=str, default='pat-models/')
107 | predict_parser.add_argument('--bert-model-path', type=str, default='mbert/')
108 | predict_parser.add_argument('--output-path', type=str, default='results/')
109 | predict_parser.add_argument('--batch-size', type=int, default=16)
110 | predict_parser.add_argument('--max-depth', type=int, default=10,
111 | help="Max stack depth used for decoding")
112 | predict_parser.add_argument('--is-greedy', type=bool, default=False,
113 | help="Whether or not to use greedy decoding")
114 | predict_parser.add_argument('--keep-per-depth', type=int, default=1,
115 | help="Max elements to keep per depth")
116 | predict_parser.add_argument('--use-tensorboard', type=bool, default=False,
117 | help="Whether to use the tensorboard for logging the results make sure "
118 | "to add credentials to run.py if set to true")
119 |
120 |
121 | def initialize_tag_system(reader, tagging_schema, lang, tag_vocab_path="",
122 | add_remove_top=False):
123 | tag_vocab = None
124 | if tag_vocab_path != "":
125 | with open(tag_vocab_path + lang + "-" + tagging_schema + '.pkl', 'rb') as f:
126 | tag_vocab = pickle.load(f)
127 |
128 | if tagging_schema == HEXATAGGER:
129 | # for BHT
130 | tag_system = HexaTagger(
131 | trees=reader.parsed_sents(lang + '.bht.train'),
132 | tag_vocab=tag_vocab, add_remove_top=False
133 | )
134 | else:
135 | logging.error("Please specify the tagging schema")
136 | return
137 | return tag_system
138 |
139 |
140 | def get_data_path(tagger):
141 | if tagger == HEXATAGGER:
142 | return DEP_DATA_PATH
143 | return DATA_PATH
144 |
145 |
146 | def save_vocab(args):
147 | data_path = get_data_path(args.tagger)
148 | if args.tagger == HEXATAGGER:
149 | prefix = args.lang + ".bht"
150 | else:
151 | prefix = args.lang
152 | reader = BracketParseCorpusReader(
153 | data_path, [prefix + '.train', prefix + '.dev', prefix + '.test'])
154 | tag_system = initialize_tag_system(
155 | reader, args.tagger, args.lang, add_remove_top=True)
156 | with open(args.output_path + args.lang + "-" + args.tagger + '.pkl', 'wb+') as f:
157 | pickle.dump(tag_system.tag_vocab, f)
158 |
159 |
160 | def prepare_training_data(reader, tag_system, tagging_schema, model_name, batch_size, lang):
161 | is_tetratags = True if tagging_schema == TETRATAGGER or tagging_schema == HEXATAGGER else False
162 | prefix = lang + ".bht" if tagging_schema == HEXATAGGER else lang
163 |
164 | tokenizer = transformers.AutoTokenizer.from_pretrained(
165 | model_name, truncation=True, use_fast=True)
166 | train_dataset = TaggingDataset(prefix + '.train', tokenizer, tag_system, reader, device,
167 | is_tetratags=is_tetratags, language=lang)
168 | eval_dataset = TaggingDataset(prefix + '.test', tokenizer, tag_system, reader, device,
169 | is_tetratags=is_tetratags, language=lang)
170 | train_dataloader = DataLoader(
171 | train_dataset, shuffle=True, batch_size=batch_size, collate_fn=train_dataset.collate,
172 | pin_memory=True
173 | )
174 | eval_dataloader = DataLoader(
175 | eval_dataset, batch_size=batch_size, collate_fn=eval_dataset.collate, pin_memory=True
176 | )
177 | return train_dataset, eval_dataset, train_dataloader, eval_dataloader
178 |
179 |
180 | def prepare_test_data(reader, tag_system, tagging_schema, model_name, batch_size, lang):
181 | is_tetratags = True if tagging_schema == TETRATAGGER or tagging_schema == HEXATAGGER else False
182 | prefix = lang + ".bht" if tagging_schema == HEXATAGGER else lang
183 |
184 | print(f"Evaluating {model_name}, {tagging_schema}")
185 | tokenizer = transformers.AutoTokenizer.from_pretrained(
186 | model_name, truncation=True, use_fast=True)
187 | test_dataset = TaggingDataset(
188 | prefix + '.test', tokenizer, tag_system, reader, device,
189 | is_tetratags=is_tetratags, language=lang
190 | )
191 | test_dataloader = DataLoader(
192 | test_dataset, batch_size=batch_size, collate_fn=test_dataset.collate
193 | )
194 | return test_dataset, test_dataloader
195 |
196 |
197 | def generate_config(model_type, tagging_schema, tag_system, model_path, is_eng):
198 | if model_type in BERTCRF or model_type in BERTLSTM:
199 | config = transformers.AutoConfig.from_pretrained(
200 | model_path,
201 | num_labels=2 * len(tag_system.tag_vocab),
202 | task_specific_params={
203 | 'model_path': model_path,
204 | 'num_tags': len(tag_system.tag_vocab),
205 | 'is_eng': is_eng,
206 | }
207 | )
208 | elif model_type in BERT and tagging_schema in [TETRATAGGER, HEXATAGGER]:
209 | config = transformers.AutoConfig.from_pretrained(
210 | model_path,
211 | num_labels=len(tag_system.tag_vocab),
212 | id2label={i: label for i, label in enumerate(tag_system.tag_vocab)},
213 | label2id={label: i for i, label in enumerate(tag_system.tag_vocab)},
214 | task_specific_params={
215 | 'model_path': model_path,
216 | 'num_even_tags': tag_system.decode_moderator.leaf_tag_vocab_size,
217 | 'num_odd_tags': tag_system.decode_moderator.internal_tag_vocab_size,
218 | 'pos_emb_dim': 256,
219 | 'num_pos_tags': 50,
220 | 'lstm_layers': 3,
221 | 'dropout': 0.33,
222 | 'is_eng': is_eng,
223 | 'use_pos': True
224 | }
225 | )
226 | elif model_type in BERT and tagging_schema != TETRATAGGER and tagging_schema != HEXATAGGER:
227 | config = transformers.AutoConfig.from_pretrained(
228 | model_path,
229 | num_labels=2 * len(tag_system.tag_vocab),
230 | task_specific_params={
231 | 'model_path': model_path,
232 | 'num_even_tags': len(tag_system.tag_vocab),
233 | 'num_odd_tags': len(tag_system.tag_vocab),
234 | 'is_eng': is_eng
235 | }
236 | )
237 | else:
238 | logging.error("Invalid combination of model type and tagging schema")
239 | return
240 | return config
241 |
242 |
243 | def initialize_model(model_type, tagging_schema, tag_system, model_path, is_eng):
244 | config = generate_config(
245 | model_type, tagging_schema, tag_system, model_path, is_eng
246 | )
247 | if model_type in BERT:
248 | model = ModelForTetratagging(config=config)
249 | else:
250 | logging.error("Invalid model type")
251 | return
252 | return model
253 |
254 |
255 | def initialize_optimizer_and_scheduler(model, dataset_size, lr=5e-5, num_epochs=4,
256 | num_warmup_steps=160, weight_decay_rate=0.0):
257 | num_training_steps = num_epochs * dataset_size
258 | no_decay = ['bias', 'LayerNorm.weight', 'layer_norm.weight']
259 | grouped_parameters = [
260 | {
261 | "params": [p for n, p in model.named_parameters() if "bert" not in n],
262 | "weight_decay": 0.0,
263 | "lr": lr * 50, "betas": (0.9, 0.9),
264 | },
265 | {
266 | "params": [p for n, p in model.named_parameters() if
267 | "bert" in n and any(nd in n for nd in no_decay)],
268 | "weight_decay": 0.0,
269 | "lr": lr, "betas": (0.9, 0.999),
270 | },
271 | {
272 | "params": [p for n, p in model.named_parameters() if
273 | "bert" in n and not any(nd in n for nd in no_decay)],
274 | "weight_decay": 0.1,
275 | "lr": lr, "betas": (0.9, 0.999),
276 | },
277 | ]
278 |
279 | optimizer = AdamW(
280 | grouped_parameters, lr=lr
281 | )
282 | scheduler = transformers.get_linear_schedule_with_warmup(
283 | optimizer=optimizer,
284 | num_warmup_steps=num_warmup_steps,
285 | num_training_steps=num_training_steps
286 | )
287 |
288 | return optimizer, scheduler, num_training_steps
289 |
290 |
291 | def register_run_metrics(writer, run_name, lr, epochs, eval_loss, even_tag_accuracy,
292 | odd_tag_accuracy):
293 | writer.add_hparams({'run_name': run_name, 'lr': lr, 'epochs': epochs},
294 | {'eval_loss': eval_loss, 'odd_tag_accuracy': odd_tag_accuracy,
295 | 'even_tag_accuracy': even_tag_accuracy})
296 |
297 |
298 | def train_command(args):
299 | if args.tagger == HEXATAGGER:
300 | prefix = args.lang + ".bht"
301 | else:
302 | prefix = args.lang
303 | data_path = get_data_path(args.tagger)
304 | reader = BracketParseCorpusReader(data_path, [prefix + '.train', prefix + '.dev',
305 | prefix + '.test'])
306 | logging.info("Initializing Tag System")
307 | tag_system = initialize_tag_system(
308 | reader, args.tagger, args.lang,
309 | tag_vocab_path=args.tag_vocab_path, add_remove_top=True
310 | )
311 | logging.info("Preparing Data")
312 | train_dataset, eval_dataset, train_dataloader, eval_dataloader = prepare_training_data(
313 | reader, tag_system, args.tagger, args.model_path, args.batch_size, args.lang)
314 | logging.info("Initializing The Model")
315 | is_eng = True if args.lang == ENG else False
316 | model = initialize_model(
317 | args.model, args.tagger, tag_system, args.model_path, is_eng
318 | )
319 | model.to(device)
320 |
321 | train_set_size = len(train_dataloader)
322 | optimizer, scheduler, num_training_steps = initialize_optimizer_and_scheduler(
323 | model, train_set_size, args.lr, args.epochs,
324 | args.num_warmup_steps, args.weight_decay
325 | )
326 | optimizer.zero_grad()
327 | run_name = args.lang + "-" + args.tagger + "-" + args.model + "-" + str(
328 | args.lr) + "-" + str(args.epochs)
329 | writer = None
330 | if args.use_tensorboard:
331 | writer = SummaryWriter(comment=run_name)
332 |
333 | num_leaf_labels, num_tags = calc_num_tags_per_task(args.tagger, tag_system)
334 |
335 | logging.info("Starting The Training Loop")
336 | model.train()
337 | n_iter = 0
338 |
339 | last_fscore = 0
340 | best_fscore = 0
341 | tol = 99999
342 |
343 | for epo in tq(range(args.epochs)):
344 | logging.info(f"*******************EPOCH {epo}*******************")
345 | t = 1
346 | model.train()
347 |
348 | with tq(train_dataloader, disable=False) as progbar:
349 | for batch in progbar:
350 | batch = {k: v.to(device) for k, v in batch.items()}
351 |
352 | if device == "cuda":
353 | with torch.cuda.amp.autocast(
354 | enabled=True, dtype=torch.bfloat16
355 | ):
356 | outputs = model(**batch)
357 | else:
358 | with torch.cpu.amp.autocast(
359 | enabled=True, dtype=torch.bfloat16
360 | ):
361 | outputs = model(**batch)
362 |
363 | loss = outputs[0]
364 | loss.mean().backward()
365 | if args.use_tensorboard:
366 | writer.add_scalar('Loss/train', torch.mean(loss), n_iter)
367 | progbar.set_postfix(loss=torch.mean(loss).item())
368 |
369 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
370 | optimizer.step()
371 | scheduler.step()
372 | optimizer.zero_grad()
373 |
374 | n_iter += 1
375 | t += 1
376 |
377 | if True: # evaluation at the end of epoch
378 | predictions, eval_labels = predict(
379 | model, eval_dataloader, len(eval_dataset),
380 | num_tags, args.batch_size, device
381 | )
382 | calc_tag_accuracy(
383 | predictions, eval_labels,
384 | num_leaf_labels, writer, args.use_tensorboard)
385 |
386 | if args.tagger == HEXATAGGER:
387 | dev_metrics_las, dev_metrics_uas = dependency_eval(
388 | predictions, eval_labels, eval_dataset,
389 | tag_system, None, "", args.max_depth,
390 | args.keep_per_depth, False)
391 | else:
392 | dev_metrics = calc_parse_eval(
393 | predictions, eval_labels, eval_dataset,
394 | tag_system, None, "",
395 | args.max_depth, args.keep_per_depth, False)
396 |
397 | eval_loss = 0.5
398 | if args.tagger == HEXATAGGER:
399 | writer.add_scalar('LAS_Fscore/dev',
400 | dev_metrics_las.fscore, n_iter)
401 | writer.add_scalar('LAS_Precision/dev',
402 | dev_metrics_las.precision, n_iter)
403 | writer.add_scalar('LAS_Recall/dev',
404 | dev_metrics_las.recall, n_iter)
405 | writer.add_scalar('loss/dev', eval_loss, n_iter)
406 |
407 | logging.info("current LAS {}".format(dev_metrics_las))
408 | logging.info("current UAS {}".format(dev_metrics_uas))
409 | logging.info("last LAS fscore {}".format(last_fscore))
410 | logging.info("best LAS fscore {}".format(best_fscore))
411 | # setting main metric for model selection
412 | dev_metrics = dev_metrics_las
413 | else:
414 | writer.add_scalar('Fscore/dev', dev_metrics.fscore, n_iter)
415 | writer.add_scalar('Precision/dev', dev_metrics.precision, n_iter)
416 | writer.add_scalar('Recall/dev', dev_metrics.recall, n_iter)
417 | writer.add_scalar('loss/dev', eval_loss, n_iter)
418 |
419 | logging.info("current fscore {}".format(dev_metrics.fscore))
420 | logging.info("last fscore {}".format(last_fscore))
421 | logging.info("best fscore {}".format(best_fscore))
422 |
423 | # if dev_metrics.fscore > last_fscore or dev_loss < last...
424 | if dev_metrics.fscore > last_fscore:
425 | tol = 5
426 | if dev_metrics.fscore > best_fscore: # if dev_metrics.fscore > best_fscore:
427 | logging.info("save the best model")
428 | best_fscore = dev_metrics.fscore
429 | _save_best_model(model, args.output_path, run_name)
430 | elif dev_metrics.fscore > 0: # dev_metrics.fscore
431 | tol -= 1
432 |
433 | if tol < 0:
434 | _finish_training(model, tag_system, eval_dataloader,
435 | eval_dataset, eval_loss, run_name, writer, args)
436 | return
437 | if dev_metrics.fscore > 0: # not propagating the nan
438 | last_eval_loss = eval_loss
439 | last_fscore = dev_metrics.fscore
440 |
441 | # if dev_metrics.fscore > last_fscore or dev_loss < last...
442 | if dev_metrics.fscore > best_fscore:
443 | tol = 99999
444 | logging.info("tol refill")
445 | logging.info("save the best model")
446 | best_eval_loss = eval_loss
447 | best_fscore = dev_metrics.fscore
448 | _save_best_model(model, args.output_path, run_name)
449 | elif eval_loss > 0:
450 | tol -= 1
451 |
452 | if tol < 0:
453 | _finish_training(model, tag_system, eval_dataloader,
454 | eval_dataset, eval_loss, run_name, writer, args)
455 | return
456 | if eval_loss > 0: # not propagating the nan
457 | last_eval_loss = eval_loss
458 | # end of epoch
459 | pass
460 |
461 | _finish_training(model, tag_system, eval_dataloader, eval_dataset, eval_loss,
462 | run_name, writer, args)
463 |
464 |
465 | def _save_best_model(model, output_path, run_name):
466 | logging.info("Saving The Newly Found Best Model")
467 | os.makedirs(output_path, exist_ok=True)
468 | to_save_file = os.path.join(output_path, run_name)
469 | torch.save(model.state_dict(), to_save_file)
470 |
471 |
472 | def _finish_training(model, tag_system, eval_dataloader, eval_dataset, eval_loss,
473 | run_name, writer, args):
474 | num_leaf_labels, num_tags = calc_num_tags_per_task(args.tagger, tag_system)
475 | predictions, eval_labels = predict(model, eval_dataloader, len(eval_dataset),
476 | num_tags, args.batch_size,
477 | device)
478 | even_acc, odd_acc = calc_tag_accuracy(predictions, eval_labels, num_leaf_labels, writer,
479 | args.use_tensorboard)
480 | register_run_metrics(writer, run_name, args.lr,
481 | args.epochs, eval_loss, even_acc, odd_acc)
482 |
483 |
484 | def decode_model_name(model_name):
485 | name_chunks = model_name.split("-")
486 | name_chunks = name_chunks[1:]
487 | if name_chunks[0] == "td" or name_chunks[0] == "bu":
488 | tagging_schema = name_chunks[0] + "-" + name_chunks[1]
489 | model_type = name_chunks[2]
490 | else:
491 | tagging_schema = name_chunks[0]
492 | model_type = name_chunks[1]
493 | return tagging_schema, model_type
494 |
495 |
496 | def calc_num_tags_per_task(tagging_schema, tag_system):
497 | if tagging_schema == TETRATAGGER or tagging_schema == HEXATAGGER:
498 | num_leaf_labels = tag_system.decode_moderator.leaf_tag_vocab_size
499 | num_tags = len(tag_system.tag_vocab)
500 | else:
501 | num_leaf_labels = len(tag_system.tag_vocab)
502 | num_tags = 2 * len(tag_system.tag_vocab)
503 | return num_leaf_labels, num_tags
504 |
505 |
506 | def evaluate_command(args):
507 | tagging_schema, model_type = decode_model_name(args.model_name)
508 | data_path = get_data_path(tagging_schema) # HexaTagger or others
509 | print("Evaluation Args", args)
510 | if args.tagger == HEXATAGGER:
511 | prefix = args.lang + ".bht"
512 | else:
513 | prefix = args.lang
514 | reader = BracketParseCorpusReader(
515 | data_path,
516 | [prefix + '.train', prefix + '.dev', prefix + '.test'])
517 | writer = SummaryWriter(comment=args.model_name)
518 | logging.info("Initializing Tag System")
519 | tag_system = initialize_tag_system(reader, tagging_schema, args.lang,
520 | tag_vocab_path=args.tag_vocab_path,
521 | add_remove_top=True)
522 | logging.info("Preparing Data")
523 | eval_dataset, eval_dataloader = prepare_test_data(
524 | reader, tag_system, tagging_schema,
525 | args.bert_model_path, args.batch_size,
526 | args.lang)
527 |
528 | is_eng = True if args.lang == ENG else False
529 | model = initialize_model(
530 | model_type, tagging_schema, tag_system, args.bert_model_path, is_eng
531 | )
532 | model.load_state_dict(torch.load(args.model_path + args.model_name))
533 | model.to(device)
534 |
535 | num_leaf_labels, num_tags = calc_num_tags_per_task(tagging_schema, tag_system)
536 |
537 | predictions, eval_labels = predict(
538 | model, eval_dataloader, len(eval_dataset),
539 | num_tags, args.batch_size, device)
540 | calc_tag_accuracy(predictions, eval_labels,
541 | num_leaf_labels, writer, args.use_tensorboard)
542 | if tagging_schema == HEXATAGGER:
543 | dev_metrics_las, dev_metrics_uas = dependency_eval(
544 | predictions, eval_labels, eval_dataset,
545 | tag_system, args.output_path, args.model_name,
546 | args.max_depth, args.keep_per_depth, False)
547 | print(
548 | "LAS: ", dev_metrics_las, "\n",
549 | "UAS: ", dev_metrics_uas, sep=""
550 | )
551 | else:
552 | parse_metrics = calc_parse_eval(predictions, eval_labels, eval_dataset, tag_system,
553 | args.output_path,
554 | args.model_name,
555 | args.max_depth,
556 | args.keep_per_depth,
557 | args.is_greedy)
558 | print(parse_metrics)
559 |
560 |
561 | def predict_command(args):
562 | tagging_schema, model_type = decode_model_name(args.model_name)
563 | data_path = get_data_path(tagging_schema) # HexaTagger or others
564 | print("predict Args", args)
565 |
566 | if args.tagger == HEXATAGGER:
567 | prefix = args.lang + ".bht"
568 | else:
569 | prefix = args.lang
570 | reader = BracketParseCorpusReader(data_path, [])
571 | writer = SummaryWriter(comment=args.model_name)
572 | logging.info("Initializing Tag System")
573 | tag_system = initialize_tag_system(None, tagging_schema, args.lang,
574 | tag_vocab_path=args.tag_vocab_path,
575 | add_remove_top=True)
576 | logging.info("Preparing Data")
577 | eval_dataset, eval_dataloader = prepare_test_data(
578 | reader, tag_system, tagging_schema,
579 | args.bert_model_path, args.batch_size,
580 | "input")
581 |
582 | is_eng = True if args.lang == ENG else False
583 | model = initialize_model(
584 | model_type, tagging_schema, tag_system, args.bert_model_path, is_eng
585 | )
586 | model.load_state_dict(torch.load(args.model_path + args.model_name))
587 | model.to(device)
588 |
589 | num_leaf_labels, num_tags = calc_num_tags_per_task(tagging_schema, tag_system)
590 |
591 | predictions, eval_labels = predict(
592 | model, eval_dataloader, len(eval_dataset),
593 | num_tags, args.batch_size, device)
594 |
595 | if tagging_schema == HEXATAGGER:
596 | pred_output = dependency_decoding(
597 | predictions, eval_labels, eval_dataset,
598 | tag_system, args.output_path, args.model_name,
599 | args.max_depth, args.keep_per_depth, False)
600 |
601 | with open(args.output_path + args.model_name + ".pred.json", "w") as fout:
602 | print("Saving predictions to", args.output_path + args.model_name + ".pred.json")
603 | json.dump(pred_output, fout)
604 |
605 | """pred_output contains: {
606 | "predicted_dev_triples": predicted_dev_triples,
607 | "predicted_hexa_tags": pred_hexa_tags
608 | }"""
609 |
610 |
611 |
612 |
613 | def main():
614 | args = parser.parse_args()
615 | if args.command == 'train':
616 | train_command(args)
617 | elif args.command == 'evaluate':
618 | evaluate_command(args)
619 | elif args.command == 'predict':
620 | predict_command(args)
621 | elif args.command == 'vocab':
622 | save_vocab(args)
623 |
624 |
625 | if __name__ == '__main__':
626 | main()
627 |
--------------------------------------------------------------------------------
/tagging/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rycolab/parsing-as-tagging/e7a0be2d92dc44c8c5920a33db10dcfab8573dc1/tagging/__init__.py
--------------------------------------------------------------------------------
/tagging/hexatagger.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 |
3 | from nltk import ParentedTree as PTree
4 | from nltk import Tree
5 |
6 | from const import DUMMY_LABEL
7 | from tagging.tetratagger import BottomUpTetratagger
8 | from tagging.transform import RightCornerTransformer
9 | from tagging.tree_tools import debinarize_lex_tree, expand_unary
10 |
11 |
12 | class HexaTagger(BottomUpTetratagger, ABC):
13 | def preprocess(self, original_tree: Tree) -> PTree:
14 | tree = original_tree.copy(deep=True)
15 | tree.collapse_unary(collapsePOS=True, collapseRoot=True)
16 |
17 | ptree = PTree.convert(tree)
18 | root_label = ptree.label()
19 | tree_lc = PTree(root_label, [])
20 | RightCornerTransformer.transform(tree_lc, ptree, ptree)
21 | return tree_lc
22 |
23 | @staticmethod
24 | def create_shift_tag(label: str, left_or_right: str) -> str:
25 | arc_label = label.split("^^^")[-1]
26 | arc_label = arc_label.split("+")[0]
27 | return left_or_right + "/" + arc_label
28 |
29 | @staticmethod
30 | def _create_bi_reduce_tag(label: str, left_or_right: str) -> str:
31 | label = label.split("\\")[1]
32 | head_idx = label.split("^^^")[-1]
33 | # label = label.split("^^^")[0]
34 | if label.find("|") != -1: # drop extra node labels created after binarization
35 | return f'{left_or_right}' + "/" + f"{DUMMY_LABEL}^^^{head_idx}"
36 | else:
37 | return f'{left_or_right}' + "/" + label.replace("+", "/")
38 |
39 | @staticmethod
40 | def _create_unary_reduce_tag(label: str, left_or_right: str) -> str:
41 | label = label.split("\\")[0]
42 | head_idx = label.split("^^^")[-1]
43 | # label = label.split("^^^")[0]
44 | if label.find("|") != -1: # drop extra node labels created after binarization
45 | return f'{left_or_right}' + f"/{DUMMY_LABEL}^^^{head_idx}"
46 | else:
47 | return f'{left_or_right}' + "/" + label
48 |
49 | def tags_to_tree_pipeline(self, tags: [str], input_seq: []) -> Tree:
50 | ptree = self.tags_to_tree(tags, input_seq)
51 | return self.postprocess(ptree)
52 |
53 | @staticmethod
54 | def _create_pre_terminal_label(tag: str, default="X") -> str:
55 | arc_label = tag.split("/")[1]
56 | return f"X^^^{arc_label}+"
57 |
58 | @staticmethod
59 | def _create_unary_reduce_label(tag: str) -> str:
60 | idx = tag.find("/")
61 | if idx == -1:
62 | return DUMMY_LABEL
63 | return tag[idx + 1:].replace("/", "+")
64 |
65 | @staticmethod
66 | def _create_reduce_label(tag: str) -> str:
67 | idx = tag.find("/")
68 | if idx == -1:
69 | label = "X\\|" # to mark the second part as an extra node created via binarizaiton
70 | else:
71 | label = "X\\" + tag[idx + 1:].replace("/", "+")
72 | return label
73 |
74 | def postprocess(self, transformed_tree: PTree) -> Tree:
75 | tree = PTree("X", ["", ""])
76 | tree = RightCornerTransformer.rev_transform(tree, transformed_tree)
77 | tree = Tree.convert(tree)
78 | if len(tree.leaves()) == 1:
79 | expand_unary(tree)
80 | # edge case with one node
81 | return tree
82 |
83 | debinarized_tree = Tree(tree.label(), [])
84 | debinarize_lex_tree(tree, debinarized_tree)
85 | expand_unary(debinarized_tree)
86 | # debinarized_tree.pretty_print()
87 | return debinarized_tree
88 |
--------------------------------------------------------------------------------
/tagging/srtagger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from abc import ABC
3 |
4 | from nltk import ParentedTree as PTree
5 | from tqdm import tqdm as tq
6 |
7 | from learning.decode import BeamSearch, GreedySearch
8 | from tagging.tagger import Tagger, TagDecodeModerator
9 | from tagging.transform import LeftCornerTransformer
10 |
11 | import numpy as np
12 |
13 | from tagging.tree_tools import find_node_type, NodeType
14 |
15 |
16 | class SRTagDecodeModerator(TagDecodeModerator, ABC):
17 | def __init__(self, tag_vocab):
18 | super().__init__(tag_vocab)
19 | self.reduce_tag_size = len([tag for tag in tag_vocab if tag.startswith("r")])
20 | self.shift_tag_size = len([tag for tag in tag_vocab if tag.startswith("s")])
21 |
22 | self.rr_tag_size = len([tag for tag in tag_vocab if tag.startswith("rr")])
23 | self.sr_tag_size = len([tag for tag in tag_vocab if tag.startswith("sr")])
24 |
25 | self.rl_tag_size = self.reduce_tag_size - self.rr_tag_size # left reduce tag size
26 | self.sl_tag_size = self.shift_tag_size - self.sr_tag_size # left shift tag size
27 |
28 | self.mask_binarize = True
29 |
30 | def mask_scores_for_binarization(self, labels, scores) -> []:
31 | raise NotImplementedError
32 |
33 |
34 | class BUSRTagDecodeModerator(SRTagDecodeModerator):
35 | def __init__(self, tag_vocab):
36 | super().__init__(tag_vocab)
37 | stack_depth_change_by_id = [None] * len(tag_vocab)
38 | for i, tag in enumerate(tag_vocab):
39 | if tag.startswith("s"):
40 | stack_depth_change_by_id[i] = +1
41 | elif tag.startswith("r"):
42 | stack_depth_change_by_id[i] = -1
43 | assert None not in stack_depth_change_by_id
44 | self.stack_depth_change_by_id = np.array(
45 | stack_depth_change_by_id, dtype=int)
46 |
47 | self.reduce_only_mask = np.full((len(tag_vocab),), -np.inf)
48 | self.shift_only_mask = np.full((len(tag_vocab),), -np.inf)
49 | self.reduce_only_mask[:self.reduce_tag_size] = 0.0
50 | self.shift_only_mask[-self.shift_tag_size:] = 0.0
51 |
52 | def mask_scores_for_binarization(self, labels, scores) -> []:
53 | # after rr(right) -> only reduce, after rl(left) -> only shift
54 | # after sr(right) -> only reduce, after sl(left) -> only shift
55 | mask1 = np.where(
56 | (labels[:, None] >= self.rl_tag_size) & (labels[:, None] < self.reduce_tag_size),
57 | self.reduce_only_mask, 0.0)
58 | mask2 = np.where(labels[:, None] < self.rl_tag_size, self.shift_only_mask, 0.0)
59 | mask3 = np.where(
60 | labels[:, None] >= (self.sl_tag_size + self.reduce_tag_size),
61 | self.reduce_only_mask, 0.0)
62 | mask4 = np.where((labels[:, None] >= self.reduce_tag_size) & (
63 | labels[:, None] < (self.sl_tag_size + self.reduce_tag_size)),
64 | self.shift_only_mask, 0.0)
65 | all_new_scores = scores + mask1 + mask2 + mask3 + mask4
66 | return all_new_scores
67 |
68 |
69 | class TDSRTagDecodeModerator(SRTagDecodeModerator):
70 | def __init__(self, tag_vocab):
71 | super().__init__(tag_vocab)
72 | is_shift_mask = np.concatenate(
73 | [
74 | np.zeros(self.reduce_tag_size),
75 | np.ones(self.shift_tag_size),
76 | ]
77 | )
78 | self.reduce_tags_only = np.asarray(-1e9 * is_shift_mask, dtype=float)
79 |
80 | stack_depth_change_by_id = [None] * len(tag_vocab)
81 | stack_depth_change_by_id_l2 = [None] * len(tag_vocab)
82 | for i, tag in enumerate(tag_vocab):
83 | if tag.startswith("s"):
84 | stack_depth_change_by_id_l2[i] = 0
85 | stack_depth_change_by_id[i] = -1
86 | elif tag.startswith("r"):
87 | stack_depth_change_by_id_l2[i] = -1
88 | stack_depth_change_by_id[i] = +2
89 | assert None not in stack_depth_change_by_id
90 | assert None not in stack_depth_change_by_id_l2
91 | self.stack_depth_change_by_id = np.array(
92 | stack_depth_change_by_id, dtype=int)
93 | self.stack_depth_change_by_id_l2 = np.array(
94 | stack_depth_change_by_id_l2, dtype=int)
95 | self._initialize_binarize_mask(tag_vocab)
96 |
97 | def _initialize_binarize_mask(self, tag_vocab) -> None:
98 | self.right_only_mask = np.full((len(tag_vocab),), -np.inf)
99 | self.left_only_mask = np.full((len(tag_vocab),), -np.inf)
100 |
101 | self.right_only_mask[self.rl_tag_size:self.reduce_tag_size] = 0.0
102 | self.right_only_mask[-self.sr_tag_size:] = 0.0
103 |
104 | self.left_only_mask[:self.rl_tag_size] = 0.0
105 | self.left_only_mask[
106 | self.reduce_tag_size:self.reduce_tag_size + self.sl_tag_size] = 0.0
107 |
108 | def mask_scores_for_binarization(self, labels, scores) -> []:
109 | # if shift -> rr and sr, if reduce -> rl and sl
110 | mask1 = np.where(labels[:, None] >= self.reduce_tag_size, self.right_only_mask, 0.0)
111 | mask2 = np.where(labels[:, None] < self.reduce_tag_size, self.left_only_mask, 0.0)
112 | all_new_scores = scores + mask1 + mask2
113 | return all_new_scores
114 |
115 |
116 | class SRTagger(Tagger, ABC):
117 | def __init__(self, trees=None, tag_vocab=None, add_remove_top=False):
118 | super().__init__(trees, tag_vocab, add_remove_top)
119 |
120 | def add_trees_to_vocab(self, trees: []) -> None:
121 | self.label_vocab = set()
122 | for tree in tq(trees):
123 | for tag in self.tree_to_tags_pipeline(tree)[0]:
124 | self.tag_vocab.add(tag)
125 | idx = tag.find("/")
126 | if idx != -1:
127 | self.label_vocab.add(tag[idx + 1:])
128 | else:
129 | self.label_vocab.add("")
130 | self.tag_vocab = sorted(self.tag_vocab)
131 | self.label_vocab = sorted(self.label_vocab)
132 |
133 | @staticmethod
134 | def create_shift_tag(label: str, is_right_child=False) -> str:
135 | suffix = "r" if is_right_child else ""
136 | if label.find("+") != -1:
137 | tag = "s" + suffix + "/" + "/".join(label.split("+")[:-1])
138 | else:
139 | tag = "s" + suffix
140 | return tag
141 |
142 | @staticmethod
143 | def create_shift_label(tag: str) -> str:
144 | idx = tag.find("/")
145 | if idx != -1:
146 | return tag[idx + 1:].replace("/", "+") + "+"
147 | else:
148 | return ""
149 |
150 | @staticmethod
151 | def create_reduce_tag(label: str, is_right_child=False) -> str:
152 | if label.find("|") != -1: # drop extra node labels created after binarization
153 | tag = "r"
154 | else:
155 | tag = "r" + "/" + label.replace("+", "/")
156 | return "r" + tag if is_right_child else tag
157 |
158 | @staticmethod
159 | def _create_reduce_label(tag: str) -> str:
160 | idx = tag.find("/")
161 | if idx == -1:
162 | label = "|" # to mark the second part as an extra node created via binarizaiton
163 | else:
164 | label = tag[idx + 1:].replace("/", "+")
165 | return label
166 |
167 |
168 | class SRTaggerBottomUp(SRTagger):
169 | def __init__(self, trees=None, tag_vocab=None, add_remove_top=False):
170 | super().__init__(trees, tag_vocab, add_remove_top)
171 | self.decode_moderator = BUSRTagDecodeModerator(self.tag_vocab)
172 |
173 | def tree_to_tags(self, root: PTree) -> ([str], int):
174 | tags = []
175 | lc = LeftCornerTransformer.extract_left_corner_no_eps(root)
176 | if len(root) == 1: # edge case
177 | tags.append(self.create_shift_tag(lc.label(), False))
178 | return tags, 1
179 |
180 | is_right_child = lc.left_sibling() is not None
181 | tags.append(self.create_shift_tag(lc.label(), is_right_child))
182 |
183 | logging.debug("SHIFT {}".format(lc.label()))
184 | stack = [lc]
185 | max_stack_len = 1
186 |
187 | while len(stack) > 0:
188 | node = stack[-1]
189 | max_stack_len = max(max_stack_len, len(stack))
190 |
191 | if node.left_sibling() is None and node.right_sibling() is not None:
192 | lc = LeftCornerTransformer.extract_left_corner_no_eps(node.right_sibling())
193 | stack.append(lc)
194 | logging.debug("SHIFT {}".format(lc.label()))
195 | is_right_child = lc.left_sibling() is not None
196 | tags.append(self.create_shift_tag(lc.label(), is_right_child))
197 |
198 | elif len(stack) >= 2 and (
199 | node.right_sibling() == stack[-2] or node.left_sibling() == stack[-2]):
200 | prev_node = stack[-2]
201 | logging.debug("REDUCE[ {0} {1} --> {2} ]".format(
202 | *(prev_node.label(), node.label(), node.parent().label())))
203 |
204 | parent_is_right = node.parent().left_sibling() is not None
205 | tags.append(self.create_reduce_tag(node.parent().label(), parent_is_right))
206 | stack.pop()
207 | stack.pop()
208 | stack.append(node.parent())
209 |
210 | elif stack[0].parent() is None and len(stack) == 1:
211 | stack.pop()
212 | continue
213 | return tags, max_stack_len
214 |
215 | def tags_to_tree(self, tags: [str], input_seq: [str]) -> PTree:
216 | created_node_stack = []
217 | node = None
218 |
219 | if len(tags) == 1: # base case
220 | assert tags[0].startswith('s')
221 | prefix = self.create_shift_label(tags[0])
222 | return PTree(prefix + input_seq[0][1], [input_seq[0][0]])
223 | for tag in tags:
224 | if tag.startswith('s'):
225 | prefix = self.create_shift_label(tag)
226 | created_node_stack.append(PTree(prefix + input_seq[0][1], [input_seq[0][0]]))
227 | input_seq.pop(0)
228 | else:
229 | last_node = created_node_stack.pop()
230 | last_2_node = created_node_stack.pop()
231 | node = PTree(self._create_reduce_label(tag), [last_2_node, last_node])
232 | created_node_stack.append(node)
233 |
234 | if len(input_seq) != 0:
235 | raise ValueError("All the input sequence is not used")
236 | return node
237 |
238 | def logits_to_ids(self, logits: [], mask, max_depth, keep_per_depth, crf_transitions=None,
239 | is_greedy=False) -> [int]:
240 | if is_greedy:
241 | searcher = GreedySearch(
242 | self.decode_moderator,
243 | initial_stack_depth=0,
244 | crf_transitions=crf_transitions,
245 | max_depth=max_depth,
246 | keep_per_depth=keep_per_depth,
247 | )
248 | else:
249 |
250 | searcher = BeamSearch(
251 | self.decode_moderator,
252 | initial_stack_depth=0,
253 | crf_transitions=crf_transitions,
254 | max_depth=max_depth,
255 | keep_per_depth=keep_per_depth,
256 | )
257 |
258 | last_t = None
259 | seq_len = sum(mask)
260 | idx = 1
261 | for t in range(logits.shape[0]):
262 | if mask is not None and not mask[t]:
263 | continue
264 | if last_t is not None:
265 | searcher.advance(
266 | logits[last_t, :-len(self.tag_vocab)]
267 | )
268 | if idx == seq_len:
269 | searcher.advance(logits[t, -len(self.tag_vocab):], is_last=True)
270 | else:
271 | searcher.advance(logits[t, -len(self.tag_vocab):])
272 | last_t = t
273 |
274 | score, best_tag_ids = searcher.get_path()
275 | return best_tag_ids
276 |
277 |
278 | class SRTaggerTopDown(SRTagger):
279 | def __init__(self, trees=None, tag_vocab=None, add_remove_top=False):
280 | super().__init__(trees, tag_vocab, add_remove_top)
281 | self.decode_moderator = TDSRTagDecodeModerator(self.tag_vocab)
282 |
283 | def tree_to_tags(self, root: PTree) -> ([str], int):
284 | stack: [PTree] = [root]
285 | max_stack_len = 1
286 | tags = []
287 |
288 | while len(stack) > 0:
289 | node = stack[-1]
290 | max_stack_len = max(max_stack_len, len(stack))
291 |
292 | if find_node_type(node) == NodeType.NT:
293 | stack.pop()
294 | logging.debug("REDUCE[ {0} --> {1} {2}]".format(
295 | *(node.label(), node[0].label(), node[1].label())))
296 | is_right_node = node.left_sibling() is not None
297 | tags.append(self.create_reduce_tag(node.label(), is_right_node))
298 | stack.append(node[1])
299 | stack.append(node[0])
300 |
301 | else:
302 | logging.debug("-->\tSHIFT[ {0} ]".format(node.label()))
303 | is_right_node = node.left_sibling() is not None
304 | tags.append(self.create_shift_tag(node.label(), is_right_node))
305 | stack.pop()
306 |
307 | return tags, max_stack_len
308 |
309 | def tags_to_tree(self, tags: [str], input_seq: [str]) -> PTree:
310 | if len(tags) == 1: # base case
311 | assert tags[0].startswith('s')
312 | prefix = self.create_shift_label(tags[0])
313 | return PTree(prefix + input_seq[0][1], [input_seq[0][0]])
314 |
315 | assert tags[0].startswith('r')
316 | node = PTree(self._create_reduce_label(tags[0]), [])
317 | created_node_stack: [PTree] = [node]
318 |
319 | for tag in tags[1:]:
320 | parent: PTree = created_node_stack[-1]
321 | if tag.startswith('s'):
322 | prefix = self.create_shift_label(tag)
323 | new_node = PTree(prefix + input_seq[0][1], [input_seq[0][0]])
324 | input_seq.pop(0)
325 | else:
326 | label = self._create_reduce_label(tag)
327 | new_node = PTree(label, [])
328 |
329 | if len(parent) == 0:
330 | parent.insert(0, new_node)
331 | elif len(parent) == 1:
332 | parent.insert(1, new_node)
333 | created_node_stack.pop()
334 |
335 | if tag.startswith('r'):
336 | created_node_stack.append(new_node)
337 |
338 | if len(input_seq) != 0:
339 | raise ValueError("All the input sequence is not used")
340 | return node
341 |
342 | def logits_to_ids(self, logits: [], mask, max_depth, keep_per_depth, crf_transitions=None,
343 | is_greedy=False) -> [int]:
344 | if is_greedy:
345 | searcher = GreedySearch(
346 | self.decode_moderator,
347 | initial_stack_depth=1,
348 | crf_transitions=crf_transitions,
349 | max_depth=max_depth,
350 | min_depth=0,
351 | keep_per_depth=keep_per_depth,
352 | )
353 | else:
354 | searcher = BeamSearch(
355 | self.decode_moderator,
356 | initial_stack_depth=1,
357 | crf_transitions=crf_transitions,
358 | max_depth=max_depth,
359 | min_depth=0,
360 | keep_per_depth=keep_per_depth,
361 | )
362 |
363 | last_t = None
364 | seq_len = sum(mask)
365 | idx = 1
366 | is_last = False
367 | for t in range(logits.shape[0]):
368 | if mask is not None and not mask[t]:
369 | continue
370 | if last_t is not None:
371 | searcher.advance(
372 | logits[last_t, :-len(self.tag_vocab)]
373 | )
374 | if idx == seq_len:
375 | is_last = True
376 | if last_t is None:
377 | searcher.advance(
378 | logits[t, -len(self.tag_vocab):] + self.decode_moderator.reduce_tags_only,
379 | is_last=is_last)
380 | else:
381 | searcher.advance(logits[t, -len(self.tag_vocab):], is_last=is_last)
382 | last_t = t
383 |
384 | score, best_tag_ids = searcher.get_path(required_stack_depth=0)
385 | return best_tag_ids
386 |
--------------------------------------------------------------------------------
/tagging/tagger.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 |
3 | from nltk import ParentedTree as PTree
4 | from nltk import Tree
5 | from tqdm import tqdm as tq
6 |
7 | from tagging.tree_tools import add_plus_to_tree, remove_plus_from_tree
8 |
9 |
10 | class TagDecodeModerator(ABC):
11 | def __init__(self, tag_vocab):
12 | self.vocab_len = len(tag_vocab)
13 | self.stack_depth_change_by_id = None
14 | self.stack_depth_change_by_id_l2 = None
15 |
16 |
17 | class Tagger(ABC):
18 | def __init__(self, trees=None, tag_vocab=None, add_remove_top=False):
19 | self.tag_vocab = set()
20 | self.add_remove_top = add_remove_top
21 |
22 | if tag_vocab is not None:
23 | self.tag_vocab = tag_vocab
24 | elif trees is not None:
25 | self.add_trees_to_vocab(trees)
26 |
27 | self.decode_moderator = None
28 |
29 | def add_trees_to_vocab(self, trees: []) -> None:
30 | for tree in tq(trees):
31 | tags = self.tree_to_tags_pipeline(tree)[0]
32 | for tag in tags:
33 | self.tag_vocab.add(tag)
34 | self.tag_vocab = sorted(self.tag_vocab)
35 |
36 | def tree_to_tags(self, root: PTree) -> [str]:
37 | raise NotImplementedError("tree to tags is not implemented")
38 |
39 | def tags_to_tree(self, tags: [str], input_seq: [str]) -> PTree:
40 | raise NotImplementedError("tags to tree is not implemented")
41 |
42 | def tree_to_tags_pipeline(self, tree: Tree) -> ([str], int):
43 | ptree = self.preprocess(tree)
44 | return self.tree_to_tags(ptree)
45 |
46 | def tree_to_ids_pipeline(self, tree: Tree) -> [int]:
47 | tags = self.tree_to_tags_pipeline(tree)[0]
48 | res = []
49 | for tag in tags:
50 | if tag in self.tag_vocab:
51 | res.append(self.tag_vocab.index(tag))
52 | elif tag[0:tag.find("/")] in self.tag_vocab:
53 | res.append(self.tag_vocab.index(tag[0:tag.find("/")]))
54 | else:
55 | res.append(0)
56 |
57 | return res
58 |
59 | def tags_to_tree_pipeline(self, tags: [str], input_seq: []) -> Tree:
60 | filtered_input_seq = []
61 | for token, pos in input_seq:
62 | filtered_input_seq.append((token, pos.replace("+", "@")))
63 | ptree = self.tags_to_tree(tags, filtered_input_seq)
64 | return self.postprocess(ptree)
65 |
66 | def ids_to_tree_pipeline(self, ids: [int], input_seq: []) -> Tree:
67 | tags = [self.tag_vocab[idx] for idx in ids]
68 | return self.tags_to_tree_pipeline(tags, input_seq)
69 |
70 | def logits_to_ids(self, logits: [], mask, max_depth, keep_per_depth, is_greedy=False) -> [int]:
71 | raise NotImplementedError("logits to ids is not implemented")
72 |
73 | def logits_to_tree(self, logits: [], leave_nodes: [], mask=None, max_depth=5, keep_per_depth=1, is_greedy=False) -> Tree:
74 | ids = self.logits_to_ids(logits, mask, max_depth, keep_per_depth, is_greedy=is_greedy)
75 | return self.ids_to_tree_pipeline(ids, leave_nodes)
76 |
77 | def preprocess(self, original_tree: Tree) -> PTree:
78 | tree = original_tree.copy(deep=True)
79 | if self.add_remove_top:
80 | cut_off_tree = tree[0]
81 | else:
82 | cut_off_tree = tree
83 |
84 | remove_plus_from_tree(cut_off_tree)
85 | cut_off_tree.collapse_unary(collapsePOS=True, collapseRoot=True)
86 | cut_off_tree.chomsky_normal_form()
87 | ptree = PTree.convert(cut_off_tree)
88 | return ptree
89 |
90 | def postprocess(self, tree: PTree) -> Tree:
91 | tree = Tree.convert(tree)
92 | tree.un_chomsky_normal_form()
93 | add_plus_to_tree(tree)
94 | if self.add_remove_top:
95 | return Tree("TOP", [tree])
96 | else:
97 | return tree
98 |
--------------------------------------------------------------------------------
/tagging/test.py:
--------------------------------------------------------------------------------
1 | # import logging
2 | import unittest
3 |
4 | import numpy as np
5 | from nltk import ParentedTree
6 | from nltk import Tree
7 |
8 | from learning.evaluate import evalb
9 | from tetratagger import BottomUpTetratagger, TopDownTetratagger
10 | from tagging.srtagger import SRTaggerBottomUp, SRTaggerTopDown
11 | from transform import LeftCornerTransformer, RightCornerTransformer
12 | from tree_tools import random_tree, is_topo_equal, create_dummy_tree
13 |
14 | from original_tetratagger import TetraTagSequence, TetraTagSystem
15 | from nltk.corpus.reader.bracket_parse import BracketParseCorpusReader
16 | from tqdm import tqdm as tq
17 |
18 | # logging.getLogger().setLevel(logging.DEBUG)
19 |
20 | np.random.seed(0)
21 |
22 |
23 | class TestTransforms(unittest.TestCase):
24 | def test_transform(self):
25 | tree = ParentedTree.fromstring("(S (NP (det the) (N dog)) (VP (V ran) (Adv fast)))")
26 | tree.pretty_print()
27 | new_tree_lc = ParentedTree("S", [])
28 | LeftCornerTransformer.transform(new_tree_lc, tree, tree)
29 | new_tree_lc.pretty_print()
30 |
31 | new_tree_rc = ParentedTree("S", [])
32 | RightCornerTransformer.transform(new_tree_rc, tree, tree)
33 | new_tree_rc.pretty_print()
34 |
35 | def test_rev_rc_transform(self, trials=100):
36 | for _ in range(trials):
37 | t = ParentedTree("ROOT", [])
38 | random_tree(t, depth=0, cutoff=5)
39 | new_tree_rc = ParentedTree("S", [])
40 | RightCornerTransformer.transform(new_tree_rc, t, t)
41 | tree_back = ParentedTree("X", ["", ""])
42 | tree_back = RightCornerTransformer.rev_transform(tree_back, new_tree_rc)
43 | self.assertEqual(tree_back, t)
44 |
45 | def test_rev_lc_transform(self, trials=100):
46 | for _ in range(trials):
47 | t = ParentedTree("ROOT", [])
48 | random_tree(t, depth=0, cutoff=5)
49 | new_tree_lc = ParentedTree("S", [])
50 | LeftCornerTransformer.transform(new_tree_lc, t, t)
51 | tree_back = ParentedTree("X", ["", ""])
52 | tree_back = LeftCornerTransformer.rev_transform(tree_back, new_tree_lc)
53 | self.assertEqual(tree_back, t)
54 |
55 |
56 | class TestTagging(unittest.TestCase):
57 | def test_buttom_up(self):
58 | tree = ParentedTree.fromstring("(S (NP (det the) (N dog)) (VP (V ran) (Adv fast)))")
59 | tree.pretty_print()
60 | tree_rc = ParentedTree("S", [])
61 | RightCornerTransformer.transform(tree_rc, tree, tree)
62 | tree_rc.pretty_print()
63 | tagger = BottomUpTetratagger()
64 | tags, _ = tagger.tree_to_tags(tree_rc)
65 |
66 | for tag in tagger.tetra_visualize(tags):
67 | print(tag)
68 | print("--" * 20)
69 |
70 | def test_buttom_up_alternate(self, trials=100):
71 | for _ in range(trials):
72 | t = ParentedTree("ROOT", [])
73 | random_tree(t, depth=0, cutoff=5)
74 | t_rc = ParentedTree("S", [])
75 | RightCornerTransformer.transform(t_rc, t, t)
76 | tagger = BottomUpTetratagger()
77 | tags, _ = tagger.tree_to_tags(t_rc)
78 | self.assertTrue(tagger.is_alternating(tags))
79 | self.assertTrue((2 * len(t.leaves()) - 1) == len(tags))
80 |
81 | def test_round_trip_test_buttom_up(self, trials=100):
82 | for _ in range(trials):
83 | tree = ParentedTree("ROOT", [])
84 | random_tree(tree, depth=0, cutoff=5)
85 | tree_rc = ParentedTree("S", [])
86 | RightCornerTransformer.transform(tree_rc, tree, tree)
87 | tree.pretty_print()
88 | tree_rc.pretty_print()
89 | tagger = BottomUpTetratagger()
90 | tags, _ = tagger.tree_to_tags(tree_rc)
91 | root_from_tags = tagger.tags_to_tree(tags, tree.leaves())
92 | tree_back = ParentedTree("X", ["", ""])
93 | tree_back = RightCornerTransformer.rev_transform(tree_back, root_from_tags,
94 | pick_up_labels=False)
95 | self.assertTrue(is_topo_equal(tree, tree_back))
96 |
97 | def test_top_down(self):
98 | tree = ParentedTree.fromstring("(S (NP (det the) (N dog)) (VP (V ran) (Adv fast)))")
99 | tree.pretty_print()
100 | tree_lc = ParentedTree("S", [])
101 | LeftCornerTransformer.transform(tree_lc, tree, tree)
102 | tree_lc.pretty_print()
103 | tagger = TopDownTetratagger()
104 | tags, _ = tagger.tree_to_tags(tree_lc)
105 |
106 | for tag in tagger.tetra_visualize(tags):
107 | print(tag)
108 | print("--" * 20)
109 |
110 | def test_top_down_alternate(self, trials=100):
111 | for _ in range(trials):
112 | t = ParentedTree("ROOT", [])
113 | random_tree(t, depth=0, cutoff=5)
114 | t_lc = ParentedTree("S", [])
115 | LeftCornerTransformer.transform(t_lc, t, t)
116 | tagger = TopDownTetratagger()
117 | tags = tagger.tree_to_tags(t_lc)
118 | self.assertTrue(tagger.is_alternating(tags))
119 | self.assertTrue((2 * len(t.leaves()) - 1) == len(tags))
120 |
121 | def round_trip_test_top_down(self, trials=100):
122 | for _ in range(trials):
123 | tree = ParentedTree("ROOT", [])
124 | random_tree(tree, depth=0, cutoff=5)
125 | tree_lc = ParentedTree("S", [])
126 | LeftCornerTransformer.transform(tree_lc, tree, tree)
127 | tagger = TopDownTetratagger()
128 | tags = tagger.tree_to_tags(tree_lc)
129 | root_from_tags = tagger.tags_to_tree(tags, tree.leaves())
130 | tree_back = ParentedTree("X", ["", ""])
131 | tree_back = LeftCornerTransformer.rev_transform(tree_back, root_from_tags,
132 | pick_up_labels=False)
133 | self.assertTrue(is_topo_equal(tree, tree_back))
134 |
135 |
136 | class TestPipeline(unittest.TestCase):
137 | def test_example_colab(self):
138 | example_tree = Tree.fromstring(
139 | "(S (NP (PRP She)) (VP (VBZ enjoys) (S (VP (VBG playing) (NP (NN tennis))))) (. .))")
140 | tagger = BottomUpTetratagger()
141 | tags = tagger.tree_to_tags_pipeline(example_tree)[0]
142 | print(tags)
143 | for tag in tagger.tetra_visualize(tags):
144 | print(tag)
145 |
146 | def test_dummy_tree(self):
147 | example_tree = Tree.fromstring(
148 | "(S (NP (PRP She)) (VP (VBZ enjoys) (S (VP (VBG playing) (NP (NN tennis))))) (. .))")
149 | print(example_tree.pos())
150 | dummy = create_dummy_tree(example_tree.pos())
151 | dummy.pretty_print()
152 | example_tree.pretty_print()
153 | print(evalb("../EVALB/", [example_tree], [dummy]))
154 |
155 | def test_tree_linearizations(self):
156 | READER = BracketParseCorpusReader('../data/spmrl/',
157 | ['English.train', 'English.dev', 'English.test'])
158 | trees = READER.parsed_sents('English.test')
159 | for tree in trees:
160 | print(tree)
161 | print(" ".join(tree.leaves()))
162 |
163 | def test_compare_to_original_tetratagger(self):
164 | # import pickle
165 | # with open("../data/tetra.pkl", 'rb') as f:
166 | # tag_vocab = pickle.load(f)
167 |
168 | READER = BracketParseCorpusReader('../data/spmrl/', ['English.train', 'English.dev', 'English.test'])
169 | trees = READER.parsed_sents('English.test')
170 | tagger = BottomUpTetratagger(add_remove_top=True)
171 | tetratagger = TetraTagSystem(trees=trees)
172 | for tree in tq(trees):
173 | original_tree = tree.copy(deep=True)
174 | # original_tags = TetraTagSequence.from_tree(original_tree)
175 | tags = tagger.tree_to_tags_pipeline(tree)[0]
176 | tetratags = tetratagger.tags_from_tree(tree)
177 | # ids = tagger.tree_to_ids_pipeline(tree)
178 | # ids = tagger.ids_from_tree(tree)
179 | # tree_back = tagger.tags_to_tree_pipeline(tags, tree.pos())
180 | # self.assertEqual(original_tree, tree_back)
181 | self.assertEqual(list(tetratags), list(tags))
182 |
183 | def test_example_colab_lc(self):
184 | example_tree = Tree.fromstring(
185 | "(S (NP (PRP She)) (VP (VBZ enjoys) (S (VP (VBG playing) (NP (NN tennis))))) (. .))")
186 | original_tree = example_tree.copy(deep=True)
187 | tagger = TopDownTetratagger()
188 | tags = tagger.tree_to_tags_pipeline(example_tree)[0]
189 | tree_back = tagger.tags_to_tree_pipeline(tags, example_tree.pos())
190 | tree_back.pretty_print()
191 | self.assertEqual(original_tree, tree_back)
192 |
193 | def test_top_down_tetratagger(self):
194 | READER = BracketParseCorpusReader('../data/spmrl/', ['English.train', 'English.dev', 'English.test'])
195 | trees = READER.parsed_sents('English.test')
196 | tagger = TopDownTetratagger(add_remove_top=True)
197 | for tree in tq(trees):
198 | original_tree = tree.copy(deep=True)
199 | tags = tagger.tree_to_tags_pipeline(tree)[0]
200 | tree_back = tagger.tags_to_tree_pipeline(tags, tree.pos())
201 | self.assertEqual(original_tree, tree_back)
202 |
203 | def test_tag_ids_top_down(self):
204 | READER = BracketParseCorpusReader('../data/spmrl/', ['English.train', 'English.dev', 'English.test'])
205 | trees = READER.parsed_sents('English.test')
206 | tagger = TopDownTetratagger(trees, add_remove_top=True)
207 | for tree in tq(trees):
208 | original_tree = tree.copy(deep=True)
209 | ids = tagger.tree_to_ids_pipeline(tree)
210 | tree_back = tagger.ids_to_tree_pipeline(ids, tree.pos())
211 | self.assertEqual(original_tree, tree_back)
212 |
213 | def test_tag_ids_bottom_up(self):
214 | READER = BracketParseCorpusReader('../data/spmrl/', ['English.train', 'English.dev', 'English.test'])
215 | trees = READER.parsed_sents('English.test')
216 | tagger = BottomUpTetratagger(trees, add_remove_top=True)
217 | for tree in tq(trees):
218 | original_tree = tree.copy(deep=True)
219 | ids = tagger.tree_to_ids_pipeline(tree)
220 | tree_back = tagger.ids_to_tree_pipeline(ids, tree.pos())
221 | self.assertEqual(original_tree, tree_back)
222 |
223 | def test_decoder_edges(self):
224 | READER = BracketParseCorpusReader('../data/spmrl/', ['English.train', 'English.dev', 'English.test'])
225 | trees = READER.parsed_sents('English.test')
226 | tagger_bu = SRTaggerBottomUp(add_remove_top=True)
227 | tagger_td = SRTaggerTopDown(add_remove_top=True)
228 | unique_tags = dict()
229 | counter = 0
230 | for tree in tq(trees):
231 | counter += 1
232 | tags = tagger_td.tree_to_tags_pipeline(tree)[0]
233 | for i, tag in enumerate(tags):
234 | tag_s = tag.split("/")[0]
235 | if i < len(tags) - 1:
236 | tag_next = tags[i+1].split("/")[0]
237 | else:
238 | tag_next = None
239 | if tag_s not in unique_tags and tag_next is not None:
240 | unique_tags[tag_s] = {tag_next}
241 | elif tag_next is not None:
242 | unique_tags[tag_s].add(tag_next)
243 | print(unique_tags)
244 |
245 |
246 | class TestSRTagger(unittest.TestCase):
247 | def test_tag_sequence_example(self):
248 | READER = BracketParseCorpusReader('../data/spmrl/', ['English.train', 'English.dev', 'English.test'])
249 | trees = READER.parsed_sents('English.train')
250 | tagger = SRTaggerBottomUp(add_remove_top=True)
251 | for tree in tq(trees):
252 | original_tree = tree.copy(deep=True)
253 | tags = tagger.tree_to_tags_pipeline(tree)[0]
254 | for tag in tags:
255 | if tag.find('/X/') != -1:
256 | print(tag)
257 | original_tree.pretty_print()
258 | tree_back = tagger.tags_to_tree_pipeline(tags, tree.pos())
259 | self.assertEqual(original_tree, tree_back)
260 |
261 | def test_example_both_version(self):
262 | import nltk
263 | example_tree = nltk.Tree.fromstring(
264 | "(TOP (S (NP (PRP She)) (VP (VBZ enjoys) (S (VP (VBG playing) (NP (NN tennis))))) (. .)))")
265 | example_tree.pretty_print()
266 | td_tagger = SRTaggerTopDown(add_remove_top=True)
267 | bu_tagger = SRTaggerBottomUp(add_remove_top=True)
268 | t1 = example_tree.copy(deep=True)
269 | t2 = example_tree.copy(deep=True)
270 | td_tags = td_tagger.tree_to_tags_pipeline(t1)[0]
271 | bu_tags = bu_tagger.tree_to_tags_pipeline(t2)[0]
272 | self.assertEqual(set(td_tags), set(bu_tags))
273 | print(list(td_tags))
274 | print(list(bu_tags))
275 |
276 | def test_bu_binarize(self):
277 | READER = BracketParseCorpusReader('../data/spmrl/', ['English.train', 'English.dev', 'English.test'])
278 | trees = READER.parsed_sents('English.dev')
279 | tagger = SRTaggerBottomUp(add_remove_top=True)
280 | for tree in tq(trees):
281 | tags = tagger.tree_to_tags_pipeline(tree)[0]
282 | for idx, tag in enumerate(tags):
283 | if tag.startswith("rr"):
284 | if (idx + 1) < len(tags):
285 | self.assertTrue(tags[idx + 1].startswith('r'))
286 | elif tag.startswith("r"):
287 | if (idx + 1) < len(tags):
288 | self.assertTrue(tags[idx + 1].startswith('s'))
289 |
290 | def test_td_binarize(self):
291 | READER = BracketParseCorpusReader('../data/spmrl/', ['English.train', 'English.dev', 'English.test'])
292 | trees = READER.parsed_sents('English.dev')
293 | tagger = SRTaggerTopDown(add_remove_top=True)
294 | for tree in tq(trees):
295 | tags = tagger.tree_to_tags_pipeline(tree)[0]
296 | for idx, tag in enumerate(tags):
297 | if tag.startswith("s"):
298 | if (idx + 1) < len(tags):
299 | self.assertTrue(tags[idx + 1].startswith('rr') or tags[idx + 1].startswith('s'))
300 | elif tag.startswith("r"):
301 | if (idx + 1) < len(tags):
302 | self.assertTrue(not tags[idx + 1].startswith('rr'))
303 |
304 | def test_td_bu(self):
305 | READER = BracketParseCorpusReader('../data/spmrl/', ['English.train', 'English.dev', 'English.test'])
306 | trees = READER.parsed_sents('English.dev')
307 | td_tagger = SRTaggerTopDown(add_remove_top=True)
308 | bu_tagger = SRTaggerBottomUp(add_remove_top=True)
309 |
310 | for tree in tq(trees):
311 | td_tags = td_tagger.tree_to_tags_pipeline(tree)[0]
312 | bu_tags = bu_tagger.tree_to_tags_pipeline(tree)[0]
313 | self.assertEqual(set(td_tags), set(bu_tags))
314 |
315 | def test_tag_ids(self):
316 | READER = BracketParseCorpusReader('../data/spmrl/', ['English.train', 'English.dev', 'English.test'])
317 | trees = READER.parsed_sents('English.train')
318 | tagger = SRTaggerBottomUp(trees, add_remove_top=False)
319 | for tree in tq(trees):
320 | ids = tagger.tree_to_ids_pipeline(tree)
321 | tree_back = tagger.ids_to_tree_pipeline(ids, tree.pos())
322 | self.assertEqual(tree, tree_back)
323 |
324 | def test_max_length(self):
325 | READER = BracketParseCorpusReader('../data/spmrl/', ['English.train', 'English.dev', 'English.test'])
326 | trees = READER.parsed_sents('English.train')
327 | tagger = SRTaggerBottomUp(trees, add_remove_top=True)
328 | print(len(tagger.tag_vocab))
329 |
330 |
331 | class TestSPMRL(unittest.TestCase):
332 | def test_reading_trees(self):
333 | langs = ['Korean']
334 | for l in langs:
335 | READER = BracketParseCorpusReader('../data/spmrl/', [l+'.train', l+'.dev', l+'.test'])
336 | trees = READER.parsed_sents(l+'.test')
337 | trees[0].pretty_print()
338 |
339 | def test_tagging_bu_sr(self):
340 | langs = ['Basque', 'French', 'German', 'Hebrew', 'Hungarian', 'Korean', 'Polish',
341 | 'Swedish']
342 | for l in langs:
343 | READER = BracketParseCorpusReader('../data/spmrl/',
344 | [l + '.train', l + '.dev', l + '.test'])
345 | trees = READER.parsed_sents(l + '.test')
346 | trees[0].pretty_print()
347 | tagger = SRTaggerBottomUp(trees, add_remove_top=True)
348 | print(tagger.tree_to_tags_pipeline(trees[0]))
349 | print(tagger.tree_to_ids_pipeline(trees[0]))
350 |
351 | def test_tagging_td_sr(self):
352 | langs = ['Basque', 'French', 'German', 'Hebrew', 'Hungarian', 'Korean', 'Polish',
353 | 'Swedish']
354 | for l in langs:
355 | READER = BracketParseCorpusReader('../data/spmrl/',
356 | [l + '.train', l + '.dev', l + '.test'])
357 | trees = READER.parsed_sents(l + '.test')
358 | trees[0].pretty_print()
359 | tagger = SRTaggerBottomUp(trees, add_remove_top=True)
360 | print(tagger.tree_to_tags_pipeline(trees[0]))
361 | print(tagger.tree_to_ids_pipeline(trees[0]))
362 |
363 | def test_tagging_tetra(self):
364 | langs = ['Basque', 'French', 'German', 'Hebrew', 'Hungarian', 'Korean', 'Polish',
365 | 'Swedish']
366 | for l in langs:
367 | READER = BracketParseCorpusReader('../data/spmrl/',
368 | [l + '.train', l + '.dev', l + '.test'])
369 | trees = READER.parsed_sents(l + '.test')
370 | trees[0].pretty_print()
371 | tagger = BottomUpTetratagger(trees, add_remove_top=True)
372 | print(tagger.tree_to_tags_pipeline(trees[0]))
373 | print(tagger.tree_to_ids_pipeline(trees[0]))
374 |
375 | def test_taggers(self):
376 | tagger = SRTaggerBottomUp(add_remove_top=False)
377 | langs = ['Basque', 'French', 'German', 'Hebrew', 'Hungarian', 'Korean', 'Polish',
378 | 'Swedish']
379 | for l in tq(langs):
380 | READER = BracketParseCorpusReader('../data/spmrl/',
381 | [l + '.train', l + '.dev', l + '.test'])
382 | trees = READER.parsed_sents(l + '.test')
383 | for tree in tq(trees):
384 | tags, _ = tagger.tree_to_tags_pipeline(tree)
385 | tree_back = tagger.tags_to_tree_pipeline(tags, tree.pos())
386 | self.assertEqual(tree, tree_back)
387 |
388 | def test_korean(self):
389 | import pickle
390 | with open("../data/vocab/Korean-bu-sr.pkl", 'rb') as f:
391 | tag_vocab = pickle.load(f)
392 | tagger = SRTaggerBottomUp(tag_vocab=tag_vocab, add_remove_top=False)
393 | l = "Korean"
394 | READER = BracketParseCorpusReader('../data/spmrl/',
395 | [l + '.train', l + '.dev', l + '.test'])
396 | trees = READER.parsed_sents(l + '.test')
397 | for tree in tq(trees):
398 | tags = tagger.tree_to_tags_pipeline(tree)[0]
399 | tree_back = tagger.tags_to_tree_pipeline(tags, tree.pos())
400 | self.assertEqual(tree, tree_back)
401 |
402 |
403 | class TestStackSize(unittest.TestCase):
404 | def test_english_tetra(self):
405 | l = "English"
406 | READER = BracketParseCorpusReader('../data/spmrl', [l+'.train', l+'.dev', l+'.test'])
407 | trees = READER.parsed_sents([l+'.train', l+'.dev', l+'.test'])
408 | tagger = BottomUpTetratagger(add_remove_top=True)
409 | stack_size_list = []
410 | for tree in tq(trees):
411 | tags, max_depth = tagger.tree_to_tags_pipeline(tree)
412 | stack_size_list.append(max_depth)
413 |
414 | print(stack_size_list)
415 |
416 |
417 | if __name__ == '__main__':
418 | unittest.main()
419 |
--------------------------------------------------------------------------------
/tagging/tetratagger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from abc import ABC
3 |
4 | import numpy as np
5 | from nltk import ParentedTree as PTree
6 | from nltk import Tree
7 |
8 | from learning.decode import BeamSearch, GreedySearch
9 | from tagging.tagger import Tagger, TagDecodeModerator
10 | from tagging.transform import LeftCornerTransformer, RightCornerTransformer
11 | from tagging.tree_tools import find_node_type, is_node_epsilon, NodeType
12 |
13 |
14 | class TetraTagDecodeModerator(TagDecodeModerator):
15 | def __init__(self, tag_vocab):
16 | super().__init__(tag_vocab)
17 | self.internal_tag_vocab_size = len(
18 | [tag for tag in tag_vocab if tag[0] in "LR"]
19 | )
20 | self.leaf_tag_vocab_size = len(
21 | [tag for tag in tag_vocab if tag[0] in "lr"]
22 | )
23 |
24 | is_leaf_mask = np.concatenate(
25 | [
26 | np.zeros(self.internal_tag_vocab_size),
27 | np.ones(self.leaf_tag_vocab_size),
28 | ]
29 | )
30 | self.internal_tags_only = np.asarray(-1e9 * is_leaf_mask, dtype=float)
31 | self.leaf_tags_only = np.asarray(
32 | -1e9 * (1 - is_leaf_mask), dtype=float
33 | )
34 |
35 | stack_depth_change_by_id = [None] * len(tag_vocab)
36 | for i, tag in enumerate(tag_vocab):
37 | if tag.startswith("l"):
38 | stack_depth_change_by_id[i] = +1
39 | elif tag.startswith("R"):
40 | stack_depth_change_by_id[i] = -1
41 | else:
42 | stack_depth_change_by_id[i] = 0
43 | assert None not in stack_depth_change_by_id
44 | self.stack_depth_change_by_id = np.array(
45 | stack_depth_change_by_id, dtype=np.int32
46 | )
47 | self.mask_binarize = False
48 |
49 |
50 | class TetraTagger(Tagger, ABC):
51 | def __init__(self, trees=None, tag_vocab=None, add_remove_top=False):
52 | super().__init__(trees, tag_vocab, add_remove_top)
53 | self.decode_moderator = TetraTagDecodeModerator(self.tag_vocab)
54 |
55 | def expand_tags(self, tags: [str]) -> [str]:
56 | raise NotImplementedError("expand tags is not implemented")
57 |
58 | @staticmethod
59 | def tetra_visualize(tags: [str]):
60 | for tag in tags:
61 | if tag.startswith('r'):
62 | yield "-->"
63 | if tag.startswith('l'):
64 | yield "<--"
65 | if tag.startswith('R'):
66 | yield "==>"
67 | if tag.startswith('L'):
68 | yield "<=="
69 |
70 | @staticmethod
71 | def create_shift_tag(label: str, left_or_right: str) -> str:
72 | if label.find("+") != -1:
73 | return left_or_right + "/" + "/".join(label.split("+")[:-1])
74 | else:
75 | return left_or_right
76 |
77 | @staticmethod
78 | def _create_bi_reduce_tag(label: str, left_or_right: str) -> str:
79 | label = label.split("\\")[1]
80 | if label.find("|") != -1: # drop extra node labels created after binarization
81 | return left_or_right
82 | else:
83 | return left_or_right + "/" + label.replace("+", "/")
84 |
85 | @staticmethod
86 | def _create_unary_reduce_tag(label: str, left_or_right: str) -> str:
87 | label = label.split("\\")[0]
88 | if label.find("|") != -1: # drop extra node labels created after binarization
89 | return left_or_right
90 | else:
91 | return left_or_right + "/" + label.replace("+", "/")
92 |
93 | @staticmethod
94 | def create_merge_shift_tag(label: str, left_or_right: str) -> str:
95 | if label.find("/") != -1:
96 | return left_or_right + "/" + "/".join(label.split("/")[1:])
97 | else:
98 | return left_or_right
99 |
100 | @staticmethod
101 | def _create_pre_terminal_label(tag: str, default="X") -> str:
102 | idx = tag.find("/")
103 | if idx != -1:
104 | label = tag[idx + 1:].replace("/", "+")
105 | if default == "":
106 | return label + "+"
107 | else:
108 | return label
109 | else:
110 | return default
111 |
112 | @staticmethod
113 | def _create_unary_reduce_label(tag: str) -> str:
114 | idx = tag.find("/")
115 | if idx == -1:
116 | return "X|"
117 | return tag[idx + 1:].replace("/", "+")
118 |
119 | @staticmethod
120 | def _create_reduce_label(tag: str) -> str:
121 | idx = tag.find("/")
122 | if idx == -1:
123 | label = "X\\|" # to mark the second part as an extra node created via binarizaiton
124 | else:
125 | label = "X\\" + tag[idx + 1:].replace("/", "+")
126 | return label
127 |
128 | @staticmethod
129 | def is_alternating(tags: [str]) -> bool:
130 | prev_state = True # true means reduce
131 | for tag in tags:
132 | if tag.startswith('r') or tag.startswith('l'):
133 | state = False
134 | else:
135 | state = True
136 | if state == prev_state:
137 | return False
138 | prev_state = state
139 | return True
140 |
141 | def logits_to_ids(self, logits: [], mask, max_depth, keep_per_depth, is_greedy=False) -> [int]:
142 | if is_greedy:
143 | searcher = GreedySearch(
144 | self.decode_moderator,
145 | initial_stack_depth=0,
146 | max_depth=max_depth,
147 | keep_per_depth=keep_per_depth,
148 | )
149 | else:
150 | searcher = BeamSearch(
151 | self.decode_moderator,
152 | initial_stack_depth=0,
153 | max_depth=max_depth,
154 | keep_per_depth=keep_per_depth,
155 | )
156 |
157 | last_t = None
158 | for t in range(logits.shape[0]):
159 | if mask is not None and not mask[t]:
160 | continue
161 | if last_t is not None:
162 | searcher.advance(
163 | logits[last_t, :] + self.decode_moderator.internal_tags_only
164 | )
165 | searcher.advance(logits[t, :] + self.decode_moderator.leaf_tags_only)
166 | last_t = t
167 |
168 | score, best_tag_ids = searcher.get_path()
169 | return best_tag_ids
170 |
171 |
172 | class BottomUpTetratagger(TetraTagger):
173 | """ Kitaev and Klein (2020)"""
174 |
175 | def expand_tags(self, tags: [str]) -> [str]:
176 | new_tags = []
177 | for tag in tags:
178 | if tag.startswith('r'):
179 | new_tags.append("l" + tag[1:])
180 | new_tags.append("R")
181 | else:
182 | new_tags.append(tag)
183 | return new_tags
184 |
185 | def preprocess(self, tree: Tree) -> PTree:
186 | ptree: PTree = super().preprocess(tree)
187 | root_label = ptree.label()
188 | tree_rc = PTree(root_label, [])
189 | RightCornerTransformer.transform(tree_rc, ptree, ptree)
190 | return tree_rc
191 |
192 | def tree_to_tags(self, root: PTree) -> ([], int):
193 | tags = []
194 | lc = LeftCornerTransformer.extract_left_corner_no_eps(root)
195 | tags.append(self.create_shift_tag(lc.label(), "l"))
196 |
197 | logging.debug("SHIFT {}".format(lc.label()))
198 | stack = [lc]
199 | max_stack_len = 1
200 |
201 | while len(stack) > 0:
202 | max_stack_len = max(max_stack_len, len(stack))
203 | node = stack[-1]
204 | if find_node_type(
205 | node) == NodeType.NT: # special case: merge the reduce and last shift
206 | last_tag = tags.pop()
207 | last_two_tag = tags.pop()
208 | if not last_tag.startswith('R') or not last_two_tag.startswith('l'):
209 | raise ValueError(
210 | "When reaching NT the right PT should already be shifted")
211 | # merged shift
212 | tags.append(self.create_merge_shift_tag(last_two_tag, "r"))
213 |
214 | if node.left_sibling() is None and node.right_sibling() is not None:
215 | lc = LeftCornerTransformer.extract_left_corner_no_eps(node.right_sibling())
216 | stack.append(lc)
217 | logging.debug("<-- \t SHIFT {}".format(lc.label()))
218 | # normal shift
219 | tags.append(self.create_shift_tag(lc.label(), "l"))
220 |
221 | elif len(stack) >= 2 and (
222 | node.right_sibling() == stack[-2] or node.left_sibling() == stack[-2]):
223 | prev_node = stack[-2]
224 | logging.debug("==> \t REDUCE[ {0} {1} --> {2} ]".format(
225 | *(prev_node.label(), node.label(), node.parent().label())))
226 |
227 | tags.append(
228 | self._create_bi_reduce_tag(prev_node.label(), "R")) # normal reduce
229 | stack.pop()
230 | stack.pop()
231 | stack.append(node.parent())
232 |
233 | elif find_node_type(node) != NodeType.NT_NT:
234 | if stack[0].parent() is None and len(stack) == 1:
235 | stack.pop()
236 | continue
237 | logging.debug(
238 | "<== \t REDUCE[ {0} --> {1} ]".format(
239 | *(node.label(), node.parent().label())))
240 | tags.append(self._create_unary_reduce_tag(
241 | node.parent().label(), "L")) # unary reduce
242 | stack.pop()
243 | stack.append(node.parent())
244 | else:
245 | logging.error("ERROR: Undefined stack state")
246 | return
247 | logging.debug("=" * 20)
248 | return tags, max_stack_len
249 |
250 | def _unary_reduce(self, node, last_node, tag):
251 | label = self._create_unary_reduce_label(tag)
252 | node.insert(0, PTree(label + "\\" + label, ["EPS"]))
253 | node.insert(1, last_node)
254 | return node
255 |
256 | def _reduce(self, node, last_node, last_2_node, tag):
257 | label = self._create_reduce_label(tag)
258 | last_2_node.set_label(label)
259 | node.insert(0, last_2_node)
260 | node.insert(1, last_node)
261 | return node
262 |
263 | def postprocess(self, transformed_tree: PTree) -> Tree:
264 | tree = PTree("X", ["", ""])
265 | tree = RightCornerTransformer.rev_transform(tree, transformed_tree)
266 | return super().postprocess(tree)
267 |
268 | def tags_to_tree(self, tags: [str], input_seq: [str]) -> PTree:
269 | created_node_stack = []
270 | node = None
271 | expanded_tags = self.expand_tags(tags)
272 | if len(expanded_tags) == 1: # base case
273 | assert expanded_tags[0].startswith('l')
274 | prefix = self._create_pre_terminal_label(expanded_tags[0], "")
275 | return PTree(prefix + input_seq[0][1], [input_seq[0][0]])
276 | for tag in expanded_tags:
277 | if tag.startswith('l'): # shift
278 | prefix = self._create_pre_terminal_label(tag, "")
279 | created_node_stack.append(PTree(prefix + input_seq[0][1], [input_seq[0][0]]))
280 | input_seq.pop(0)
281 | else:
282 | node = PTree("X", [])
283 | if tag.startswith('R'): # normal reduce
284 | last_node = created_node_stack.pop()
285 | last_2_node = created_node_stack.pop()
286 | created_node_stack.append(self._reduce(node, last_node, last_2_node, tag))
287 | elif tag.startswith('L'): # unary reduce
288 | created_node_stack.append(
289 | self._unary_reduce(node, created_node_stack.pop(), tag))
290 | if len(input_seq) != 0:
291 | raise ValueError("All the input sequence is not used")
292 | return node
293 |
294 |
295 | class TopDownTetratagger(TetraTagger):
296 |
297 | @staticmethod
298 | def create_merge_shift_tag(label: str, left_or_right: str) -> str:
299 | if label.find("+") != -1:
300 | return left_or_right + "/" + "/".join(label.split("+")[:-1])
301 | else:
302 | return left_or_right
303 |
304 | def expand_tags(self, tags: [str]) -> [str]:
305 | new_tags = []
306 | for tag in tags:
307 | if tag.startswith('l'):
308 | new_tags.append("L")
309 | new_tags.append("r" + tag[1:])
310 | else:
311 | new_tags.append(tag)
312 | return new_tags
313 |
314 | def preprocess(self, tree: Tree) -> PTree:
315 | ptree = super(TopDownTetratagger, self).preprocess(tree)
316 | root_label = ptree.label()
317 | tree_lc = PTree(root_label, [])
318 | LeftCornerTransformer.transform(tree_lc, ptree, ptree)
319 | return tree_lc
320 |
321 | def tree_to_tags(self, root: PTree) -> ([str], int):
322 | """ convert left-corner transformed tree to shifts and reduces """
323 | stack: [PTree] = [root]
324 | max_stack_len = 1
325 | logging.debug("SHIFT {}".format(root.label()))
326 | tags = []
327 | while len(stack) > 0:
328 | max_stack_len = max(max_stack_len, len(stack))
329 | node = stack[-1]
330 | if find_node_type(node) == NodeType.NT or find_node_type(node) == NodeType.NT_NT:
331 | stack.pop()
332 | logging.debug("REDUCE[ {0} --> {1} {2}]".format(
333 | *(node.label(), node[0].label(), node[1].label())))
334 | if find_node_type(node) == NodeType.NT:
335 | if find_node_type(node[0]) != NodeType.PT:
336 | raise ValueError("Left child of NT should be a PT")
337 | stack.append(node[1])
338 | tags.append(
339 | self.create_merge_shift_tag(node[0].label(), "l")) # merged shift
340 | else:
341 | if not is_node_epsilon(node[1]):
342 | stack.append(node[1])
343 | tags.append(self._create_bi_reduce_tag(node[1].label(), "L"))
344 | # normal reduce
345 | else:
346 | tags.append(self._create_unary_reduce_tag(node[1].label(), "R"))
347 | # unary reduce
348 | stack.append(node[0])
349 |
350 | elif find_node_type(node) == NodeType.PT:
351 | tags.append(self.create_shift_tag(node.label(), "r")) # normal shift
352 | logging.debug("-->\tSHIFT[ {0} ]".format(node.label()))
353 | stack.pop()
354 | return tags, max_stack_len
355 |
356 | def postprocess(self, transformed_tree: PTree) -> Tree:
357 | ptree = PTree("X", ["", ""])
358 | ptree = LeftCornerTransformer.rev_transform(ptree, transformed_tree)
359 | return super().postprocess(ptree)
360 |
361 | def tags_to_tree(self, tags: [str], input_seq: [str]) -> PTree:
362 | expanded_tags = self.expand_tags(tags)
363 | root = PTree("X", [])
364 | created_node_stack = [root]
365 | if len(expanded_tags) == 1: # base case
366 | assert expanded_tags[0].startswith('r')
367 | prefix = self._create_pre_terminal_label(expanded_tags[0], "")
368 | return PTree(prefix + input_seq[0][1], [input_seq[0][0]])
369 | for tag in expanded_tags:
370 | if tag.startswith('r'): # shift
371 | node = created_node_stack.pop()
372 | prefix = self._create_pre_terminal_label(tag, "")
373 | node.set_label(prefix + input_seq[0][1])
374 | node.insert(0, input_seq[0][0])
375 | input_seq.pop(0)
376 | elif tag.startswith('R') or tag.startswith('L'):
377 | parent = created_node_stack.pop()
378 | if tag.startswith('L'): # normal reduce
379 | label = self._create_reduce_label(tag)
380 | r_node = PTree(label, [])
381 | created_node_stack.append(r_node)
382 | else:
383 | label = self._create_unary_reduce_label(tag)
384 | r_node = PTree(label + "\\" + label, ["EPS"])
385 |
386 | l_node_label = self._create_reduce_label(tag)
387 | l_node = PTree(l_node_label, [])
388 | created_node_stack.append(l_node)
389 | parent.insert(0, l_node)
390 | parent.insert(1, r_node)
391 | else:
392 | raise ValueError("Invalid tag type")
393 | if len(input_seq) != 0:
394 | raise ValueError("All the input sequence is not used")
395 | return root
396 |
--------------------------------------------------------------------------------
/tagging/transform.py:
--------------------------------------------------------------------------------
1 | from nltk import ParentedTree as PTree
2 | from tagging.tree_tools import find_node_type, is_node_epsilon, NodeType
3 |
4 |
5 | class Transformer:
6 | @classmethod
7 | def expand_nt(cls, node: PTree, ref_node: PTree) -> (PTree, PTree, PTree, PTree):
8 | raise NotImplementedError("expand non-terminal is not implemented")
9 |
10 | @classmethod
11 | def expand_nt_nt(cls, node: PTree, ref_node1: PTree, ref_node2: PTree) -> (
12 | PTree, PTree, PTree, PTree):
13 | raise NotImplementedError("expand paired non-terimnal is not implemented")
14 |
15 | @classmethod
16 | def extract_right_corner(cls, node: PTree) -> PTree:
17 | while type(node[0]) != str:
18 | if len(node) > 1:
19 | node = node[1]
20 | else: # unary rules
21 | node = node[0]
22 | return node
23 |
24 | @classmethod
25 | def extract_left_corner(cls, node: PTree) -> PTree:
26 | while len(node) > 1:
27 | node = node[0]
28 | return node
29 |
30 | @classmethod
31 | def transform(cls, node: PTree, ref_node1: PTree, ref_node2: PTree) -> None:
32 | if node is None:
33 | return
34 | type = find_node_type(node)
35 | if type == NodeType.NT:
36 | left_ref1, left_ref2, right_ref1, right_ref2, is_base_case = cls.expand_nt(node,
37 | ref_node1)
38 | elif type == NodeType.NT_NT:
39 | is_base_case = False
40 | left_ref1, left_ref2, right_ref1, right_ref2 = cls.expand_nt_nt(
41 | node, ref_node1,
42 | ref_node2)
43 | else:
44 | return
45 | if is_base_case:
46 | return
47 | cls.transform(node[0], left_ref1, left_ref2)
48 | cls.transform(node[1], right_ref1, right_ref2)
49 |
50 |
51 | class LeftCornerTransformer(Transformer):
52 |
53 | @classmethod
54 | def extract_left_corner_no_eps(cls, node: PTree) -> PTree:
55 | while len(node) > 1:
56 | if not is_node_epsilon(node[0]):
57 | node = node[0]
58 | else:
59 | node = node[1]
60 | return node
61 |
62 | @classmethod
63 | def expand_nt(cls, node: PTree, ref_node: PTree) -> (PTree, PTree, PTree, PTree, bool):
64 | leftcorner_node = cls.extract_left_corner(ref_node)
65 | if leftcorner_node == ref_node: # this only happens if the tree only consists of one terminal rule
66 | node.insert(0, ref_node.leaves()[0])
67 | return None, None, None, None, True
68 | new_right_node = PTree(node.label() + "\\" + leftcorner_node.label(), [])
69 | new_left_node = PTree(leftcorner_node.label(), leftcorner_node.leaves())
70 |
71 | node.insert(0, new_left_node)
72 | node.insert(1, new_right_node)
73 | return leftcorner_node, leftcorner_node, ref_node, leftcorner_node, False
74 |
75 | @classmethod
76 | def expand_nt_nt(cls, node: PTree, ref_node1: PTree, ref_node2: PTree) -> (
77 | PTree, PTree, PTree, PTree):
78 | parent_node = ref_node2.parent()
79 | if ref_node1 == parent_node:
80 | new_right_node = PTree(node.label().split("\\")[0] + "\\" + parent_node.label(),
81 | ["EPS"])
82 | else:
83 | new_right_node = PTree(node.label().split("\\")[0] + "\\" + parent_node.label(),
84 | [])
85 |
86 | sibling_node = ref_node2.right_sibling()
87 | if len(sibling_node) == 1:
88 | new_left_node = PTree(sibling_node.label(), sibling_node.leaves())
89 | else:
90 | new_left_node = PTree(sibling_node.label(), [])
91 |
92 | node.insert(0, new_left_node)
93 | node.insert(1, new_right_node)
94 | return sibling_node, sibling_node, ref_node1, parent_node
95 |
96 | @classmethod
97 | def rev_transform(cls, node: PTree, ref_node: PTree, pick_up_labels=True) -> PTree:
98 | if find_node_type(ref_node) == NodeType.NT_NT and pick_up_labels:
99 | node.set_label(ref_node[1].label().split("\\")[1])
100 | if len(ref_node) == 1 and find_node_type(ref_node) == NodeType.PT: # base case
101 | return ref_node
102 | if find_node_type(ref_node[0]) == NodeType.PT and not is_node_epsilon(ref_node[1]):
103 | # X -> word X
104 | pt_node = PTree(ref_node[0].label(), ref_node[0].leaves())
105 | if node[0] == "":
106 | node[0] = pt_node
107 | else:
108 | node[1] = pt_node
109 | par_node = PTree("X", [node, ""])
110 | node = par_node
111 | return cls.rev_transform(node, ref_node[1], pick_up_labels)
112 | elif find_node_type(ref_node[0]) != NodeType.PT and is_node_epsilon(ref_node[1]):
113 | # X -> X X-X
114 | if node[0] == "":
115 | raise ValueError(
116 | "When reaching the root the left branch should already exist")
117 | node[1] = cls.rev_transform(PTree("X", ["", ""]), ref_node[0], pick_up_labels)
118 | return node
119 | elif find_node_type(ref_node[0]) == NodeType.PT and is_node_epsilon(ref_node[1]):
120 | # X -> word X-X
121 | if node[0] == "":
122 | raise ValueError(
123 | "When reaching the end of the chain the left branch should already exist")
124 | node[1] = PTree(ref_node[0].label(), ref_node[0].leaves())
125 | return node
126 | elif find_node_type(ref_node[0]) != NodeType.PT and find_node_type(
127 | ref_node[1]) != NodeType.PT:
128 | # X -> X X
129 | node[1] = cls.rev_transform(PTree("X", ["", ""]), ref_node[0], pick_up_labels)
130 | par_node = PTree("X", [node, ""])
131 | return cls.rev_transform(par_node, ref_node[1], pick_up_labels)
132 |
133 |
134 | class RightCornerTransformer(Transformer):
135 | @classmethod
136 | def expand_nt(cls, node: PTree, ref_node: PTree) -> (PTree, PTree, PTree, PTree, bool):
137 | rightcorner_node = cls.extract_right_corner(ref_node)
138 | if rightcorner_node == ref_node:
139 | node.insert(0, ref_node.leaves()[0])
140 | return None, None, None, None, True
141 | new_left_node = PTree(node.label() + "\\" + rightcorner_node.label(), [])
142 | new_right_node = PTree(rightcorner_node.label(), rightcorner_node.leaves())
143 |
144 | node.insert(0, new_left_node)
145 | node.insert(1, new_right_node)
146 | return ref_node, rightcorner_node, rightcorner_node, rightcorner_node, False
147 |
148 | @classmethod
149 | def expand_nt_nt(cls, node: PTree, ref_node1: PTree, ref_node2: PTree) -> (
150 | PTree, PTree, PTree, PTree):
151 | parent_node = ref_node2.parent()
152 | if ref_node1 == parent_node:
153 | new_left_node = PTree(node.label().split("\\")[0] + "\\" + parent_node.label(),
154 | ["EPS"])
155 | else:
156 | new_left_node = PTree(node.label().split("\\")[0] + "\\" + parent_node.label(), [])
157 |
158 | sibling_node = ref_node2.left_sibling()
159 | if len(sibling_node) == 1:
160 | new_right_node = PTree(sibling_node.label(), sibling_node.leaves())
161 | else:
162 | new_right_node = PTree(sibling_node.label(), [])
163 |
164 | node.insert(0, new_left_node)
165 | node.insert(1, new_right_node)
166 | return ref_node1, parent_node, sibling_node, sibling_node
167 |
168 | @classmethod
169 | def rev_transform(cls, node: PTree, ref_node: PTree, pick_up_labels=True) -> PTree:
170 | if find_node_type(ref_node) == NodeType.NT_NT and pick_up_labels:
171 | node.set_label(ref_node[0].label().split("\\")[1])
172 | if len(ref_node) == 1 and find_node_type(ref_node) == NodeType.PT: # base case
173 | return ref_node
174 | if find_node_type(ref_node[1]) == NodeType.PT and not is_node_epsilon(ref_node[0]):
175 | # X -> X word
176 | pt_node = PTree(ref_node[1].label(), ref_node[1].leaves())
177 | if node[1] == "":
178 | node[1] = pt_node
179 | else:
180 | node[0] = pt_node
181 | par_node = PTree("X", ["", node])
182 | node = par_node
183 | return cls.rev_transform(node, ref_node[0], pick_up_labels)
184 | elif find_node_type(ref_node[1]) != NodeType.PT and is_node_epsilon(ref_node[0]):
185 | # X -> X-X X
186 | if node[1] == "":
187 | raise ValueError(
188 | "When reaching the root the right branch should already exist")
189 | node[0] = cls.rev_transform(PTree("X", ["", ""]), ref_node[1], pick_up_labels)
190 | return node
191 | elif find_node_type(ref_node[1]) == NodeType.PT and is_node_epsilon(ref_node[0]):
192 | # X -> X-X word
193 | if node[1] == "":
194 | raise ValueError(
195 | "When reaching the end of the chain the right branch should already exist")
196 | node[0] = PTree(ref_node[1].label(), ref_node[1].leaves())
197 | return node
198 | elif find_node_type(ref_node[1]) != NodeType.PT and find_node_type(
199 | ref_node[0]) != NodeType.PT:
200 | # X -> X X
201 | node[0] = cls.rev_transform(PTree("X", ["", ""]), ref_node[1], pick_up_labels)
202 | par_node = PTree("X", ["", node])
203 | return cls.rev_transform(par_node, ref_node[0], pick_up_labels)
204 |
--------------------------------------------------------------------------------
/tagging/tree_tools.py:
--------------------------------------------------------------------------------
1 | import random
2 | import string
3 | from enum import Enum
4 |
5 | import numpy as np
6 | from nltk import ParentedTree
7 | from nltk import Tree
8 |
9 | from const import DUMMY_LABEL
10 |
11 |
12 | class NodeType(Enum):
13 | NT = 0
14 | NT_NT = 1
15 | PT = 2
16 |
17 |
18 | def find_node_type(node: ParentedTree) -> NodeType:
19 | if len(node) == 1:
20 | return NodeType.PT
21 | elif node.label().find("\\") != -1:
22 | return NodeType.NT_NT
23 | else:
24 | return NodeType.NT
25 |
26 |
27 | def is_node_epsilon(node: ParentedTree) -> bool:
28 | node_leaves = node.leaves()
29 | if len(node_leaves) == 1 and node_leaves[0] == "EPS":
30 | return True
31 | return False
32 |
33 |
34 | def is_topo_equal(first: ParentedTree, second: ParentedTree) -> bool:
35 | if len(first) == 1 and len(second) != 1:
36 | return False
37 | if len(first) != 1 and len(second) == str:
38 | return False
39 | if len(first) == 1 and len(second) == 1:
40 | return True
41 | return is_topo_equal(first[0], second[0]) and is_topo_equal(first[1], second[1])
42 |
43 |
44 | def random_tree(node: ParentedTree, depth=0, p=.75, cutoff=7) -> None:
45 | """ sample a random tree
46 | @param input_str: list of sampled terminals
47 | """
48 | if np.random.binomial(1, p) == 1 and depth < cutoff:
49 | # add the left child tree
50 | left_label = "X/" + str(depth)
51 | left = ParentedTree(left_label, [])
52 | node.insert(0, left)
53 | random_tree(left, depth=depth + 1, p=p, cutoff=cutoff)
54 | else:
55 | label = "X/" + str(depth)
56 | left = ParentedTree(label, [random.choice(string.ascii_letters)])
57 | node.insert(0, left)
58 |
59 | if np.random.binomial(1, p) == 1 and depth < cutoff:
60 | # add the right child tree
61 | right_label = "X/" + str(depth)
62 | right = ParentedTree(right_label, [])
63 | node.insert(1, right)
64 | random_tree(right, depth=depth + 1, p=p, cutoff=cutoff)
65 | else:
66 | label = "X/" + str(depth)
67 | right = ParentedTree(label, [random.choice(string.ascii_letters)])
68 | node.insert(1, right)
69 |
70 |
71 | def create_dummy_tree(leaves):
72 | dummy_tree = Tree("S", [])
73 | idx = 0
74 | for token, pos in leaves:
75 | dummy_tree.insert(idx, Tree(pos, [token]))
76 | idx += 1
77 |
78 | return dummy_tree
79 |
80 |
81 | def remove_plus_from_tree(tree):
82 | if type(tree) == str:
83 | return
84 | label = tree.label()
85 | new_label = label.replace("+", "@")
86 | tree.set_label(new_label)
87 | for child in tree:
88 | remove_plus_from_tree(child)
89 |
90 |
91 | def add_plus_to_tree(tree):
92 | if type(tree) == str:
93 | return
94 | label = tree.label()
95 | new_label = label.replace("@", "+")
96 | tree.set_label(new_label)
97 | for child in tree:
98 | add_plus_to_tree(child)
99 |
100 |
101 | def process_label(label, node_type):
102 | # label format is NT^^^HEAD_IDX
103 | suffix = label[label.rindex("^^^"):]
104 | head_idx = int(label[label.rindex("^^^") + 3:])
105 | return node_type + suffix, head_idx
106 |
107 |
108 | def extract_head_idx(node):
109 | label = node.label()
110 | return int(label[label.rindex("^^^") + 3:])
111 |
112 |
113 | def attach_to_tree(ref_node, parent_node, stack):
114 | if len(ref_node) == 1:
115 | new_node = Tree(ref_node.label(), [ref_node[0]])
116 | parent_node.insert(len(parent_node), new_node)
117 | return stack
118 |
119 | new_node = Tree(ref_node.label(), [])
120 | stack.append((ref_node, new_node))
121 | parent_node.insert(len(parent_node), new_node)
122 | return stack
123 |
124 |
125 | def relabel_tree(root):
126 | stack = [root]
127 | while len(stack) != 0:
128 | cur_node = stack.pop()
129 | if type(cur_node) == str:
130 | continue
131 | cur_node.set_label(f"X^^^{extract_head_idx(cur_node)}")
132 | for child in cur_node:
133 | stack.append(child)
134 |
135 |
136 | def debinarize_lex_tree(root, new_root):
137 | stack = [(root, new_root)] # (node in binarized tree, node in debinarized tree)
138 | while len(stack) != 0:
139 | ref_node, new_node = stack.pop()
140 | head_idx = extract_head_idx(ref_node)
141 |
142 | # attach the left child
143 | stack = attach_to_tree(ref_node[0], new_node, stack)
144 |
145 | cur_node = ref_node[1]
146 | while cur_node.label().startswith(DUMMY_LABEL):
147 | right_idx = extract_head_idx(cur_node)
148 | head_idx += right_idx
149 | # attach the left child
150 | stack = attach_to_tree(cur_node[0], new_node, stack)
151 | cur_node = cur_node[1]
152 |
153 | # attach the right child
154 | stack = attach_to_tree(cur_node, new_node, stack)
155 | # update the label
156 | new_node.set_label(f"X^^^{head_idx}")
157 |
158 |
159 | def binarize_lex_tree(children, node, node_type):
160 | # node type is X if normal node, Y|X (DUMMY_LABEL) if a dummy node created because of binarization
161 |
162 | if len(children) == 1:
163 | node.insert(0, children[0])
164 | return node
165 | new_label, head_idx = process_label(node.label(), node_type)
166 | node.set_label(new_label)
167 | if len(children) == 2:
168 | left_node = Tree(children[0].label(), [])
169 | left_node = binarize_lex_tree(children[0], left_node, "X")
170 | right_node = Tree(children[1].label(), [])
171 | right_node = binarize_lex_tree(children[1], right_node, "X")
172 | node.insert(0, left_node)
173 | node.insert(1, right_node)
174 | return node
175 | elif len(children) > 2:
176 | if head_idx > 1:
177 | node.set_label(f"{node_type}^^^1")
178 | if head_idx >= 1:
179 | head_idx -= 1
180 | left_node = Tree(children[0].label(), [])
181 | left_node = binarize_lex_tree(children[0], left_node, "X")
182 | right_node = Tree(f"{DUMMY_LABEL}^^^{head_idx}", [])
183 | right_node = binarize_lex_tree(children[1:], right_node, DUMMY_LABEL)
184 | node.insert(0, left_node)
185 | node.insert(1, right_node)
186 | return node
187 | else:
188 | raise ValueError("node has zero children!")
189 |
190 |
191 | def expand_unary(tree):
192 | if len(tree) == 1:
193 | label = tree.label().split("+")
194 | pos_label = label[1]
195 | tree.set_label(label[0])
196 | tree[0] = Tree(pos_label, [tree[0]])
197 | return
198 | for child in tree:
199 | expand_unary(child)
200 |
--------------------------------------------------------------------------------