├── README.md ├── const.py ├── data ├── .gitignore ├── bht │ ├── Chinese.bht.dev │ ├── Chinese.bht.test │ ├── Chinese.bht.train │ ├── Chinese.lex.dev │ ├── Chinese.lex.test │ ├── Chinese.lex.train │ ├── English.bht.dev │ ├── English.bht.test │ ├── English.bht.train │ ├── English.lex.dev │ ├── English.lex.test │ └── English.lex.train ├── dep2bht.py ├── pos │ ├── pos.chinese.json │ └── pos.english.json └── vocab │ └── .gitignore ├── header-hexa.png ├── header.jpg ├── learning ├── __init__.py ├── crf.py ├── dataset.py ├── decode.py ├── evaluate.py ├── learn.py └── util.py ├── load_dataset.sh ├── pat-models └── .gitignore ├── requirements.txt ├── run.py └── tagging ├── __init__.py ├── hexatagger.py ├── srtagger.py ├── tagger.py ├── test.py ├── tetratagger.py ├── transform.py └── tree_tools.py /README.md: -------------------------------------------------------------------------------- 1 | # Parsing as Tagging 2 |

3 | 4 | 5 |

6 | This repository contains code for training and evaluation of two papers: 7 | 8 | - On Parsing as Tagging 9 | - Hexatagging: Projective Dependency Parsing as Tagging 10 | 11 | ## Setting Up The Environment 12 | Set up a virtual environment and install the dependencies: 13 | ```bash 14 | pip install -r requirements.txt 15 | ``` 16 | 17 | ## Getting The Data 18 | ### Constituency Parsing 19 | Follow the instructions in this [repo](https://github.com/nikitakit/self-attentive-parser/tree/master/data) to do the initial preprocessing on English WSJ and SPMRL datasets. The default data path is `data/spmrl` folder, where each file titled in `[LANGUAGE].[train/dev/test]` format. 20 | ### Dependency Parsing with Hexatagger 21 | 1. Convert CoNLL to Binary Headed Trees: 22 | ```bash 23 | python data/dep2bht.py 24 | ``` 25 | This will generate the phrase-structured BHT trees in the `data/bht` directory. 26 | We placed the processed files already under the `data/bht` directory. 27 | 28 | ## Building The Tagging Vocab 29 | In order to use taggers, we need to build the vocabulary of tags for in-order, pre-order and post-order linearizations. You can cache these vocabularies using: 30 | ```bash 31 | python run.py vocab --lang [LANGUAGE] --tagger [TAGGER] 32 | ``` 33 | Tagger can be `td-sr` for top-down (pre-order) shift--reduce linearization, `bu-sr` for bottom-up (post-order) shift--reduce linearization,`tetra` for in-order, and `hexa` for hexatagging linearization. 34 | 35 | ## Training 36 | Train the model and store the best checkpoint. 37 | ```bash 38 | python run.py train --batch-size [BATCH_SIZE] --tagger [TAGGER] --lang [LANGUAGE] --model [MODEL] --epochs [EPOCHS] --lr [LR] --model-path [MODEL_PATH] --output-path [PATH] --max-depth [DEPTH] --keep-per-depth [KPD] [--use-tensorboard] 39 | ``` 40 | - batch size: use 32 to reproduce the results 41 | - tagger: `td-sr` or `bu-sr` or `tetra` 42 | - lang: language, one of the nine languages reported in the paper 43 | - model: `bert`, `bert+crf`, `bert+lstm` 44 | - model path: path that pretrained model is saved 45 | - output path: path to save the best trained model 46 | - max depth: maximum depth to keep in the decoding lattice 47 | - keep per depth: number of elements to keep track of in the decoding step 48 | - use-tensorboard: whether to store the logs in tensorboard or not (true or false) 49 | 50 | ## Evaluation 51 | Calculate evaluation metrics: fscore, precision, recall, loss. 52 | ```bash 53 | python run.py evaluate --lang [LANGUAGE] --model-name [MODEL] --model-path [MODEL_PATH] --bert-model-path [BERT_PATH] --max-depth [DEPTH] --keep-per-depth [KPD] [--is-greedy] 54 | ``` 55 | - lang: language, one of the nine languages reported in the paper 56 | - model name: name of the checkpoint 57 | - model path: path of the checkpoint 58 | - bert model path: path to the pretrained model 59 | - max depth: maximum depth to keep in the decoding lattice 60 | - keep per depth: number of elements to keep track of in the decoding step 61 | - is greedy: whether or not use the greedy decoding, default is false 62 | 63 | # Exact Commands for Hexatagging 64 | The above commands can be used together with different taggers, models, and on different languages. To reproduce our Hexatagging results, here we put the exact commands used for training and evaluation of Hexatagger. 65 | ## Train 66 | ### PTB (English) 67 | ```bash 68 | CUDA_VISIBLE_DEVICES=0 python run.py train --lang English --max-depth 6 --tagger hexa --model bert --epochs 50 --batch-size 32 --lr 2e-5 --model-path xlnet-large-cased --output-path ./checkpoints/ --use-tensorboard True 69 | # model saved at ./checkpoints/English-hexa-bert-3e-05-50 70 | ``` 71 | ### CTB (Chinese) 72 | ```bash 73 | CUDA_VISIBLE_DEVICES=0 python run.py train --lang Chinese --max-depth 6 --tagger hexa --model bert --epochs 50 --batch-size 32 --lr 2e-5 --model-path hfl/chinese-xlnet-mid --output-path ./checkpoints/ --use-tensorboard True 74 | # model saved at ./checkpoints/Chinese-hexa-bert-2e-05-50 75 | ``` 76 | 77 | ### UD 78 | ```bash 79 | CUDA_VISIBLE_DEVICES=0 python run.py train --lang bg --max-depth 6 --tagger hexa --model bert --epochs 50 --batch-size 32 --lr 2e-5 --model-path bert-base-multilingual-cased --output-path ./checkpoints/ --use-tensorboard True 80 | ``` 81 | 82 | ## Evaluate 83 | ### PTB 84 | ```bash 85 | python run.py evaluate --lang English --max-depth 10 --tagger hexa --bert-model-path xlnet-large-cased --model-name English-hexa-bert-3e-05-50 --batch-size 64 --model-path ./checkpoints/ 86 | ``` 87 | 88 | ### CTB 89 | ```bash 90 | python run.py evaluate --lang Chinese --max-depth 10 --tagger hexa --bert-model-path bert-base-chinese --model-name Chinese-hexa-bert-3e-05-50 --batch-size 64 --model-path ./checkpoints/ 91 | ``` 92 | ### UD 93 | ```bash 94 | python run.py evaluate --lang bg --max-depth 10 --tagger hexa --bert-model-path bert-base-multilingual-cased --model-name bg-hexa-bert-1e-05-50 --batch-size 64 --model-path ./checkpoints/ 95 | ``` 96 | 97 | 98 | ## Predict 99 | ### PTB 100 | ```bash 101 | python run.py predict --lang English --max-depth 10 --tagger hexa --bert-model-path xlnet-large-cased --model-name English-hexa-bert-3e-05-50 --batch-size 64 --model-path ./checkpoints/ 102 | ``` 103 | 104 | # Citation 105 | If you find this repository useful, please cite our papers: 106 | ```bibtex 107 | @inproceedings{amini-cotterell-2022-parsing, 108 | title = "On Parsing as Tagging", 109 | author = "Amini, Afra and 110 | Cotterell, Ryan", 111 | booktitle = "Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing", 112 | month = dec, 113 | year = "2022", 114 | address = "Abu Dhabi, United Arab Emirates", 115 | publisher = "Association for Computational Linguistics", 116 | url = "https://aclanthology.org/2022.emnlp-main.607", 117 | pages = "8884--8900", 118 | } 119 | ``` 120 | 121 | ```bibtex 122 | @inproceedings{amini-etal-2023-hexatagging, 123 | title = "Hexatagging: Projective Dependency Parsing as Tagging", 124 | author = "Amini, Afra and 125 | Liu, Tianyu and 126 | Cotterell, Ryan", 127 | booktitle = "Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)", 128 | month = jul, 129 | year = "2023", 130 | address = "Toronto, Canada", 131 | publisher = "Association for Computational Linguistics", 132 | url = "https://aclanthology.org/2023.acl-short.124", 133 | pages = "1453--1464", 134 | } 135 | ``` 136 | 137 | -------------------------------------------------------------------------------- /const.py: -------------------------------------------------------------------------------- 1 | DUMMY_LABEL = "Y|X" 2 | 3 | DATA_PATH = "data/spmrl/" 4 | DEP_DATA_PATH = "data/bht/" 5 | 6 | 7 | TETRATAGGER = "tetra" 8 | HEXATAGGER = "hexa" 9 | 10 | TD_SR = "td-sr" 11 | BU_SR = "bu-sr" 12 | BERT = ["bert", "roberta", "robertaL"] 13 | BERTCRF = ["bert+crf", "roberta+crf", "robertaL+crf"] 14 | BERTLSTM = ["bert+lstm", "roberta+lstm", "robertaL+lstm"] 15 | 16 | BAQ = "Basque" 17 | CHN = "Chinese" 18 | CHN09 = "Chinese-conll09" 19 | FRE = "French" 20 | GER = "German" 21 | HEB = "Hebrew" 22 | HUN = "Hungarian" 23 | KOR = "Korean" 24 | POL = "Polish" 25 | SWE = "swedish" 26 | ENG = "English" 27 | LANG = [BAQ, CHN, CHN09, FRE, GER, HEB, HUN, KOR, POL, SWE, ENG, 28 | "bg","ca","cs","de","en","es", "fr","it","nl","no","ro","ru"] 29 | 30 | UD_LANG_TO_DIR = { 31 | "bg": "/UD_Bulgarian-BTB/bg_btb-ud-{split}.conllu", 32 | "ca": "/UD_Catalan-AnCora/ca_ancora-ud-{split}.conllu", 33 | "cs": "/UD_Czech-PDT/cs_pdt-ud-{split}.conllu", 34 | "de": "/UD_German-GSD/de_gsd-ud-{split}.conllu", 35 | "en": "/UD_English-EWT/en_ewt-ud-{split}.conllu", 36 | "es": "/UD_Spanish-AnCora/es_ancora-ud-{split}.conllu", 37 | "fr": "/UD_French-GSD/fr_gsd-ud-{split}.conllu", 38 | "it": "/UD_Italian-ISDT/it_isdt-ud-{split}.conllu", 39 | "nl": "/UD_Dutch-Alpino/nl_alpino-ud-{split}.conllu", 40 | "no": "/UD_Norwegian-Bokmaal/no_bokmaal-ud-{split}.conllu", 41 | "ro": "/UD_Romanian-RRT/ro_rrt-ud-{split}.conllu", 42 | "ru": "/UD_Russian-SynTagRus/ru_syntagrus-ud-{split}.conllu", 43 | } 44 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | !bht/* 6 | !pos/* 7 | !vocab/* 8 | 9 | -------------------------------------------------------------------------------- /data/dep2bht.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from nltk.corpus.reader import DependencyCorpusReader 4 | from nltk.corpus.reader.util import * 5 | from nltk.tree import Tree 6 | from tqdm import tqdm as tq 7 | import sys 8 | 9 | LANG_TO_DIR = { 10 | "bg": "/UD_Bulgarian-BTB/bg_btb-ud-{split}.conllu", 11 | "ca": "/UD_Catalan-AnCora/ca_ancora-ud-{split}.conllu", 12 | "cs": "/UD_Czech-PDT/cs_pdt-ud-{split}.conllu", 13 | "de": "/UD_German-GSD/de_gsd-ud-{split}.conllu", 14 | "en": "/UD_English-EWT/en_ewt-ud-{split}.conllu", 15 | "es": "/UD_Spanish-AnCora/es_ancora-ud-{split}.conllu", 16 | "fr": "/UD_French-GSD/fr_gsd-ud-{split}.conllu", 17 | "it": "/UD_Italian-ISDT/it_isdt-ud-{split}.conllu", 18 | "nl": "/UD_Dutch-Alpino/nl_alpino-ud-{split}.conllu", 19 | "no": "/UD_Norwegian-Bokmaal/no_bokmaal-ud-{split}.conllu", 20 | "ro": "/UD_Romanian-RRT/ro_rrt-ud-{split}.conllu", 21 | "ru": "/UD_Russian-SynTagRus/ru_syntagrus-ud-{split}.conllu", 22 | } 23 | 24 | 25 | # using ^^^ as delimiter 26 | # since ^ never appears 27 | 28 | def augment_constituent_tree(const_tree, dep_tree): 29 | # augment constituent tree leaves into dicts 30 | 31 | assert len(const_tree.leaves()) == len(dep_tree.nodes) - 1 32 | 33 | leaf_nodes = list(const_tree.treepositions('leaves')) 34 | for i, pos in enumerate(leaf_nodes): 35 | x = dep_tree.nodes[1 + i] 36 | y = const_tree[pos].replace("\\", "") 37 | assert (x['word'] == y), (const_tree, dep_tree) 38 | 39 | # expanding leaves with dependency info 40 | const_tree[pos] = { 41 | "word": dep_tree.nodes[1 + i]["word"], 42 | "head": dep_tree.nodes[1 + i]["head"], 43 | "rel": dep_tree.nodes[1 + i]["rel"] 44 | } 45 | 46 | return const_tree 47 | 48 | 49 | def get_bht(root, offset=0): 50 | # Return: 51 | # The index of the head in this tree 52 | # and: 53 | # The dependant that this tree points to 54 | if type(root) is dict: 55 | # leaf node already, return its head 56 | return 0, root['head'] 57 | 58 | # word offset in the current tree 59 | words_seen = 0 60 | root_projection = (offset, offset + len(root.leaves())) 61 | # init the return values to be None 62 | head, root_points_to = None, None 63 | 64 | # traverse the consituent tree 65 | for idx, child in enumerate(root): 66 | head_of_child, child_points_to = get_bht(child, offset + words_seen) 67 | if type(child) == type(root): 68 | words_seen += len(child.leaves()) 69 | else: 70 | # leaf node visited 71 | words_seen += 1 72 | 73 | if child_points_to < root_projection[0] or child_points_to >= root_projection[1]: 74 | # pointing to outside of the current tree 75 | if head is not None: 76 | print("error! Non-projectivity detected.", root_projection, idx) 77 | continue # choose the first child as head 78 | head = idx 79 | root_points_to = child_points_to 80 | 81 | if root_points_to is None: 82 | # self contained sub-sentence 83 | print("multiple roots detected", root) 84 | root_points_to = 0 85 | 86 | original_label = root.label() 87 | root.set_label(f"{original_label}^^^{head}") 88 | 89 | return head, root_points_to 90 | 91 | 92 | def dep2lex(dep_tree, language="English"): 93 | # left-first attachment 94 | def dfs(node_idx): 95 | dependencies = [] 96 | for rel in dep_tree.nodes[node_idx]["deps"]: 97 | for dependent in dep_tree.nodes[node_idx]["deps"][rel]: 98 | dependencies.append((dependent, rel)) 99 | 100 | dependencies.append((node_idx, "SELF")) 101 | if len(dependencies) == 1: 102 | # no dependent at all, leaf node 103 | return Tree( 104 | f"X^^^{dep_tree.nodes[node_idx]['rel']}", 105 | [ 106 | Tree( 107 | f"{dep_tree.nodes[node_idx]['ctag']}", 108 | [ 109 | f"{dep_tree.nodes[node_idx]['word']}" 110 | ] 111 | ) 112 | ] 113 | ) 114 | # Now, len(dependencies) >= 2, sort dependents 115 | dependencies = sorted(dependencies) 116 | 117 | lex_tree_root = Tree(f"X^^^{0}", []) 118 | empty_slot = lex_tree_root # place to fill in the next subtree 119 | for idx, dependency in enumerate(dependencies): 120 | if dependency[0] < node_idx: 121 | # waiting for a head in the right child 122 | lex_tree_root.set_label(f"X^^^{1}") 123 | if len(lex_tree_root) == 0: 124 | # the first non-head child 125 | lex_tree_root.insert(0, dfs(dependency[0])) 126 | else: 127 | # not the first non-head child 128 | # insert a sub tree: \ 129 | # X^^^1 130 | # / \ 131 | # word [empty_slot] 132 | empty_slot.insert(1, Tree(f"X^^^{1}", [dfs(dependency[0])])) 133 | empty_slot = empty_slot[1] 134 | elif dependency[0] == node_idx: 135 | tree_piece = Tree( 136 | f"X^^^{dep_tree.nodes[dependency[0]]['rel']}", 137 | [ 138 | Tree( 139 | f"{dep_tree.nodes[dependency[0]]['ctag']}", 140 | [ 141 | f"{dep_tree.nodes[dependency[0]]['word']}" 142 | ] 143 | ) 144 | ] 145 | ) 146 | if len(empty_slot) == 1: 147 | # This is the head 148 | empty_slot.insert(1, tree_piece) 149 | else: # len(empty_slot) == 0 150 | lex_tree_root = tree_piece 151 | pass 152 | else: 153 | # moving on to the right of the head 154 | lex_tree_root = Tree(f"X^^^{0}", [lex_tree_root, dfs(dependency[0])]) 155 | return lex_tree_root 156 | 157 | return dfs( 158 | dep_tree.nodes[0]["deps"]["root"][0] if "root" in dep_tree.nodes[0]["deps"] else 159 | dep_tree.nodes[0]["deps"]["ROOT"][0] 160 | ) 161 | 162 | 163 | def dep2lex_right_first(dep_tree, language="English"): 164 | # right-first attachment 165 | 166 | def dfs(node_idx): 167 | dependencies = [] 168 | for rel in dep_tree.nodes[node_idx]["deps"]: 169 | for dependent in dep_tree.nodes[node_idx]["deps"][rel]: 170 | dependencies.append((dependent, rel)) 171 | 172 | dependencies.append((node_idx, "SELF")) 173 | if len(dependencies) == 1: 174 | # no dependent at all, leaf node 175 | return Tree( 176 | f"X^^^{dep_tree.nodes[node_idx]['rel']}", 177 | [ 178 | Tree( 179 | f"{dep_tree.nodes[node_idx]['ctag']}", 180 | [ 181 | f"{dep_tree.nodes[node_idx]['word']}" 182 | ] 183 | ) 184 | ] 185 | ) 186 | # Now, len(dependencies) >= 2, sort dependents 187 | dependencies = sorted(dependencies) 188 | 189 | lex_tree_root = Tree(f"X^^^{0}", []) 190 | empty_slot = lex_tree_root # place to fill in the next subtree 191 | for idx, dependency in enumerate(dependencies): 192 | if dependency[0] < node_idx: 193 | # waiting for a head in the right child 194 | lex_tree_root.set_label(f"X^^^{1}") 195 | if len(lex_tree_root) == 0: 196 | # the first non-head child 197 | lex_tree_root.insert( 198 | 0, dfs(dependency[0]) 199 | ) 200 | else: 201 | # not the first non-head child 202 | # insert a sub tree: \ 203 | # X^^^1 204 | # / \ 205 | # word [empty_slot] 206 | empty_slot.insert( 207 | 1, 208 | Tree(f"X^^^{1}", [ 209 | dfs(dependency[0]) 210 | ]) 211 | ) 212 | empty_slot = empty_slot[1] 213 | elif dependency[0] == node_idx: 214 | # This is the head 215 | right_branch_root = Tree( 216 | f"X^^^{dep_tree.nodes[dependency[0]]['rel']}", 217 | [ 218 | Tree( 219 | f"{dep_tree.nodes[dependency[0]]['ctag']}", 220 | [ 221 | f"{dep_tree.nodes[dependency[0]]['word']}" 222 | ] 223 | ) 224 | ] 225 | ) 226 | else: 227 | # moving on to the right of the head 228 | right_branch_root = Tree( 229 | f"X^^^{0}", 230 | [ 231 | right_branch_root, 232 | dfs(dependency[0]) 233 | ] 234 | ) 235 | 236 | if len(empty_slot) == 0: 237 | lex_tree_root = right_branch_root 238 | else: 239 | empty_slot.insert(1, right_branch_root) 240 | 241 | return lex_tree_root 242 | 243 | return dfs( 244 | dep_tree.nodes[0]["deps"]["root"][0] if "root" in dep_tree.nodes[0]["deps"] else 245 | dep_tree.nodes[0]["deps"]["ROOT"][0] 246 | ) 247 | 248 | 249 | if __name__ == "__main__": 250 | repo_directory = os.path.abspath(__file__) 251 | 252 | if len(sys.argv) > 1: 253 | path = sys.argv[1] 254 | reader = DependencyCorpusReader( 255 | os.path.dirname(repo_directory), 256 | [path], 257 | ) 258 | dep_trees = reader.parsed_sents(path) 259 | 260 | bhts = [] 261 | for dep_tree in tq(dep_trees): 262 | lex_tree = dep2lex(dep_tree, language="input") 263 | bhts.append(lex_tree) 264 | lex_tree_leaves = tuple(lex_tree.leaves()) 265 | dep_tree_leaves = tuple( 266 | [str(node["word"]) for _, node in sorted(dep_tree.nodes.items())]) 267 | 268 | dep_tree_leaves = dep_tree_leaves[1:] 269 | 270 | print(f"Writing BHTs to {os.path.dirname(repo_directory)}/input.bht.test") 271 | with open(os.path.dirname(repo_directory) + f"/bht/input.bht.test", 272 | "w") as fout: 273 | for lex_tree in bhts: 274 | fout.write(lex_tree._pformat_flat("", "()", False) + "\n") 275 | 276 | exit(0) 277 | 278 | for language in [ 279 | "English", # PTB 280 | "Chinese", # CTB 281 | "bg","ca","cs","de","en","es","fr","it","nl","no","ro","ru" # UD2.2 282 | ]: 283 | print(f"Processing {language}...") 284 | if language == "English": 285 | path = os.path.dirname(repo_directory) + "/ptb/ptb_{split}_3.3.0.sd.clean" 286 | paths = [path.format(split=split) for split in ["train", "dev", "test"]] 287 | elif language == "Chinese": 288 | path = os.path.dirname(repo_directory) + "/ctb/{split}.ctb.conll" 289 | paths = [path.format(split=split) for split in ["train", "dev", "test"]] 290 | elif language in ["bg", "ca","cs","de","en","es","fr","it","nl","no","ro","ru"]: 291 | path = os.path.dirname(repo_directory)+f"/ctb_ptb_ud22/ud2.2/{LANG_TO_DIR[language]}" 292 | paths = [] 293 | groups = re.match(r'(\w+)_\w+-ud-(\w+)\.conllu', os.path.split(path)[-1]) 294 | 295 | for split in ["train", "dev", "test"]: 296 | conll_path = path.format(split=split) 297 | command = f"cd ../malt/maltparser-1.9.2/; java -jar maltparser-1.9.2.jar -c {language}_{split} -m proj" \ 298 | f" -i {conll_path} -o {conll_path}.proj -pp head" 299 | os.system(command) 300 | paths.append(conll_path + ".proj") 301 | 302 | reader = DependencyCorpusReader( 303 | os.path.dirname(repo_directory), 304 | paths, 305 | ) 306 | 307 | for path, split in zip(paths, ["train", "dev", "test"]): 308 | print(f"Converting {path} to lexicalized tree") 309 | 310 | with tempfile.NamedTemporaryFile(mode="w") as tmp_file: 311 | with open(path, "r", encoding='utf-8-sig') as fin: 312 | lines = [line for line in fin.readlines() if not line.startswith("#")] 313 | for idl, line in enumerate(lines): 314 | if line.strip() == "": 315 | continue 316 | assert len(line.split("\t")) == 10, line 317 | if '\xad' in line: 318 | lines[idl] = lines[idl].replace('\xad', '') 319 | if '\ufeff' in line: 320 | lines[idl] = lines[idl].replace('\ufeff', '') 321 | if " " in line: 322 | lines[idl] = lines[idl].replace(" ", ",") 323 | if ")" in line: 324 | lines[idl] = lines[idl].replace(")", "-RRB-") 325 | if "(" in line: 326 | lines[idl] = lines[idl].replace("(", "-LRB-") 327 | 328 | tmp_file.writelines(lines) 329 | tmp_file.flush() 330 | tmp_file.seek(0) 331 | 332 | dep_trees = reader.parsed_sents(tmp_file.name) 333 | 334 | bhts = [] 335 | for dep_tree in tq(dep_trees): 336 | lex_tree = dep2lex(dep_tree, language=language) 337 | bhts.append(lex_tree) 338 | lex_tree_leaves = tuple(lex_tree.leaves()) 339 | dep_tree_leaves = tuple( 340 | [str(node["word"]) for _, node in sorted(dep_tree.nodes.items())]) 341 | 342 | dep_tree_leaves = dep_tree_leaves[1:] 343 | 344 | print(f"Writing BHTs to {os.path.dirname(repo_directory)}/{language}.bht.{split}") 345 | with open(os.path.dirname(repo_directory) + f"/bht/{language}.bht.{split}", 346 | "w") as fout: 347 | for lex_tree in bhts: 348 | fout.write(lex_tree._pformat_flat("", "()", False) + "\n") 349 | -------------------------------------------------------------------------------- /data/pos/pos.chinese.json: -------------------------------------------------------------------------------- 1 | {"TOP": 1, "NN": 2, "NR": 3, "VV": 4, "PU": 5, "LC": 6, "JJ": 7, "AD": 8, "CD": 9, "VA": 10, "OD": 11, "DEG": 12, "VE": 13, "NT": 14, "PN": 15, "DEC": 16, "CC": 17, "DT": 18, "P": 19, "M": 20, "VC": 21, "BA": 22, "SP": 23, "AS": 24, "SB": 25, "FW": 26, "CS": 27, "X": 28, "ETC": 29, "DER": 30, "DEV": 31, "LB": 32, "MSP": 33, "IJ": 34} -------------------------------------------------------------------------------- /data/pos/pos.english.json: -------------------------------------------------------------------------------- 1 | {"IN": 1, "DT": 2, "NNP": 3, "CD": 4, "NN": 5, "``": 6, "''": 7, "POS": 8, "-LRB-": 9, "JJ": 10, "NNS": 11, "VBP": 12, ",": 13, "CC": 14, "-RRB-": 15, "VBN": 16, "VBD": 17, "RB": 18, "TO": 19, ".": 20, "VBZ": 21, "NNPS": 22, "PRP": 23, "PRP$": 24, "VB": 25, "MD": 26, "VBG": 27, "RBR": 28, ":": 29, "WP": 30, "WDT": 31, "JJR": 32, "PDT": 33, "RBS": 34, "JJS": 35, "WRB": 36, "$": 37, "RP": 38, "FW": 39, "EX": 40, "#": 41, "WP$": 42, "UH": 43, "SYM": 44, "LS": 45} -------------------------------------------------------------------------------- /data/vocab/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | 6 | -------------------------------------------------------------------------------- /header-hexa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rycolab/parsing-as-tagging/e7a0be2d92dc44c8c5920a33db10dcfab8573dc1/header-hexa.png -------------------------------------------------------------------------------- /header.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rycolab/parsing-as-tagging/e7a0be2d92dc44c8c5920a33db10dcfab8573dc1/header.jpg -------------------------------------------------------------------------------- /learning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rycolab/parsing-as-tagging/e7a0be2d92dc44c8c5920a33db10dcfab8573dc1/learning/__init__.py -------------------------------------------------------------------------------- /learning/crf.py: -------------------------------------------------------------------------------- 1 | # Code is inspired from https://github.com/mtreviso/linear-chain-crf 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class CRF(nn.Module): 7 | """ 8 | Linear-chain Conditional Random Field (CRF). 9 | Args: 10 | nb_labels (int): number of labels in your tagset, including special symbols. 11 | device: cpu or gpu, 12 | batch_first (bool): Whether the first dimension represents the batch dimension. 13 | """ 14 | 15 | def __init__( 16 | self, nb_labels, device=None, batch_first=True 17 | ): 18 | super(CRF, self).__init__() 19 | 20 | self.nb_labels = nb_labels 21 | self.batch_first = batch_first 22 | self.device = device 23 | 24 | self.transitions = nn.Parameter(torch.empty(self.nb_labels, self.nb_labels)) 25 | self.start_transitions = nn.Parameter(torch.empty(self.nb_labels)) 26 | self.end_transitions = nn.Parameter(torch.empty(self.nb_labels)) 27 | 28 | self.init_weights() 29 | 30 | def init_weights(self): 31 | nn.init.uniform_(self.start_transitions, -0.1, 0.1) 32 | nn.init.uniform_(self.end_transitions, -0.1, 0.1) 33 | nn.init.uniform_(self.transitions, -0.1, 0.1) 34 | 35 | def forward(self, emissions, tags, mask=None): 36 | """Compute the negative log-likelihood. See `log_likelihood` method.""" 37 | nll = -self.log_likelihood(emissions, tags, mask=mask) 38 | return nll 39 | 40 | def log_likelihood(self, emissions, tags, mask=None): 41 | """Compute the probability of a sequence of tags given a sequence of 42 | emissions scores. 43 | Args: 44 | emissions (torch.Tensor): Sequence of emissions for each label. 45 | Shape of (batch_size, seq_len, nb_labels) if batch_first is True, 46 | (seq_len, batch_size, nb_labels) otherwise. 47 | tags (torch.LongTensor): Sequence of labels. 48 | Shape of (batch_size, seq_len) if batch_first is True, 49 | (seq_len, batch_size) otherwise. 50 | mask (torch.FloatTensor, optional): Tensor representing valid positions. 51 | If None, all positions are considered valid. 52 | Shape of (batch_size, seq_len) if batch_first is True, 53 | (seq_len, batch_size) otherwise. 54 | Returns: 55 | torch.Tensor: the log-likelihoods for each sequence in the batch. 56 | Shape of (batch_size,) 57 | """ 58 | 59 | # fix tensors order by setting batch as the first dimension 60 | if not self.batch_first: 61 | emissions = emissions.transpose(0, 1) 62 | tags = tags.transpose(0, 1) 63 | 64 | if mask is None: 65 | mask = torch.ones(emissions.shape[:2], dtype=torch.float).to(self.device) 66 | 67 | scores = self._compute_scores(emissions, tags, mask=mask) 68 | partition = self._compute_log_partition(emissions, mask=mask) 69 | return torch.sum(scores - partition) 70 | 71 | def _compute_scores(self, emissions, tags, mask, last_idx=None): 72 | """Compute the scores for a given batch of emissions with their tags. 73 | Args: 74 | emissions (torch.Tensor): (batch_size, seq_len, nb_labels) 75 | tags (Torch.LongTensor): (batch_size, seq_len) 76 | mask (Torch.FloatTensor): (batch_size, seq_len) 77 | Returns: 78 | torch.Tensor: Scores for each batch. 79 | Shape of (batch_size,) 80 | """ 81 | batch_size, seq_length = tags.shape 82 | scores = torch.zeros(batch_size).to(self.device) 83 | 84 | alpha_mask = torch.zeros((batch_size,), dtype=int).to(self.device) 85 | previous_tags = torch.zeros((batch_size,), dtype=int).to(self.device) 86 | 87 | for i in range(0, seq_length): 88 | is_valid = mask[:, i] 89 | 90 | current_tags = tags[:, i] 91 | 92 | e_scores = emissions[:, i].gather(1, current_tags.unsqueeze(1)).squeeze() 93 | 94 | first_t_scores = self.start_transitions[current_tags] 95 | t_scores = self.transitions[previous_tags, current_tags] 96 | t_scores = (1 - alpha_mask) * first_t_scores + alpha_mask * t_scores 97 | alpha_mask = is_valid + (1 - is_valid) * alpha_mask 98 | 99 | e_scores = e_scores * is_valid 100 | t_scores = t_scores * is_valid 101 | 102 | scores += e_scores + t_scores 103 | 104 | previous_tags = current_tags * is_valid + previous_tags * (1 - is_valid) 105 | 106 | scores += self.end_transitions[previous_tags] 107 | 108 | return scores 109 | 110 | def _compute_log_partition(self, emissions, mask): 111 | """Compute the partition function in log-space using the forward-algorithm. 112 | Args: 113 | emissions (torch.Tensor): (batch_size, seq_len, nb_labels) 114 | mask (Torch.FloatTensor): (batch_size, seq_len) 115 | Returns: 116 | torch.Tensor: the partition scores for each batch. 117 | Shape of (batch_size,) 118 | """ 119 | batch_size, seq_length, nb_labels = emissions.shape 120 | 121 | alphas = torch.zeros((batch_size, nb_labels)).to(self.device) 122 | alpha_mask = torch.zeros((batch_size, 1), dtype=int).to(self.device) 123 | 124 | for i in range(0, seq_length): 125 | is_valid = mask[:, i].unsqueeze(-1) 126 | 127 | first_alphas = self.start_transitions + emissions[:, i] 128 | 129 | # (bs, nb_labels) -> (bs, 1, nb_labels) 130 | e_scores = emissions[:, i].unsqueeze(1) 131 | # (nb_labels, nb_labels) -> (bs, nb_labels, nb_labels) 132 | t_scores = self.transitions.unsqueeze(0) 133 | # (bs, nb_labels) -> (bs, nb_labels, 1) 134 | a_scores = alphas.unsqueeze(2) 135 | scores = e_scores + t_scores + a_scores 136 | 137 | new_alphas = torch.logsumexp(scores, dim=1) 138 | 139 | new_alphas = (1 - alpha_mask) * first_alphas + alpha_mask * new_alphas 140 | alpha_mask = is_valid + (1 - is_valid) * alpha_mask 141 | 142 | alphas = is_valid * new_alphas + (1 - is_valid) * alphas 143 | 144 | end_scores = alphas + self.end_transitions 145 | 146 | return torch.logsumexp(end_scores, dim=1) 147 | -------------------------------------------------------------------------------- /learning/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.nn.utils.rnn import pad_sequence 8 | 9 | BERT_TOKEN_MAPPING = { 10 | "-LRB-": "(", 11 | "-RRB-": ")", 12 | "-LCB-": "{", 13 | "-RCB-": "}", 14 | "-LSB-": "[", 15 | "-RSB-": "]", 16 | "``": '"', 17 | "''": '"', 18 | "`": "'", 19 | '«': '"', 20 | '»': '"', 21 | '‘': "'", 22 | '’': "'", 23 | '“': '"', 24 | '”': '"', 25 | '„': '"', 26 | '‹': "'", 27 | '›': "'", 28 | "\u2013": "--", # en dash 29 | "\u2014": "--", # em dash 30 | } 31 | 32 | 33 | def ptb_unescape(sent): 34 | cleaned_words = [] 35 | for word in sent: 36 | word = BERT_TOKEN_MAPPING.get(word, word) 37 | word = word.replace('\\/', '/').replace('\\*', '*') 38 | # Mid-token punctuation occurs in biomedical text 39 | word = word.replace('-LSB-', '[').replace('-RSB-', ']') 40 | word = word.replace('-LRB-', '(').replace('-RRB-', ')') 41 | if word == "n't" and cleaned_words: 42 | cleaned_words[-1] = cleaned_words[-1] + "n" 43 | word = "'t" 44 | cleaned_words.append(word) 45 | return cleaned_words 46 | 47 | 48 | class TaggingDataset(torch.utils.data.Dataset): 49 | def __init__(self, split, tokenizer, tag_system, reader, device, is_tetratags=False, 50 | language="english", pad_to_len=None, 51 | max_train_len=350): 52 | self.reader = reader 53 | self.split = split 54 | self.trees = self.reader.parsed_sents(split) 55 | self.tokenizer = tokenizer 56 | self.language = language 57 | self.tag_system = tag_system 58 | self.pad_token_id = self.tokenizer.pad_token_id 59 | self.pad_to_len = pad_to_len 60 | self.device = device 61 | self.is_tetratags = is_tetratags 62 | 63 | if "train" in split and max_train_len is not None: 64 | # To speed up training, we only train on short sentences. 65 | print(len(self.trees), "trees before filtering") 66 | self.trees = [ 67 | tree for tree in self.trees if 68 | (len(tree.leaves()) <= max_train_len and len(tree.leaves()) >= 2)] 69 | print(len(self.trees), "trees after filtering") 70 | else: 71 | # speed up! 72 | self.trees = [ 73 | tree for tree in self.trees if len(tree.leaves()) <= max_train_len] 74 | 75 | if not os.path.exists( 76 | f"./data/pos.{language.lower()}.json") and "train" in split: 77 | self.pos_dict = self.get_pos_dict() 78 | with open(f"./data/pos.{language.lower()}.json", 'w') as fp: 79 | json.dump(self.pos_dict, fp) 80 | else: 81 | with open(f"./data/pos.{language.lower()}.json", 'r') as fp: 82 | self.pos_dict = json.load(fp) 83 | 84 | def get_pos_dict(self): 85 | pos_dict = {} 86 | for t in self.trees: 87 | for _, x in t.pos(): 88 | pos_dict[x] = pos_dict.get(x, 1 + len(pos_dict)) 89 | return pos_dict 90 | 91 | def __len__(self): 92 | return len(self.trees) 93 | 94 | def __getitem__(self, index): 95 | tree = self.trees[index] 96 | words = ptb_unescape(tree.leaves()) 97 | pos_tags = [self.pos_dict.get(x[1], 0) for x in tree.pos()] 98 | 99 | if 'albert' in str(type(self.tokenizer)): 100 | # albert is case insensitive! 101 | words = [w.lower() for w in words] 102 | 103 | encoded = self.tokenizer.encode_plus(' '.join(words)) 104 | word_end_positions = [ 105 | encoded.char_to_token(i) 106 | for i in np.cumsum([len(word) + 1 for word in words]) - 2] 107 | word_start_positions = [ 108 | encoded.char_to_token(i) 109 | for i in np.cumsum([0] + [len(word) + 1 for word in words])[:-1]] 110 | 111 | input_ids = torch.tensor(encoded['input_ids'], dtype=torch.long) 112 | pair_ids = torch.zeros_like(input_ids) 113 | end_of_word = torch.zeros_like(input_ids) 114 | pos_ids = torch.zeros_like(input_ids) 115 | 116 | tag_ids = self.tag_system.tree_to_ids_pipeline(tree) 117 | 118 | # Pack both leaf and internal tag ids into a single "label" field. 119 | # (The huggingface API isn't flexible enough to use multiple label fields) 120 | tag_ids = [tag_id + 1 for tag_id in tag_ids] + [0] 121 | tag_ids = torch.tensor(tag_ids, dtype=torch.long) 122 | labels = torch.zeros_like(input_ids) 123 | 124 | odd_labels = tag_ids[1::2] 125 | if self.is_tetratags: 126 | even_labels = tag_ids[ 127 | ::2] - self.tag_system.decode_moderator.internal_tag_vocab_size 128 | labels[word_end_positions] = ( 129 | odd_labels * ( 130 | self.tag_system.decode_moderator.leaf_tag_vocab_size + 1) + even_labels) 131 | pos_ids[word_end_positions] = torch.tensor(pos_tags, dtype=torch.long) 132 | pair_ids[word_end_positions] = torch.as_tensor(word_start_positions) 133 | 134 | end_of_word[word_end_positions] = 1 135 | end_of_word[word_end_positions[-1]] = 2 # last word 136 | else: 137 | even_labels = tag_ids[::2] 138 | labels[word_end_positions] = ( 139 | odd_labels * (len(self.tag_system.tag_vocab) + 1) + even_labels) 140 | 141 | if self.pad_to_len is not None: 142 | pad_amount = self.pad_to_len - input_ids.shape[0] 143 | if pad_amount >= 0: 144 | input_ids = F.pad(input_ids, [0, pad_amount], value=self.pad_token_id) 145 | pos_ids = F.pad(pos_ids, [0, pad_amount], value=0) 146 | labels = F.pad(labels, [0, pad_amount], value=0) 147 | 148 | return { 149 | 'input_ids': input_ids, 150 | 'pos_ids': pos_ids, 151 | 'pair_ids': pair_ids, 152 | 'end_of_word': end_of_word, 153 | 'labels': labels 154 | } 155 | 156 | def collate(self, batch): 157 | # for GPT-2, self.pad_token_id is None 158 | pad_token_id = self.pad_token_id if self.pad_token_id is not None else -100 159 | input_ids = pad_sequence( 160 | [item['input_ids'] for item in batch], 161 | batch_first=True, padding_value=pad_token_id) 162 | 163 | attention_mask = (input_ids != pad_token_id).float() 164 | # for GPT-2, change -100 back into 0 165 | input_ids = torch.where( 166 | input_ids == -100, 167 | 0, 168 | input_ids 169 | ) 170 | 171 | pair_ids = pad_sequence( 172 | [item['pair_ids'] for item in batch], 173 | batch_first=True, padding_value=pad_token_id) 174 | end_of_word = pad_sequence( 175 | [item['end_of_word'] for item in batch], 176 | batch_first=True, padding_value=0) 177 | pos_ids = pad_sequence( 178 | [item['pos_ids'] for item in batch], 179 | batch_first=True, padding_value=0) 180 | 181 | labels = pad_sequence( 182 | [item['labels'] for item in batch], 183 | batch_first=True, padding_value=0) 184 | 185 | return { 186 | 'input_ids': input_ids, 187 | 'pos_ids': pos_ids, 188 | 'pair_ids': pair_ids, 189 | 'end_of_word': end_of_word, 190 | 'attention_mask': attention_mask, 191 | 'labels': labels, 192 | } 193 | -------------------------------------------------------------------------------- /learning/decode.py: -------------------------------------------------------------------------------- 1 | ## code adopted from: https://github.com/nikitakit/tetra-tagging 2 | 3 | import numpy as np 4 | 5 | 6 | class Beam: 7 | def __init__(self, scores, stack_depths, prev, backptrs, labels): 8 | self.scores = scores 9 | self.stack_depths = stack_depths 10 | self.prev = prev 11 | self.backptrs = backptrs 12 | self.labels = labels 13 | 14 | 15 | class BeamSearch: 16 | def __init__( 17 | self, 18 | tag_moderator, 19 | initial_stack_depth, 20 | max_depth=5, 21 | min_depth=1, 22 | keep_per_depth=1, 23 | crf_transitions=None, 24 | initial_label=None, 25 | ): 26 | # Save parameters 27 | self.tag_moderator = tag_moderator 28 | self.valid_depths = np.arange(min_depth, max_depth) 29 | self.keep_per_depth = keep_per_depth 30 | self.max_depth = max_depth 31 | self.min_depth = min_depth 32 | self.crf_transitions = crf_transitions 33 | 34 | # Initialize the beam 35 | scores = np.zeros(1, dtype=float) 36 | stack_depths = np.full(1, initial_stack_depth) 37 | prev = backptrs = labels = None 38 | if initial_label is not None: 39 | labels = np.full(1, initial_label) 40 | self.beam = Beam(scores, stack_depths, prev, backptrs, labels) 41 | 42 | def compute_new_scores(self, label_log_probs, is_last): 43 | if self.crf_transitions is None: 44 | return self.beam.scores[:, None] + label_log_probs 45 | else: 46 | if self.beam.labels is not None: 47 | all_new_scores = self.beam.scores[:, None] + label_log_probs + \ 48 | self.crf_transitions["transitions"][self.beam.labels] 49 | else: 50 | all_new_scores = self.beam.scores[:, None] + label_log_probs + \ 51 | self.crf_transitions["start_transitions"] 52 | if is_last: 53 | all_new_scores += self.crf_transitions["end_transitions"] 54 | return all_new_scores 55 | 56 | # This extra mask layer takes care of invalid reduce actions when there is not an empty 57 | # slot in the tree, which is needed in the top-down shift reduce tagging schema 58 | def extra_mask_layer(self, all_new_scores, all_new_stack_depths): 59 | depth_mask = np.zeros(all_new_stack_depths.shape) 60 | depth_mask[all_new_stack_depths < 0] = -np.inf 61 | depth_mask[all_new_stack_depths > self.max_depth] = -np.inf 62 | all_new_scores = all_new_scores + depth_mask 63 | 64 | all_new_stack_depths = ( 65 | all_new_stack_depths 66 | + self.tag_moderator.stack_depth_change_by_id 67 | ) 68 | return all_new_scores, all_new_stack_depths 69 | 70 | def advance(self, label_logits, is_last=False): 71 | label_log_probs = label_logits 72 | 73 | all_new_scores = self.compute_new_scores(label_log_probs, is_last) 74 | if self.tag_moderator.mask_binarize and self.beam.labels is not None: 75 | labels = self.beam.labels 76 | all_new_scores = self.tag_moderator.mask_scores_for_binarization(labels, 77 | all_new_scores) 78 | 79 | if self.tag_moderator.stack_depth_change_by_id_l2 is not None: 80 | all_new_stack_depths = ( 81 | self.beam.stack_depths[:, None] 82 | + self.tag_moderator.stack_depth_change_by_id_l2[None, :] 83 | ) 84 | all_new_scores, all_new_stack_depths = self.extra_mask_layer(all_new_scores, 85 | all_new_stack_depths) 86 | else: 87 | all_new_stack_depths = ( 88 | self.beam.stack_depths[:, None] 89 | + self.tag_moderator.stack_depth_change_by_id[None, :] 90 | ) 91 | 92 | masked_scores = all_new_scores[None, :, :] + np.where( 93 | all_new_stack_depths[None, :, :] == self.valid_depths[:, None, None], 94 | 0.0, -np.inf, 95 | ) 96 | masked_scores = masked_scores.reshape(self.valid_depths.shape[0], -1) 97 | idxs = np.argsort(-masked_scores)[:, : self.keep_per_depth].flatten() 98 | backptrs, labels = np.unravel_index(idxs, all_new_scores.shape) 99 | 100 | transition_valid = all_new_stack_depths[ 101 | backptrs, labels 102 | ] == self.valid_depths.repeat(self.keep_per_depth) 103 | 104 | backptrs = backptrs[transition_valid] 105 | labels = labels[transition_valid] 106 | 107 | self.beam = Beam( 108 | all_new_scores[backptrs, labels], 109 | all_new_stack_depths[backptrs, labels], 110 | self.beam, 111 | backptrs, 112 | labels, 113 | ) 114 | 115 | def get_path(self, idx=0, required_stack_depth=1): 116 | if required_stack_depth is not None: 117 | assert self.beam.stack_depths[idx] == required_stack_depth 118 | score = self.beam.scores[idx] 119 | assert score > -np.inf 120 | 121 | beam = self.beam 122 | label_idxs = [] 123 | while beam.prev is not None: 124 | label_idxs.insert(0, beam.labels[idx]) 125 | idx = beam.backptrs[idx] 126 | beam = beam.prev 127 | 128 | return score, label_idxs 129 | 130 | 131 | class GreedySearch(BeamSearch): 132 | def advance(self, label_logits, is_last=False): 133 | label_log_probs = label_logits 134 | 135 | all_new_scores = self.compute_new_scores(label_log_probs, is_last) 136 | if self.tag_moderator.mask_binarize and self.beam.labels is not None: 137 | labels = self.beam.labels 138 | all_new_scores = self.tag_moderator.mask_scores_for_binarization(labels, 139 | all_new_scores) 140 | 141 | if self.tag_moderator.stack_depth_change_by_id_l2 is not None: 142 | all_new_stack_depths = ( 143 | self.beam.stack_depths[:, None] 144 | + self.tag_moderator.stack_depth_change_by_id_l2[None, :] 145 | ) 146 | 147 | all_new_scores, all_new_stack_depths = self.extra_mask_layer(all_new_scores, 148 | all_new_stack_depths) 149 | else: 150 | all_new_stack_depths = ( 151 | self.beam.stack_depths[:, None] 152 | + self.tag_moderator.stack_depth_change_by_id[None, :] 153 | ) 154 | 155 | masked_scores = all_new_scores + np.where((all_new_stack_depths >= self.min_depth) 156 | & (all_new_stack_depths <= self.max_depth), 157 | 0.0, 158 | -np.inf) 159 | 160 | masked_scores = masked_scores.reshape(-1) 161 | idxs = np.argsort(-masked_scores)[:self.keep_per_depth].flatten() 162 | 163 | backptrs, labels = np.unravel_index(idxs, all_new_scores.shape) 164 | 165 | transition_valid = (all_new_stack_depths[ 166 | backptrs, labels 167 | ] >= self.min_depth) & (all_new_stack_depths[ 168 | backptrs, labels 169 | ] <= self.max_depth) 170 | 171 | backptrs = backptrs[transition_valid] 172 | labels = labels[transition_valid] 173 | 174 | self.beam = Beam( 175 | all_new_scores[backptrs, labels], 176 | all_new_stack_depths[backptrs, labels], 177 | self.beam, 178 | backptrs, 179 | labels, 180 | ) 181 | -------------------------------------------------------------------------------- /learning/evaluate.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import os.path 4 | import re 5 | import subprocess 6 | import tempfile 7 | from copy import deepcopy 8 | from typing import Tuple 9 | import tempfile 10 | 11 | import numpy as np 12 | import torch 13 | from sklearn.metrics import precision_recall_fscore_support 14 | from tqdm import tqdm as tq 15 | 16 | from tagging.tree_tools import create_dummy_tree 17 | 18 | repo_directory = os.path.abspath(__file__) 19 | 20 | class ParseMetrics(object): 21 | def __init__(self, recall, precision, fscore, complete_match, tagging_accuracy=100): 22 | self.recall = recall 23 | self.precision = precision 24 | self.fscore = fscore 25 | self.complete_match = complete_match 26 | self.tagging_accuracy = tagging_accuracy 27 | 28 | def __str__(self): 29 | if self.tagging_accuracy < 100: 30 | return "(Recall={:.4f}, Precision={:.4f}, ParseMetrics={:.4f}, CompleteMatch={:.4f}, TaggingAccuracy={:.4f})".format( 31 | self.recall, self.precision, self.fscore, self.complete_match, 32 | self.tagging_accuracy) 33 | else: 34 | return "(Recall={:.4f}, Precision={:.4f}, ParseMetrics={:.4f}, CompleteMatch={:.4f})".format( 35 | self.recall, self.precision, self.fscore, self.complete_match) 36 | 37 | 38 | def report_eval_loss(model, eval_dataloader, device, n_iter, writer) -> np.ndarray: 39 | loss = [] 40 | for batch in eval_dataloader: 41 | batch = {k: v.to(device) for k, v in batch.items()} 42 | with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16): 43 | outputs = model(**batch) 44 | loss.append(torch.mean(outputs[0]).cpu()) 45 | 46 | mean_loss = np.mean(loss) 47 | logging.info("Eval Loss: {}".format(mean_loss)) 48 | if writer is not None: 49 | writer.add_scalar('eval_loss', mean_loss, n_iter) 50 | return mean_loss 51 | 52 | 53 | def predict( 54 | model, eval_dataloader, dataset_size, num_tags, batch_size, device 55 | ) -> Tuple[np.array, np.array]: 56 | model.eval() 57 | predictions = [] 58 | eval_labels = [] 59 | max_len = 0 60 | idx = 0 61 | 62 | for batch in tq(eval_dataloader, disable=True): 63 | batch = {k: v.to(device) for k, v in batch.items()} 64 | 65 | with torch.no_grad(), torch.cuda.amp.autocast( 66 | enabled=True, dtype=torch.bfloat16 67 | ): 68 | outputs = model(**batch) 69 | 70 | logits = outputs[1].float().cpu().numpy() 71 | max_len = max(max_len, logits.shape[1]) 72 | predictions.append(logits) 73 | labels = batch['labels'].int().cpu().numpy() 74 | eval_labels.append(labels) 75 | idx += 1 76 | 77 | predictions = np.concatenate([np.pad(logits, 78 | ((0, 0), (0, max_len - logits.shape[1]), (0, 0)), 79 | 'constant', constant_values=0) for logits in 80 | predictions], axis=0) 81 | eval_labels = np.concatenate([np.pad(labels, ((0, 0), (0, max_len - labels.shape[1])), 82 | 'constant', constant_values=0) for labels in 83 | eval_labels], axis=0) 84 | 85 | return predictions, eval_labels 86 | 87 | 88 | def calc_tag_accuracy( 89 | predictions, eval_labels, num_leaf_labels, writer, use_tensorboard 90 | ) -> Tuple[float, float]: 91 | even_predictions = predictions[..., -num_leaf_labels:] 92 | odd_predictions = predictions[..., :-num_leaf_labels] 93 | even_labels = eval_labels % (num_leaf_labels + 1) - 1 94 | odd_labels = eval_labels // (num_leaf_labels + 1) - 1 95 | 96 | odd_predictions = odd_predictions[odd_labels != -1].argmax(-1) 97 | even_predictions = even_predictions[even_labels != -1].argmax(-1) 98 | 99 | odd_labels = odd_labels[odd_labels != -1] 100 | even_labels = even_labels[even_labels != -1] 101 | 102 | odd_acc = (odd_predictions == odd_labels).mean() 103 | even_acc = (even_predictions == even_labels).mean() 104 | 105 | logging.info('odd_tags_accuracy: {}'.format(odd_acc)) 106 | logging.info('even_tags_accuracy: {}'.format(even_acc)) 107 | 108 | if use_tensorboard: 109 | writer.add_pr_curve('odd_tags_pr_curve', odd_labels, odd_predictions, 0) 110 | writer.add_pr_curve('even_tags_pr_curve', even_labels, even_predictions, 1) 111 | return even_acc, odd_acc 112 | 113 | 114 | def get_dependency_from_lexicalized_tree(lex_tree, triple_dict, offset=0): 115 | # this recursion assumes projectivity 116 | # Input: 117 | # root of lex-tree 118 | # Output: 119 | # the global index of the dependency root 120 | if type(lex_tree) not in {str, dict} and len(lex_tree) == 1: 121 | # unary rule 122 | # returning the global index of the head 123 | return offset 124 | 125 | head_branch_index = int(lex_tree.label().split("^^^")[1]) 126 | head_global_index = None 127 | branch_to_global_dict = {} 128 | 129 | for branch_id_child, child in enumerate(lex_tree): 130 | global_id_child = get_dependency_from_lexicalized_tree( 131 | child, triple_dict, offset=offset 132 | ) 133 | offset = offset + len(child.leaves()) 134 | branch_to_global_dict[branch_id_child] = global_id_child 135 | if branch_id_child == head_branch_index: 136 | head_global_index = global_id_child 137 | 138 | for branch_id_child, child in enumerate(lex_tree): 139 | if branch_id_child != head_branch_index: 140 | triple_dict[branch_to_global_dict[branch_id_child]] = head_global_index 141 | 142 | return head_global_index 143 | 144 | 145 | def is_punctuation(pos): 146 | punct_set = '.' '``' "''" ':' ',' 147 | return (pos in punct_set) or (pos.lower() in ['pu', 'punct']) # for CTB & UD 148 | 149 | 150 | def tree_to_dep_triples(lex_tree): 151 | triple_dict = {} 152 | dep_triples = [] 153 | sent_root = get_dependency_from_lexicalized_tree( 154 | lex_tree, triple_dict 155 | ) 156 | # the root of the whole sentence should refer to ROOT 157 | assert sent_root not in triple_dict 158 | # the root of the sentence 159 | triple_dict[sent_root] = -1 160 | for head, tail in sorted(triple_dict.items()): 161 | dep_triples.append(( 162 | head, tail, 163 | lex_tree.pos()[head][1].split("^^^")[1].split("+")[0], 164 | lex_tree.pos()[head][1].split("^^^")[1].split("+")[1] 165 | )) 166 | return dep_triples 167 | 168 | 169 | 170 | def dependency_eval( 171 | predictions, eval_labels, eval_dataset, tag_system, output_path, 172 | model_name, max_depth, keep_per_depth, is_greedy 173 | ) -> ParseMetrics: 174 | ud_flag = eval_dataset.language not in {'English', 'Chinese'} 175 | 176 | # This can be parallelized! 177 | predicted_dev_triples, predicted_dev_triples_unlabeled = [], [] 178 | gold_dev_triples, gold_dev_triples_unlabeled = [], [] 179 | c_err = 0 180 | 181 | gt_triple_data, pred_triple_data = [], [] 182 | 183 | for i in tq(range(predictions.shape[0]), disable=True): 184 | logits = predictions[i] 185 | is_word = (eval_labels[i] != 0) 186 | 187 | original_tree = deepcopy(eval_dataset.trees[i]) 188 | original_tree.collapse_unary(collapsePOS=True, collapseRoot=True) 189 | 190 | try: # ignore the ones that failed in unchomsky_normal_form 191 | tree = tag_system.logits_to_tree( 192 | logits, original_tree.pos(), 193 | mask=is_word, 194 | max_depth=max_depth, 195 | keep_per_depth=keep_per_depth, 196 | is_greedy=is_greedy 197 | ) 198 | tree.collapse_unary(collapsePOS=True, collapseRoot=True) 199 | except Exception as ex: 200 | template = "An exception of type {0} occurred. Arguments:\n{1!r}" 201 | message = template.format(type(ex).__name__, ex.args) 202 | c_err += 1 203 | predicted_dev_triples.append(create_dummy_tree(original_tree.pos())) 204 | continue 205 | if tree.leaves() != original_tree.leaves(): 206 | c_err += 1 207 | predicted_dev_triples.append(create_dummy_tree(original_tree.pos())) 208 | continue 209 | 210 | gt_triples = tree_to_dep_triples(original_tree) 211 | pred_triples = tree_to_dep_triples(tree) 212 | 213 | gt_triple_data.append(gt_triples) 214 | pred_triple_data.append(pred_triples) 215 | 216 | assert len(gt_triples) == len( 217 | pred_triples), f"wrong length {len(gt_triples)} vs. {len(pred_triples)}!" 218 | 219 | for x, y in zip(sorted(gt_triples), sorted(pred_triples)): 220 | if is_punctuation(x[3]) and not ud_flag: 221 | # ignoring punctuations for evaluation 222 | continue 223 | assert x[0] == y[0], f"wrong tree {gt_triples} vs. {pred_triples}!" 224 | 225 | gold_dev_triples.append(f"{x[0]}-{x[1]}-{x[2].split(':')[0]}") 226 | gold_dev_triples_unlabeled.append(f"{x[0]}-{x[1]}") 227 | 228 | predicted_dev_triples.append(f"{y[0]}-{y[1]}-{y[2].split(':')[0]}") 229 | predicted_dev_triples_unlabeled.append(f"{y[0]}-{y[1]}") 230 | 231 | if ud_flag: 232 | # UD 233 | predicted_dev_triples, predicted_dev_triples_unlabeled = [], [] 234 | gold_dev_triples, gold_dev_triples_unlabeled = [], [] 235 | 236 | language, split = eval_dataset.language, eval_dataset.split.split(".")[-1] 237 | 238 | gold_temp_out, pred_temp_out = tempfile.mktemp(dir=os.path.dirname(repo_directory)), \ 239 | tempfile.mktemp(dir=os.path.dirname(repo_directory)) 240 | gold_temp_in, pred_temp_in = gold_temp_out + ".deproj", pred_temp_out + ".deproj" 241 | 242 | 243 | save_triplets(gt_triple_data, gold_temp_out) 244 | save_triplets(pred_triple_data, pred_temp_out) 245 | 246 | for filename, tgt_filename in zip([gold_temp_out, pred_temp_out], [gold_temp_in, pred_temp_in]): 247 | command = f"cd ./malt/maltparser-1.9.2/; java -jar maltparser-1.9.2.jar -c {language}_{split} -m deproj" \ 248 | f" -i {filename} -o {tgt_filename} ; cd ../../" 249 | os.system(command) 250 | 251 | loaded_gold_dev_triples = load_triplets(gold_temp_in) 252 | loaded_pred_dev_triples = load_triplets(pred_temp_in) 253 | 254 | for gt_triples, pred_triples in zip(loaded_gold_dev_triples, loaded_pred_dev_triples): 255 | for x, y in zip(sorted(gt_triples), sorted(pred_triples)): 256 | if is_punctuation(x[3]): 257 | # ignoring punctuations for evaluation 258 | continue 259 | assert x[0] == y[0], f"wrong tree {gt_triples} vs. {pred_triples}!" 260 | 261 | gold_dev_triples.append(f"{x[0]}-{x[1]}-{x[2].split(':')[0]}") 262 | gold_dev_triples_unlabeled.append(f"{x[0]}-{x[1]}") 263 | 264 | predicted_dev_triples.append(f"{y[0]}-{y[1]}-{y[2].split(':')[0]}") 265 | predicted_dev_triples_unlabeled.append(f"{y[0]}-{y[1]}") 266 | 267 | 268 | logging.warning("Number of binarization error: {}\n".format(c_err)) 269 | las_recall, las_precision, las_fscore, _ = precision_recall_fscore_support( 270 | gold_dev_triples, predicted_dev_triples, average='micro' 271 | ) 272 | uas_recall, uas_precision, uas_fscore, _ = precision_recall_fscore_support( 273 | gold_dev_triples_unlabeled, predicted_dev_triples_unlabeled, average='micro' 274 | ) 275 | 276 | return (ParseMetrics(las_recall, las_precision, las_fscore, complete_match=1), 277 | ParseMetrics(uas_recall, uas_precision, uas_fscore, complete_match=1)) 278 | 279 | 280 | def save_triplets(triplet_data, file_path): 281 | # save triplets to file in conll format 282 | with open(file_path, 'w') as f: 283 | for triplets in triplet_data: 284 | for triplet in triplets: 285 | # 8 Витоша витоша PROPN Npfsi Definite=Ind|Gender=Fem|Number=Sing 6 nmod _ _ 286 | head, tail, label, pos = triplet 287 | f.write(f"{head+1}\t-\t-\t{pos}\t-\t-\t{tail+1}\t{label}\t_\t_\n") 288 | f.write('\n') 289 | 290 | return 291 | 292 | 293 | def load_triplets(file_path): 294 | # load triplets from file in conll format 295 | triplet_data = [] 296 | with open(file_path, 'r') as f: 297 | triplets = [] 298 | for line in f.readlines(): 299 | if line.startswith('#') or line == '\n': 300 | if triplets: 301 | triplet_data.append(triplets) 302 | triplets = [] 303 | continue 304 | line_list = line.strip().split('\t') 305 | head, tail, label, pos = line_list[0], line_list[6], line_list[7], line_list[3] 306 | triplets.append((head, tail, label, pos)) 307 | if triplets: 308 | triplet_data.append(triplets) 309 | return triplet_data 310 | 311 | 312 | def calc_parse_eval(predictions, eval_labels, eval_dataset, tag_system, output_path, 313 | model_name, max_depth, keep_per_depth, is_greedy) -> ParseMetrics: 314 | predicted_dev_trees = [] 315 | gold_dev_trees = [] 316 | c_err = 0 317 | for i in tq(range(predictions.shape[0])): 318 | logits = predictions[i] 319 | is_word = eval_labels[i] != 0 320 | original_tree = eval_dataset.trees[i] 321 | gold_dev_trees.append(original_tree) 322 | try: # ignore the ones that failed in unchomsky_normal_form 323 | tree = tag_system.logits_to_tree(logits, original_tree.pos(), mask=is_word, 324 | max_depth=max_depth, 325 | keep_per_depth=keep_per_depth, 326 | is_greedy=is_greedy) 327 | except Exception as ex: 328 | template = "An exception of type {0} occurred. Arguments:\n{1!r}" 329 | message = template.format(type(ex).__name__, ex.args) 330 | # print(message) 331 | c_err += 1 332 | predicted_dev_trees.append(create_dummy_tree(original_tree.pos())) 333 | continue 334 | if tree.leaves() != original_tree.leaves(): 335 | c_err += 1 336 | predicted_dev_trees.append(create_dummy_tree(original_tree.pos())) 337 | continue 338 | predicted_dev_trees.append(tree) 339 | 340 | logging.warning("Number of binarization error: {}".format(c_err)) 341 | 342 | return evalb("EVALB_SPMRL/", gold_dev_trees, predicted_dev_trees) 343 | 344 | 345 | def save_predictions(predicted_trees, file_path): 346 | with open(file_path, 'w') as f: 347 | for tree in predicted_trees: 348 | f.write(' '.join(str(tree).split()) + '\n') 349 | 350 | 351 | def evalb(evalb_dir, gold_trees, predicted_trees, ref_gold_path=None) -> ParseMetrics: 352 | # Code from: https://github.com/nikitakit/self-attentive-parser/blob/master/src/evaluate.py 353 | assert os.path.exists(evalb_dir) 354 | evalb_program_path = os.path.join(evalb_dir, "evalb") 355 | evalb_spmrl_program_path = os.path.join(evalb_dir, "evalb_spmrl") 356 | assert os.path.exists(evalb_program_path) or os.path.exists(evalb_spmrl_program_path) 357 | 358 | if os.path.exists(evalb_program_path): 359 | evalb_param_path = os.path.join(evalb_dir, "nk.prm") 360 | else: 361 | evalb_program_path = evalb_spmrl_program_path 362 | evalb_param_path = os.path.join(evalb_dir, "spmrl.prm") 363 | 364 | assert os.path.exists(evalb_program_path) 365 | assert os.path.exists(evalb_param_path) 366 | 367 | assert len(gold_trees) == len(predicted_trees) 368 | for gold_tree, predicted_tree in zip(gold_trees, predicted_trees): 369 | gold_leaves = list(gold_tree.leaves()) 370 | predicted_leaves = list(predicted_tree.leaves()) 371 | 372 | temp_dir = tempfile.TemporaryDirectory(prefix="evalb-") 373 | gold_path = os.path.join(temp_dir.name, "gold.txt") 374 | predicted_path = os.path.join(temp_dir.name, "predicted.txt") 375 | output_path = os.path.join(temp_dir.name, "output.txt") 376 | 377 | with open(gold_path, "w") as outfile: 378 | if ref_gold_path is None: 379 | for tree in gold_trees: 380 | outfile.write(' '.join(str(tree).split()) + '\n') 381 | else: 382 | # For the SPMRL dataset our data loader performs some modifications 383 | # (like stripping morphological features), so we compare to the 384 | # raw gold file to be certain that we haven't spoiled the evaluation 385 | # in some way. 386 | with open(ref_gold_path) as goldfile: 387 | outfile.write(goldfile.read()) 388 | 389 | with open(predicted_path, "w") as outfile: 390 | for tree in predicted_trees: 391 | outfile.write(' '.join(str(tree).split()) + '\n') 392 | 393 | command = "{} -p {} {} {} > {}".format( 394 | evalb_program_path, 395 | evalb_param_path, 396 | gold_path, 397 | predicted_path, 398 | output_path, 399 | ) 400 | subprocess.run(command, shell=True) 401 | 402 | fscore = ParseMetrics(math.nan, math.nan, math.nan, math.nan) 403 | with open(output_path) as infile: 404 | for line in infile: 405 | match = re.match(r"Bracketing Recall\s+=\s+(\d+\.\d+)", line) 406 | if match: 407 | fscore.recall = float(match.group(1)) 408 | match = re.match(r"Bracketing Precision\s+=\s+(\d+\.\d+)", line) 409 | if match: 410 | fscore.precision = float(match.group(1)) 411 | match = re.match(r"Bracketing FMeasure\s+=\s+(\d+\.\d+)", line) 412 | if match: 413 | fscore.fscore = float(match.group(1)) 414 | match = re.match(r"Complete match\s+=\s+(\d+\.\d+)", line) 415 | if match: 416 | fscore.complete_match = float(match.group(1)) 417 | match = re.match(r"Tagging accuracy\s+=\s+(\d+\.\d+)", line) 418 | if match: 419 | fscore.tagging_accuracy = float(match.group(1)) 420 | break 421 | 422 | success = ( 423 | not math.isnan(fscore.fscore) or 424 | fscore.recall == 0.0 or 425 | fscore.precision == 0.0) 426 | 427 | if success: 428 | temp_dir.cleanup() 429 | else: 430 | print("Error reading EVALB results.") 431 | print("Gold path: {}".format(gold_path)) 432 | print("Predicted path: {}".format(predicted_path)) 433 | print("Output path: {}".format(output_path)) 434 | 435 | return fscore 436 | 437 | 438 | 439 | 440 | def dependency_decoding( 441 | predictions, eval_labels, eval_dataset, tag_system, output_path, 442 | model_name, max_depth, keep_per_depth, is_greedy 443 | ) -> ParseMetrics: 444 | ud_flag = eval_dataset.language not in {'English', 'Chinese'} 445 | 446 | # This can be parallelized! 447 | predicted_dev_triples, predicted_dev_triples_unlabeled = [], [] 448 | gold_dev_triples, gold_dev_triples_unlabeled = [], [] 449 | pred_hexa_tags = [] 450 | c_err = 0 451 | 452 | gt_triple_data, pred_triple_data = [], [] 453 | 454 | for i in tq(range(predictions.shape[0]), disable=True): 455 | logits = predictions[i] 456 | is_word = (eval_labels[i] != 0) 457 | 458 | original_tree = deepcopy(eval_dataset.trees[i]) 459 | original_tree.collapse_unary(collapsePOS=True, collapseRoot=True) 460 | 461 | try: # ignore the ones that failed in unchomsky_normal_form 462 | tree = tag_system.logits_to_tree( 463 | logits, original_tree.pos(), 464 | mask=is_word, 465 | max_depth=max_depth, 466 | keep_per_depth=keep_per_depth, 467 | is_greedy=is_greedy 468 | ) 469 | hexa_ids = tag_system.logits_to_ids( 470 | logits, is_word, max_depth, keep_per_depth, is_greedy=is_greedy 471 | ) 472 | pred_hexa_tags.append(hexa_ids) 473 | 474 | tree.collapse_unary(collapsePOS=True, collapseRoot=True) 475 | except Exception as ex: 476 | template = "An exception of type {0} occurred. Arguments:\n{1!r}" 477 | message = template.format(type(ex).__name__, ex.args) 478 | c_err += 1 479 | predicted_dev_triples.append(create_dummy_tree(original_tree.pos())) 480 | continue 481 | if tree.leaves() != original_tree.leaves(): 482 | c_err += 1 483 | predicted_dev_triples.append(create_dummy_tree(original_tree.pos())) 484 | continue 485 | 486 | gt_triples = tree_to_dep_triples(original_tree) 487 | pred_triples = tree_to_dep_triples(tree) 488 | 489 | gt_triple_data.append(gt_triples) 490 | pred_triple_data.append(pred_triples) 491 | 492 | assert len(gt_triples) == len( 493 | pred_triples), f"wrong length {len(gt_triples)} vs. {len(pred_triples)}!" 494 | 495 | for x, y in zip(sorted(gt_triples), sorted(pred_triples)): 496 | if is_punctuation(x[3]) and not ud_flag: 497 | # ignoring punctuations for evaluation 498 | continue 499 | assert x[0] == y[0], f"wrong tree {gt_triples} vs. {pred_triples}!" 500 | 501 | gold_dev_triples.append(f"{x[0]}-{x[1]}-{x[2].split(':')[0]}") 502 | gold_dev_triples_unlabeled.append(f"{x[0]}-{x[1]}") 503 | 504 | predicted_dev_triples.append(f"{y[0]}-{y[1]}-{y[2].split(':')[0]}") 505 | predicted_dev_triples_unlabeled.append(f"{y[0]}-{y[1]}") 506 | 507 | if ud_flag: 508 | # UD 509 | predicted_dev_triples, predicted_dev_triples_unlabeled = [], [] 510 | gold_dev_triples, gold_dev_triples_unlabeled = [], [] 511 | 512 | language, split = eval_dataset.language, eval_dataset.split.split(".")[-1] 513 | 514 | gold_temp_out, pred_temp_out = tempfile.mktemp(dir=os.path.dirname(repo_directory)), \ 515 | tempfile.mktemp(dir=os.path.dirname(repo_directory)) 516 | gold_temp_in, pred_temp_in = gold_temp_out + ".deproj", pred_temp_out + ".deproj" 517 | 518 | 519 | save_triplets(gt_triple_data, gold_temp_out) 520 | save_triplets(pred_triple_data, pred_temp_out) 521 | 522 | for filename, tgt_filename in zip([gold_temp_out, pred_temp_out], [gold_temp_in, pred_temp_in]): 523 | command = f"cd ./malt/maltparser-1.9.2/; java -jar maltparser-1.9.2.jar -c {language}_{split} -m deproj" \ 524 | f" -i {filename} -o {tgt_filename} ; cd ../../" 525 | os.system(command) 526 | 527 | loaded_gold_dev_triples = load_triplets(gold_temp_in) 528 | loaded_pred_dev_triples = load_triplets(pred_temp_in) 529 | 530 | for gt_triples, pred_triples in zip(loaded_gold_dev_triples, loaded_pred_dev_triples): 531 | for x, y in zip(sorted(gt_triples), sorted(pred_triples)): 532 | if is_punctuation(x[3]): 533 | # ignoring punctuations for evaluation 534 | continue 535 | assert x[0] == y[0], f"wrong tree {gt_triples} vs. {pred_triples}!" 536 | 537 | gold_dev_triples.append(f"{x[0]}-{x[1]}-{x[2].split(':')[0]}") 538 | gold_dev_triples_unlabeled.append(f"{x[0]}-{x[1]}") 539 | 540 | predicted_dev_triples.append(f"{y[0]}-{y[1]}-{y[2].split(':')[0]}") 541 | predicted_dev_triples_unlabeled.append(f"{y[0]}-{y[1]}") 542 | 543 | return { 544 | "predicted_dev_triples": predicted_dev_triples, 545 | "predicted_hexa_tags": pred_hexa_tags 546 | } -------------------------------------------------------------------------------- /learning/learn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 5 | 6 | from transformers import AutoModelForTokenClassification, AutoModel 7 | 8 | 9 | def calc_loss_helper(logits, labels, attention_mask, num_even_tags, num_odd_tags): 10 | # shape: (batch_size, seq_len, num_tags) -> (batch_size, num_tags, seq_len) 11 | logits = torch.movedim(logits, -1, 1) 12 | odd_logits, even_logits = torch.split(logits, [num_odd_tags, num_even_tags], dim=1) 13 | odd_labels = (labels // (num_even_tags + 1)) - 1 14 | even_labels = (labels % (num_even_tags + 1)) - 1 15 | # The last word will have only even label 16 | 17 | # Only keep active parts of the loss 18 | active_even_labels = torch.where( 19 | attention_mask, even_labels, -1 20 | ) 21 | active_odd_labels = torch.where( 22 | attention_mask, odd_labels, -1 23 | ) 24 | loss = (F.cross_entropy(even_logits, active_even_labels, ignore_index=-1) 25 | + F.cross_entropy(odd_logits, active_odd_labels, ignore_index=-1)) 26 | 27 | return loss 28 | 29 | 30 | class ModelForTetratagging(nn.Module): 31 | def __init__(self, config): 32 | super().__init__() 33 | self.num_even_tags = config.task_specific_params['num_even_tags'] 34 | self.num_odd_tags = config.task_specific_params['num_odd_tags'] 35 | self.model_path = config.task_specific_params['model_path'] 36 | self.use_pos = config.task_specific_params.get('use_pos', False) 37 | self.num_pos_tags = config.task_specific_params.get('num_pos_tags', 50) 38 | 39 | self.pos_emb_dim = config.task_specific_params['pos_emb_dim'] 40 | self.dropout_rate = config.task_specific_params['dropout'] 41 | 42 | self.bert = AutoModel.from_pretrained(self.model_path, config=config) 43 | if self.use_pos: 44 | self.pos_encoder = nn.Sequential( 45 | nn.Embedding(self.num_pos_tags, self.pos_emb_dim, padding_idx=0) 46 | ) 47 | 48 | self.endofword_embedding = nn.Embedding(2, self.pos_emb_dim) 49 | self.lstm = nn.LSTM( 50 | 2 * config.hidden_size + self.pos_emb_dim * (1 + self.use_pos), 51 | config.hidden_size, 52 | config.task_specific_params['lstm_layers'], 53 | dropout=self.dropout_rate, 54 | batch_first=True, bidirectional=True 55 | ) 56 | self.projection = nn.Sequential( 57 | nn.Linear(2 * config.hidden_size, config.num_labels) 58 | ) 59 | 60 | def forward( 61 | self, 62 | input_ids=None, 63 | pair_ids=None, 64 | pos_ids=None, 65 | end_of_word=None, 66 | attention_mask=None, 67 | head_mask=None, 68 | inputs_embeds=None, 69 | labels=None, 70 | output_attentions=None, 71 | output_hidden_states=None, 72 | ): 73 | outputs = self.bert( 74 | input_ids, 75 | attention_mask=attention_mask, 76 | head_mask=head_mask, 77 | inputs_embeds=inputs_embeds, 78 | output_attentions=output_attentions, 79 | output_hidden_states=output_hidden_states, 80 | ) 81 | if self.use_pos: 82 | pos_encodings = self.pos_encoder(pos_ids) 83 | token_repr = torch.cat([outputs[0], pos_encodings], dim=-1) 84 | else: 85 | token_repr = outputs[0] 86 | 87 | start_repr = outputs[0].take_along_dim(pair_ids.unsqueeze(-1), dim=1) 88 | token_repr = torch.cat([token_repr, start_repr], dim=-1) 89 | token_repr = torch.cat([token_repr, self.endofword_embedding((pos_ids != 0).long())], 90 | dim=-1) 91 | 92 | lens = attention_mask.sum(dim=-1).cpu() 93 | token_repr = pack_padded_sequence(token_repr, lens, batch_first=True, 94 | enforce_sorted=False) 95 | token_repr = self.lstm(token_repr)[0] 96 | token_repr, _ = pad_packed_sequence(token_repr, batch_first=True) 97 | 98 | tag_logits = self.projection(token_repr) 99 | 100 | loss = None 101 | if labels is not None and self.training: 102 | loss = calc_loss_helper( 103 | tag_logits, labels, attention_mask.bool(), 104 | self.num_even_tags, self.num_odd_tags 105 | ) 106 | return loss, tag_logits 107 | else: 108 | return loss, tag_logits 109 | -------------------------------------------------------------------------------- /learning/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | from typing import Optional, Tuple, Any, Dict, Iterable, Union 6 | import math 7 | 8 | 9 | class TakeLSTMOutput(nn.Module): 10 | # Take the last hidden state from the output of the LSTM 11 | def forward(self, x): 12 | tensor, _ = x 13 | return tensor 14 | 15 | 16 | def hexa_loss(logits, labels, attention_mask, num_even_tags, num_odd_tags): 17 | # shape: (batch_size, seq_len, num_tags) -> (batch_size, num_tags, seq_len) 18 | logits = torch.movedim(logits, -1, 1) 19 | odd_logits, even_logits = torch.split(logits, [num_odd_tags, num_even_tags], dim=1) 20 | 21 | odd_labels = (labels // (num_even_tags + 1)) - 1 22 | even_labels = (labels % (num_even_tags + 1)) - 1 23 | # The last word will have only even label 24 | 25 | # Only keep active parts of the loss 26 | active_even_labels = torch.where(attention_mask, even_labels, -1) 27 | active_odd_labels = torch.where(attention_mask, odd_labels, -1) 28 | loss = (F.cross_entropy(even_logits, active_even_labels, ignore_index=-1) 29 | + F.cross_entropy(odd_logits, active_odd_labels, ignore_index=-1)) 30 | 31 | ground_truth_likelihood = 0. 32 | 33 | return loss, ground_truth_likelihood 34 | 35 | 36 | def calc_loss_helper(logits, labels, attention_mask): 37 | # Only keep active parts of the loss 38 | 39 | # logits: shape: (batch_size, *, num_tags) 40 | # active_logits: shape: (batch_size, num_tags, *) 41 | active_logits = torch.movedim(logits, -1, 1) 42 | if active_logits.dim() == 4: # for permute logits 43 | attention_mask = attention_mask.unsqueeze(2) 44 | 45 | # shape: (batch_size, seq_len, ...) 46 | active_labels = torch.where( 47 | (attention_mask == 1), labels, -1 48 | ) 49 | loss = F.cross_entropy(active_logits, active_labels, ignore_index=-1) 50 | ground_truth_likelihood = 0. 51 | 52 | return loss, ground_truth_likelihood 53 | 54 | 55 | def interval_sort( 56 | values, # shape: (batch_size, seq_len) 57 | gt_values, # shape: (batch_size, seq_len) 58 | mask, # shape: (batch_size, seq_len) 59 | intervals # shape: (batch_size, num_intervals, 4) 60 | ): 61 | # Args: 62 | # value: (batch_size, seq_len) 63 | # intervals: (batch_size, num_intervals, 4) 64 | # tuples of (start, end, split, self) intervals. end is inclusive. 65 | # Returns: 66 | # interval_stats: (batch_size, num_intervals, 3) 67 | # tuples of (min, median, max) values in the interval 68 | 69 | # shape: (batch_size, seq_len) 70 | batch_size, seq_len = values.size() 71 | num_intervals = intervals.size(1) 72 | 73 | _, sorted_indices = torch.where( 74 | mask, gt_values, float('inf') # ascending 75 | ).sort(dim=1) 76 | sorted_values = values.gather(dim=1, index=sorted_indices) 77 | 78 | range_vec = torch.arange( # shape: (1, seq_len, 1) 79 | 0, values.size(1), device=intervals.device, dtype=torch.long 80 | )[None, :, None] 81 | # shape: (batch_size, num_intervals) 82 | relative_indices = (intervals.select(dim=-1, index=3) - 83 | intervals.select(dim=-1, index=0)) 84 | # shape: (batch_size, seq_len, num_intervals) 85 | in_interval = ((range_vec >= intervals[:, None, :, 0]) & 86 | (range_vec <= intervals[:, None, :, 1])) 87 | # shape: (batch_size, seq_len, num_intervals) 88 | projected_indices = torch.where( 89 | in_interval, sorted_indices[:, :, None], seq_len - 1 90 | ).sort(dim=1)[0] 91 | 92 | # shape (batch_size, num_intervals) 93 | projected_indices = projected_indices.take_along_dim( 94 | relative_indices[:, None, :], dim=1 95 | ).squeeze(1) 96 | 97 | # shape: (batch_size, num_intervals) 98 | # projected_values = values.take_along_dim(projected_indices, dim=1) 99 | projected_values = sorted_values.take_along_dim(projected_indices, dim=1) 100 | 101 | # shape: (batch_size, num_intervals) 102 | split_rank = intervals.select(dim=-1, index=2).long() 103 | 104 | # shape: (batch_size, num_intervals), avoid out of range 105 | safe_split_rank = torch.where( 106 | split_rank == 0, 1, split_rank 107 | ) 108 | # shape: (batch_size, num_intervals) 109 | split_thresholds = (sorted_values.gather(dim=1, index=safe_split_rank - 1) + 110 | sorted_values.gather(dim=1, index=safe_split_rank)) / 2.0 111 | # shape: (batch_size, num_intervals) 112 | # > 0: go right, < 0: go left 113 | split_logits = projected_values - split_thresholds 114 | 115 | return split_logits 116 | 117 | 118 | def logsoftperm( 119 | input: torch.Tensor, # shape: (*, num_elements) 120 | perm: torch.Tensor, # shape: (*, num_elements) 121 | mask: Optional[torch.BoolTensor] = None # shape: (*, num_elements) 122 | ): 123 | # Args: 124 | # input: (*, num_elements) 125 | # perm: (*, num_elements) 126 | # Returns: 127 | # output: (*, num_elements) 128 | max_value = input.max().detach() + 1.0 129 | if mask is not None: 130 | input = input.masked_fill(~mask, max_value) 131 | perm = perm.masked_fill(~mask, max_value) 132 | 133 | # shape: (*, num_elements) 134 | sorted_input, _ = input.sort(dim=-1) 135 | # shape: (*, num_elements, num_elements) 136 | logits_matrix = -torch.abs(perm[:, None, :] - sorted_input[:, :, None]) 137 | # shape: (*, num_elements) 138 | log_likelihood_perm = F.log_softmax(logits_matrix, dim=-1).diagonal(dim1=-2, dim2=-1) 139 | 140 | if mask is not None: 141 | log_likelihood_perm = log_likelihood_perm.masked_fill(~mask, 0.0) 142 | 143 | return log_likelihood_perm 144 | 145 | 146 | def _batched_index_select( 147 | target: torch.Tensor, 148 | indices: torch.LongTensor, 149 | flattened_indices: Optional[torch.LongTensor] = None, 150 | ) -> torch.Tensor: 151 | # Args: 152 | # target: (..., seq_len, ...) 153 | # indices: (..., num_indices) 154 | # Returns: 155 | # selected_targets (..., num_indices, ...) 156 | 157 | # dim is the index of the last dimension of indices 158 | dim = indices.dim() - 1 159 | unidim = False 160 | if target.dim() == 2: 161 | # (batch_size, sequence_length) -> (batch_size, sequence_length, 1) 162 | unidim = True 163 | target = target.unsqueeze(-1) 164 | 165 | target_size, indices_size = target.size(), indices.size() 166 | # flatten dimensions before dim, make a pseudo batch dimension 167 | indices = indices.view(math.prod([*indices_size[:dim]]), indices_size[dim]) 168 | 169 | if flattened_indices is None: 170 | # Shape: (batch_size * d_1 * ... * d_n) 171 | flattened_indices = flatten_and_batch_shift_indices( 172 | indices, target_size[dim] 173 | ) 174 | 175 | # Shape: (batch_size * sequence_length, embedding_size) 176 | flattened_target = target.reshape(-1, *target_size[dim + 1:]) 177 | # Shape: (batch_size * d_1 * ... * d_n, embedding_size) 178 | flattened_selected = flattened_target.index_select(0, flattened_indices) 179 | selected_shape = list(indices_size) + ([] if unidim else list(target_size[dim + 1:])) 180 | # Shape: (batch_size, d_1, ..., d_n, embedding_size) 181 | selected_targets = flattened_selected.reshape(*selected_shape) 182 | return selected_targets 183 | 184 | 185 | def flatten_and_batch_shift_indices( 186 | indices: torch.Tensor, sequence_length: int 187 | ) -> torch.Tensor: 188 | # Input: 189 | # indices: (batch_size, num_indices) 190 | # sequence_length: int, d_1*d_2*...*d_n 191 | # Returns: 192 | # offset_indices Shape: (batch_size*d_1*d_2*...*d_n) 193 | offsets = torch.arange(0, indices.size(0), dtype=torch.long, 194 | device=indices.device) * sequence_length 195 | 196 | for _ in range(indices.dim() - 1): 197 | offsets = offsets.unsqueeze(1) 198 | 199 | # Shape: (batch_size, d_1, ..., d_n) 200 | offset_indices = indices + offsets 201 | # Shape: (batch_size * d_1 * ... * d_n) 202 | offset_indices = offset_indices.view(-1) 203 | 204 | return offset_indices 205 | 206 | 207 | def onehot_with_ignore_label(labels, num_class, ignore_label): 208 | # One-hot encode the modified labels 209 | one_hot_labels = torch.nn.functional.one_hot( 210 | labels.masked_fill((labels == ignore_label), num_class), 211 | num_classes=num_class + 1 212 | ) 213 | # Remove the last row in the one-hot encoding 214 | # shape: (*, num_class+1 -> num_class) 215 | one_hot_labels = one_hot_labels[..., :-1] 216 | return one_hot_labels 217 | -------------------------------------------------------------------------------- /load_dataset.sh: -------------------------------------------------------------------------------- 1 | # Based on https://github.com/nikitakit/tetra-tagging/blob/master/examples/training.ipynb 2 | 3 | %%bash 4 | if [ ! -e self-attentive-parser ]; then 5 | git clone https://github.com/nikitakit/self-attentive-parser &> /dev/null 6 | fi 7 | rm -rf train dev test EVALB/ 8 | cp self-attentive-parser/data/02-21.10way.clean ./data/train 9 | cp self-attentive-parser/data/22.auto.clean ./data/dev 10 | cp self-attentive-parser/data/23.auto.clean ./data/test 11 | # The evalb program needs to be compiled 12 | cp -R self-attentive-parser/EVALB EVALB 13 | rm -rf self-attentive-parser 14 | cd EVALB && make &> /dev/null 15 | # To test that everything works as intended, we check that the F1 score when 16 | # comparing the dev set with itself is 100. 17 | ./evalb -p nk.prm ../data/dev ../data/dev | grep FMeasure | head -n 1 -------------------------------------------------------------------------------- /pat-models/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.0.0 2 | appnope==0.1.3 3 | argon2-cffi==21.3.0 4 | argon2-cffi-bindings==21.2.0 5 | asttokens==2.0.5 6 | attrs==21.4.0 7 | backcall==0.2.0 8 | beautifulsoup4==4.10.0 9 | bleach==4.1.0 10 | cachetools==5.0.0 11 | certifi==2021.10.8 12 | cffi==1.15.0 13 | charset-normalizer==2.0.12 14 | click==8.1.2 15 | debugpy==1.6.0 16 | decorator==5.1.1 17 | defusedxml==0.7.1 18 | entrypoints==0.4 19 | executing==0.8.3 20 | fastjsonschema==2.15.3 21 | filelock==3.6.0 22 | google-auth==2.6.2 23 | google-auth-oauthlib==0.4.6 24 | grpcio==1.44.0 25 | huggingface-hub==0.5.1 26 | idna==3.3 27 | importlib-metadata==4.11.3 28 | importlib-resources==5.6.0 29 | ipykernel==6.12.1 30 | ipython==8.2.0 31 | ipython-genutils==0.2.0 32 | ipywidgets==7.7.0 33 | jedi==0.18.1 34 | Jinja2==3.1.1 35 | joblib==1.1.1 36 | jsonschema==4.4.0 37 | jupyter==1.0.0 38 | jupyter-client==7.2.1 39 | jupyter-console==6.4.3 40 | jupyter-core==4.9.2 41 | jupyterlab-pygments==0.1.2 42 | jupyterlab-widgets==1.1.0 43 | Markdown==3.3.6 44 | MarkupSafe==2.1.1 45 | matplotlib-inline==0.1.3 46 | mistune==0.8.4 47 | nbclient==0.5.13 48 | nbconvert==6.4.5 49 | nbformat==5.3.0 50 | nest-asyncio==1.5.5 51 | nltk==3.7 52 | notebook==6.4.10 53 | numpy==1.22.3 54 | oauthlib==3.2.0 55 | packaging==21.3 56 | pandas==1.4.2 57 | pandocfilters==1.5.0 58 | parso==0.8.3 59 | pexpect==4.8.0 60 | pickleshare==0.7.5 61 | plotly==5.7.0 62 | prometheus-client==0.14.0 63 | prompt-toolkit==3.0.29 64 | protobuf==3.20.0 65 | psutil==5.9.0 66 | ptyprocess==0.7.0 67 | pure-eval==0.2.2 68 | pyasn1==0.4.8 69 | pyasn1-modules==0.2.8 70 | pycparser==2.21 71 | Pygments==2.11.2 72 | pyparsing==3.0.7 73 | pyrsistent==0.18.1 74 | python-dateutil==2.8.2 75 | pytz==2022.1 76 | PyYAML==6.0 77 | pyzmq==22.3.0 78 | qtconsole==5.3.0 79 | QtPy==2.0.1 80 | regex==2022.3.15 81 | requests==2.27.1 82 | requests-oauthlib==1.3.1 83 | rsa==4.8 84 | sacremoses==0.0.49 85 | scipy==1.8.1 86 | Send2Trash==1.8.0 87 | six==1.16.0 88 | soupsieve==2.3.1 89 | spicy==0.16.0 90 | stack-data==0.2.0 91 | tenacity==8.0.1 92 | tensorboard==2.8.0 93 | tensorboard-data-server==0.6.1 94 | tensorboard-plugin-wit==1.8.1 95 | terminado==0.13.3 96 | testpath==0.6.0 97 | tokenizers==0.11.6 98 | torch==1.11.0 99 | tornado==6.1 100 | tqdm==4.64.0 101 | traitlets==5.1.1 102 | transformers==4.18.0 103 | typing_extensions==4.1.1 104 | urllib3==1.26.9 105 | wcwidth==0.2.5 106 | webencodings==0.5.1 107 | Werkzeug==2.1.1 108 | widgetsnbextension==3.6.0 109 | zipp==3.8.0 110 | bitsandbytes==0.39.0 111 | scikit-learn==1.2.2 112 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import logging 4 | import pickle 5 | import random 6 | import sys 7 | import json 8 | 9 | import numpy as np 10 | import torch 11 | import transformers 12 | from bitsandbytes.optim import AdamW 13 | from nltk.corpus.reader.bracket_parse import BracketParseCorpusReader 14 | from torch.utils.data import DataLoader 15 | from torch.utils.tensorboard import SummaryWriter 16 | from tqdm import tqdm as tq 17 | 18 | from const import * 19 | from learning.dataset import TaggingDataset 20 | from learning.evaluate import predict, dependency_eval, calc_parse_eval, calc_tag_accuracy, dependency_decoding 21 | from learning.learn import ModelForTetratagging 22 | from tagging.hexatagger import HexaTagger 23 | 24 | # Set random seed 25 | RANDOM_SEED = 1 26 | torch.manual_seed(RANDOM_SEED) 27 | random.seed(RANDOM_SEED) 28 | np.random.seed(RANDOM_SEED) 29 | print('Random seed: {}'.format(RANDOM_SEED), file=sys.stderr) 30 | 31 | logging.basicConfig( 32 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 33 | datefmt='%m/%d/%Y %H:%M:%S', 34 | level=logging.INFO 35 | ) 36 | logger = logging.getLogger(__file__) 37 | 38 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 39 | 40 | parser = argparse.ArgumentParser() 41 | subparser = parser.add_subparsers(dest='command') 42 | train = subparser.add_parser('train') 43 | evaluate = subparser.add_parser('evaluate') 44 | predict_parser = subparser.add_parser('predict') 45 | vocab = subparser.add_parser('vocab') 46 | 47 | vocab.add_argument('--tagger', choices=[HEXATAGGER, TETRATAGGER, TD_SR, BU_SR], required=True, 48 | help="Tagging schema") 49 | vocab.add_argument('--lang', choices=LANG, default=ENG, help="Language") 50 | vocab.add_argument('--output-path', choices=[HEXATAGGER, TETRATAGGER, TD_SR, BU_SR], 51 | default="data/vocab/") 52 | 53 | train.add_argument('--tagger', choices=[HEXATAGGER, TETRATAGGER, TD_SR, BU_SR], required=True, 54 | help="Tagging schema") 55 | train.add_argument('--lang', choices=LANG, default=ENG, help="Language") 56 | train.add_argument('--tag-vocab-path', type=str, default="data/vocab/") 57 | train.add_argument('--model', choices=BERT, required=True, help="Model architecture") 58 | 59 | train.add_argument('--model-path', type=str, default='bertlarge', 60 | help="Bert model path or name, " 61 | "xlnet-large-cased for english, hfl/chinese-xlnet-mid for chinese") 62 | train.add_argument('--output-path', type=str, default='pat-models/', 63 | help="Path to save trained models") 64 | train.add_argument('--use-tensorboard', type=bool, default=False, 65 | help="Whether to use the tensorboard for logging the results make sure to " 66 | "add credentials to run.py if set to true") 67 | 68 | train.add_argument('--max-depth', type=int, default=10, 69 | help="Max stack depth used for decoding") 70 | train.add_argument('--keep-per-depth', type=int, default=1, 71 | help="Max elements to keep per depth") 72 | 73 | train.add_argument('--lr', type=float, default=2e-5) 74 | train.add_argument('--epochs', type=int, default=50) 75 | train.add_argument('--batch-size', type=int, default=32) 76 | train.add_argument('--num-warmup-steps', type=int, default=200) 77 | train.add_argument('--weight-decay', type=float, default=0.01) 78 | 79 | evaluate.add_argument('--model-name', type=str, required=True) 80 | evaluate.add_argument('--lang', choices=LANG, default=ENG, help="Language") 81 | evaluate.add_argument('--tagger', choices=[HEXATAGGER, TETRATAGGER, TD_SR, BU_SR], 82 | required=True, 83 | help="Tagging schema") 84 | evaluate.add_argument('--tag-vocab-path', type=str, default="data/vocab/") 85 | evaluate.add_argument('--model-path', type=str, default='pat-models/') 86 | evaluate.add_argument('--bert-model-path', type=str, default='mbert/') 87 | evaluate.add_argument('--output-path', type=str, default='results/') 88 | evaluate.add_argument('--batch-size', type=int, default=16) 89 | evaluate.add_argument('--max-depth', type=int, default=10, 90 | help="Max stack depth used for decoding") 91 | evaluate.add_argument('--is-greedy', type=bool, default=False, 92 | help="Whether or not to use greedy decoding") 93 | evaluate.add_argument('--keep-per-depth', type=int, default=1, 94 | help="Max elements to keep per depth") 95 | evaluate.add_argument('--use-tensorboard', type=bool, default=False, 96 | help="Whether to use the tensorboard for logging the results make sure " 97 | "to add credentials to run.py if set to true") 98 | 99 | 100 | predict_parser.add_argument('--model-name', type=str, required=True) 101 | predict_parser.add_argument('--lang', choices=LANG, default=ENG, help="Language") 102 | predict_parser.add_argument('--tagger', choices=[HEXATAGGER, TETRATAGGER, TD_SR, BU_SR], 103 | required=True, 104 | help="Tagging schema") 105 | predict_parser.add_argument('--tag-vocab-path', type=str, default="data/vocab/") 106 | predict_parser.add_argument('--model-path', type=str, default='pat-models/') 107 | predict_parser.add_argument('--bert-model-path', type=str, default='mbert/') 108 | predict_parser.add_argument('--output-path', type=str, default='results/') 109 | predict_parser.add_argument('--batch-size', type=int, default=16) 110 | predict_parser.add_argument('--max-depth', type=int, default=10, 111 | help="Max stack depth used for decoding") 112 | predict_parser.add_argument('--is-greedy', type=bool, default=False, 113 | help="Whether or not to use greedy decoding") 114 | predict_parser.add_argument('--keep-per-depth', type=int, default=1, 115 | help="Max elements to keep per depth") 116 | predict_parser.add_argument('--use-tensorboard', type=bool, default=False, 117 | help="Whether to use the tensorboard for logging the results make sure " 118 | "to add credentials to run.py if set to true") 119 | 120 | 121 | def initialize_tag_system(reader, tagging_schema, lang, tag_vocab_path="", 122 | add_remove_top=False): 123 | tag_vocab = None 124 | if tag_vocab_path != "": 125 | with open(tag_vocab_path + lang + "-" + tagging_schema + '.pkl', 'rb') as f: 126 | tag_vocab = pickle.load(f) 127 | 128 | if tagging_schema == HEXATAGGER: 129 | # for BHT 130 | tag_system = HexaTagger( 131 | trees=reader.parsed_sents(lang + '.bht.train'), 132 | tag_vocab=tag_vocab, add_remove_top=False 133 | ) 134 | else: 135 | logging.error("Please specify the tagging schema") 136 | return 137 | return tag_system 138 | 139 | 140 | def get_data_path(tagger): 141 | if tagger == HEXATAGGER: 142 | return DEP_DATA_PATH 143 | return DATA_PATH 144 | 145 | 146 | def save_vocab(args): 147 | data_path = get_data_path(args.tagger) 148 | if args.tagger == HEXATAGGER: 149 | prefix = args.lang + ".bht" 150 | else: 151 | prefix = args.lang 152 | reader = BracketParseCorpusReader( 153 | data_path, [prefix + '.train', prefix + '.dev', prefix + '.test']) 154 | tag_system = initialize_tag_system( 155 | reader, args.tagger, args.lang, add_remove_top=True) 156 | with open(args.output_path + args.lang + "-" + args.tagger + '.pkl', 'wb+') as f: 157 | pickle.dump(tag_system.tag_vocab, f) 158 | 159 | 160 | def prepare_training_data(reader, tag_system, tagging_schema, model_name, batch_size, lang): 161 | is_tetratags = True if tagging_schema == TETRATAGGER or tagging_schema == HEXATAGGER else False 162 | prefix = lang + ".bht" if tagging_schema == HEXATAGGER else lang 163 | 164 | tokenizer = transformers.AutoTokenizer.from_pretrained( 165 | model_name, truncation=True, use_fast=True) 166 | train_dataset = TaggingDataset(prefix + '.train', tokenizer, tag_system, reader, device, 167 | is_tetratags=is_tetratags, language=lang) 168 | eval_dataset = TaggingDataset(prefix + '.test', tokenizer, tag_system, reader, device, 169 | is_tetratags=is_tetratags, language=lang) 170 | train_dataloader = DataLoader( 171 | train_dataset, shuffle=True, batch_size=batch_size, collate_fn=train_dataset.collate, 172 | pin_memory=True 173 | ) 174 | eval_dataloader = DataLoader( 175 | eval_dataset, batch_size=batch_size, collate_fn=eval_dataset.collate, pin_memory=True 176 | ) 177 | return train_dataset, eval_dataset, train_dataloader, eval_dataloader 178 | 179 | 180 | def prepare_test_data(reader, tag_system, tagging_schema, model_name, batch_size, lang): 181 | is_tetratags = True if tagging_schema == TETRATAGGER or tagging_schema == HEXATAGGER else False 182 | prefix = lang + ".bht" if tagging_schema == HEXATAGGER else lang 183 | 184 | print(f"Evaluating {model_name}, {tagging_schema}") 185 | tokenizer = transformers.AutoTokenizer.from_pretrained( 186 | model_name, truncation=True, use_fast=True) 187 | test_dataset = TaggingDataset( 188 | prefix + '.test', tokenizer, tag_system, reader, device, 189 | is_tetratags=is_tetratags, language=lang 190 | ) 191 | test_dataloader = DataLoader( 192 | test_dataset, batch_size=batch_size, collate_fn=test_dataset.collate 193 | ) 194 | return test_dataset, test_dataloader 195 | 196 | 197 | def generate_config(model_type, tagging_schema, tag_system, model_path, is_eng): 198 | if model_type in BERTCRF or model_type in BERTLSTM: 199 | config = transformers.AutoConfig.from_pretrained( 200 | model_path, 201 | num_labels=2 * len(tag_system.tag_vocab), 202 | task_specific_params={ 203 | 'model_path': model_path, 204 | 'num_tags': len(tag_system.tag_vocab), 205 | 'is_eng': is_eng, 206 | } 207 | ) 208 | elif model_type in BERT and tagging_schema in [TETRATAGGER, HEXATAGGER]: 209 | config = transformers.AutoConfig.from_pretrained( 210 | model_path, 211 | num_labels=len(tag_system.tag_vocab), 212 | id2label={i: label for i, label in enumerate(tag_system.tag_vocab)}, 213 | label2id={label: i for i, label in enumerate(tag_system.tag_vocab)}, 214 | task_specific_params={ 215 | 'model_path': model_path, 216 | 'num_even_tags': tag_system.decode_moderator.leaf_tag_vocab_size, 217 | 'num_odd_tags': tag_system.decode_moderator.internal_tag_vocab_size, 218 | 'pos_emb_dim': 256, 219 | 'num_pos_tags': 50, 220 | 'lstm_layers': 3, 221 | 'dropout': 0.33, 222 | 'is_eng': is_eng, 223 | 'use_pos': True 224 | } 225 | ) 226 | elif model_type in BERT and tagging_schema != TETRATAGGER and tagging_schema != HEXATAGGER: 227 | config = transformers.AutoConfig.from_pretrained( 228 | model_path, 229 | num_labels=2 * len(tag_system.tag_vocab), 230 | task_specific_params={ 231 | 'model_path': model_path, 232 | 'num_even_tags': len(tag_system.tag_vocab), 233 | 'num_odd_tags': len(tag_system.tag_vocab), 234 | 'is_eng': is_eng 235 | } 236 | ) 237 | else: 238 | logging.error("Invalid combination of model type and tagging schema") 239 | return 240 | return config 241 | 242 | 243 | def initialize_model(model_type, tagging_schema, tag_system, model_path, is_eng): 244 | config = generate_config( 245 | model_type, tagging_schema, tag_system, model_path, is_eng 246 | ) 247 | if model_type in BERT: 248 | model = ModelForTetratagging(config=config) 249 | else: 250 | logging.error("Invalid model type") 251 | return 252 | return model 253 | 254 | 255 | def initialize_optimizer_and_scheduler(model, dataset_size, lr=5e-5, num_epochs=4, 256 | num_warmup_steps=160, weight_decay_rate=0.0): 257 | num_training_steps = num_epochs * dataset_size 258 | no_decay = ['bias', 'LayerNorm.weight', 'layer_norm.weight'] 259 | grouped_parameters = [ 260 | { 261 | "params": [p for n, p in model.named_parameters() if "bert" not in n], 262 | "weight_decay": 0.0, 263 | "lr": lr * 50, "betas": (0.9, 0.9), 264 | }, 265 | { 266 | "params": [p for n, p in model.named_parameters() if 267 | "bert" in n and any(nd in n for nd in no_decay)], 268 | "weight_decay": 0.0, 269 | "lr": lr, "betas": (0.9, 0.999), 270 | }, 271 | { 272 | "params": [p for n, p in model.named_parameters() if 273 | "bert" in n and not any(nd in n for nd in no_decay)], 274 | "weight_decay": 0.1, 275 | "lr": lr, "betas": (0.9, 0.999), 276 | }, 277 | ] 278 | 279 | optimizer = AdamW( 280 | grouped_parameters, lr=lr 281 | ) 282 | scheduler = transformers.get_linear_schedule_with_warmup( 283 | optimizer=optimizer, 284 | num_warmup_steps=num_warmup_steps, 285 | num_training_steps=num_training_steps 286 | ) 287 | 288 | return optimizer, scheduler, num_training_steps 289 | 290 | 291 | def register_run_metrics(writer, run_name, lr, epochs, eval_loss, even_tag_accuracy, 292 | odd_tag_accuracy): 293 | writer.add_hparams({'run_name': run_name, 'lr': lr, 'epochs': epochs}, 294 | {'eval_loss': eval_loss, 'odd_tag_accuracy': odd_tag_accuracy, 295 | 'even_tag_accuracy': even_tag_accuracy}) 296 | 297 | 298 | def train_command(args): 299 | if args.tagger == HEXATAGGER: 300 | prefix = args.lang + ".bht" 301 | else: 302 | prefix = args.lang 303 | data_path = get_data_path(args.tagger) 304 | reader = BracketParseCorpusReader(data_path, [prefix + '.train', prefix + '.dev', 305 | prefix + '.test']) 306 | logging.info("Initializing Tag System") 307 | tag_system = initialize_tag_system( 308 | reader, args.tagger, args.lang, 309 | tag_vocab_path=args.tag_vocab_path, add_remove_top=True 310 | ) 311 | logging.info("Preparing Data") 312 | train_dataset, eval_dataset, train_dataloader, eval_dataloader = prepare_training_data( 313 | reader, tag_system, args.tagger, args.model_path, args.batch_size, args.lang) 314 | logging.info("Initializing The Model") 315 | is_eng = True if args.lang == ENG else False 316 | model = initialize_model( 317 | args.model, args.tagger, tag_system, args.model_path, is_eng 318 | ) 319 | model.to(device) 320 | 321 | train_set_size = len(train_dataloader) 322 | optimizer, scheduler, num_training_steps = initialize_optimizer_and_scheduler( 323 | model, train_set_size, args.lr, args.epochs, 324 | args.num_warmup_steps, args.weight_decay 325 | ) 326 | optimizer.zero_grad() 327 | run_name = args.lang + "-" + args.tagger + "-" + args.model + "-" + str( 328 | args.lr) + "-" + str(args.epochs) 329 | writer = None 330 | if args.use_tensorboard: 331 | writer = SummaryWriter(comment=run_name) 332 | 333 | num_leaf_labels, num_tags = calc_num_tags_per_task(args.tagger, tag_system) 334 | 335 | logging.info("Starting The Training Loop") 336 | model.train() 337 | n_iter = 0 338 | 339 | last_fscore = 0 340 | best_fscore = 0 341 | tol = 99999 342 | 343 | for epo in tq(range(args.epochs)): 344 | logging.info(f"*******************EPOCH {epo}*******************") 345 | t = 1 346 | model.train() 347 | 348 | with tq(train_dataloader, disable=False) as progbar: 349 | for batch in progbar: 350 | batch = {k: v.to(device) for k, v in batch.items()} 351 | 352 | if device == "cuda": 353 | with torch.cuda.amp.autocast( 354 | enabled=True, dtype=torch.bfloat16 355 | ): 356 | outputs = model(**batch) 357 | else: 358 | with torch.cpu.amp.autocast( 359 | enabled=True, dtype=torch.bfloat16 360 | ): 361 | outputs = model(**batch) 362 | 363 | loss = outputs[0] 364 | loss.mean().backward() 365 | if args.use_tensorboard: 366 | writer.add_scalar('Loss/train', torch.mean(loss), n_iter) 367 | progbar.set_postfix(loss=torch.mean(loss).item()) 368 | 369 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 370 | optimizer.step() 371 | scheduler.step() 372 | optimizer.zero_grad() 373 | 374 | n_iter += 1 375 | t += 1 376 | 377 | if True: # evaluation at the end of epoch 378 | predictions, eval_labels = predict( 379 | model, eval_dataloader, len(eval_dataset), 380 | num_tags, args.batch_size, device 381 | ) 382 | calc_tag_accuracy( 383 | predictions, eval_labels, 384 | num_leaf_labels, writer, args.use_tensorboard) 385 | 386 | if args.tagger == HEXATAGGER: 387 | dev_metrics_las, dev_metrics_uas = dependency_eval( 388 | predictions, eval_labels, eval_dataset, 389 | tag_system, None, "", args.max_depth, 390 | args.keep_per_depth, False) 391 | else: 392 | dev_metrics = calc_parse_eval( 393 | predictions, eval_labels, eval_dataset, 394 | tag_system, None, "", 395 | args.max_depth, args.keep_per_depth, False) 396 | 397 | eval_loss = 0.5 398 | if args.tagger == HEXATAGGER: 399 | writer.add_scalar('LAS_Fscore/dev', 400 | dev_metrics_las.fscore, n_iter) 401 | writer.add_scalar('LAS_Precision/dev', 402 | dev_metrics_las.precision, n_iter) 403 | writer.add_scalar('LAS_Recall/dev', 404 | dev_metrics_las.recall, n_iter) 405 | writer.add_scalar('loss/dev', eval_loss, n_iter) 406 | 407 | logging.info("current LAS {}".format(dev_metrics_las)) 408 | logging.info("current UAS {}".format(dev_metrics_uas)) 409 | logging.info("last LAS fscore {}".format(last_fscore)) 410 | logging.info("best LAS fscore {}".format(best_fscore)) 411 | # setting main metric for model selection 412 | dev_metrics = dev_metrics_las 413 | else: 414 | writer.add_scalar('Fscore/dev', dev_metrics.fscore, n_iter) 415 | writer.add_scalar('Precision/dev', dev_metrics.precision, n_iter) 416 | writer.add_scalar('Recall/dev', dev_metrics.recall, n_iter) 417 | writer.add_scalar('loss/dev', eval_loss, n_iter) 418 | 419 | logging.info("current fscore {}".format(dev_metrics.fscore)) 420 | logging.info("last fscore {}".format(last_fscore)) 421 | logging.info("best fscore {}".format(best_fscore)) 422 | 423 | # if dev_metrics.fscore > last_fscore or dev_loss < last... 424 | if dev_metrics.fscore > last_fscore: 425 | tol = 5 426 | if dev_metrics.fscore > best_fscore: # if dev_metrics.fscore > best_fscore: 427 | logging.info("save the best model") 428 | best_fscore = dev_metrics.fscore 429 | _save_best_model(model, args.output_path, run_name) 430 | elif dev_metrics.fscore > 0: # dev_metrics.fscore 431 | tol -= 1 432 | 433 | if tol < 0: 434 | _finish_training(model, tag_system, eval_dataloader, 435 | eval_dataset, eval_loss, run_name, writer, args) 436 | return 437 | if dev_metrics.fscore > 0: # not propagating the nan 438 | last_eval_loss = eval_loss 439 | last_fscore = dev_metrics.fscore 440 | 441 | # if dev_metrics.fscore > last_fscore or dev_loss < last... 442 | if dev_metrics.fscore > best_fscore: 443 | tol = 99999 444 | logging.info("tol refill") 445 | logging.info("save the best model") 446 | best_eval_loss = eval_loss 447 | best_fscore = dev_metrics.fscore 448 | _save_best_model(model, args.output_path, run_name) 449 | elif eval_loss > 0: 450 | tol -= 1 451 | 452 | if tol < 0: 453 | _finish_training(model, tag_system, eval_dataloader, 454 | eval_dataset, eval_loss, run_name, writer, args) 455 | return 456 | if eval_loss > 0: # not propagating the nan 457 | last_eval_loss = eval_loss 458 | # end of epoch 459 | pass 460 | 461 | _finish_training(model, tag_system, eval_dataloader, eval_dataset, eval_loss, 462 | run_name, writer, args) 463 | 464 | 465 | def _save_best_model(model, output_path, run_name): 466 | logging.info("Saving The Newly Found Best Model") 467 | os.makedirs(output_path, exist_ok=True) 468 | to_save_file = os.path.join(output_path, run_name) 469 | torch.save(model.state_dict(), to_save_file) 470 | 471 | 472 | def _finish_training(model, tag_system, eval_dataloader, eval_dataset, eval_loss, 473 | run_name, writer, args): 474 | num_leaf_labels, num_tags = calc_num_tags_per_task(args.tagger, tag_system) 475 | predictions, eval_labels = predict(model, eval_dataloader, len(eval_dataset), 476 | num_tags, args.batch_size, 477 | device) 478 | even_acc, odd_acc = calc_tag_accuracy(predictions, eval_labels, num_leaf_labels, writer, 479 | args.use_tensorboard) 480 | register_run_metrics(writer, run_name, args.lr, 481 | args.epochs, eval_loss, even_acc, odd_acc) 482 | 483 | 484 | def decode_model_name(model_name): 485 | name_chunks = model_name.split("-") 486 | name_chunks = name_chunks[1:] 487 | if name_chunks[0] == "td" or name_chunks[0] == "bu": 488 | tagging_schema = name_chunks[0] + "-" + name_chunks[1] 489 | model_type = name_chunks[2] 490 | else: 491 | tagging_schema = name_chunks[0] 492 | model_type = name_chunks[1] 493 | return tagging_schema, model_type 494 | 495 | 496 | def calc_num_tags_per_task(tagging_schema, tag_system): 497 | if tagging_schema == TETRATAGGER or tagging_schema == HEXATAGGER: 498 | num_leaf_labels = tag_system.decode_moderator.leaf_tag_vocab_size 499 | num_tags = len(tag_system.tag_vocab) 500 | else: 501 | num_leaf_labels = len(tag_system.tag_vocab) 502 | num_tags = 2 * len(tag_system.tag_vocab) 503 | return num_leaf_labels, num_tags 504 | 505 | 506 | def evaluate_command(args): 507 | tagging_schema, model_type = decode_model_name(args.model_name) 508 | data_path = get_data_path(tagging_schema) # HexaTagger or others 509 | print("Evaluation Args", args) 510 | if args.tagger == HEXATAGGER: 511 | prefix = args.lang + ".bht" 512 | else: 513 | prefix = args.lang 514 | reader = BracketParseCorpusReader( 515 | data_path, 516 | [prefix + '.train', prefix + '.dev', prefix + '.test']) 517 | writer = SummaryWriter(comment=args.model_name) 518 | logging.info("Initializing Tag System") 519 | tag_system = initialize_tag_system(reader, tagging_schema, args.lang, 520 | tag_vocab_path=args.tag_vocab_path, 521 | add_remove_top=True) 522 | logging.info("Preparing Data") 523 | eval_dataset, eval_dataloader = prepare_test_data( 524 | reader, tag_system, tagging_schema, 525 | args.bert_model_path, args.batch_size, 526 | args.lang) 527 | 528 | is_eng = True if args.lang == ENG else False 529 | model = initialize_model( 530 | model_type, tagging_schema, tag_system, args.bert_model_path, is_eng 531 | ) 532 | model.load_state_dict(torch.load(args.model_path + args.model_name)) 533 | model.to(device) 534 | 535 | num_leaf_labels, num_tags = calc_num_tags_per_task(tagging_schema, tag_system) 536 | 537 | predictions, eval_labels = predict( 538 | model, eval_dataloader, len(eval_dataset), 539 | num_tags, args.batch_size, device) 540 | calc_tag_accuracy(predictions, eval_labels, 541 | num_leaf_labels, writer, args.use_tensorboard) 542 | if tagging_schema == HEXATAGGER: 543 | dev_metrics_las, dev_metrics_uas = dependency_eval( 544 | predictions, eval_labels, eval_dataset, 545 | tag_system, args.output_path, args.model_name, 546 | args.max_depth, args.keep_per_depth, False) 547 | print( 548 | "LAS: ", dev_metrics_las, "\n", 549 | "UAS: ", dev_metrics_uas, sep="" 550 | ) 551 | else: 552 | parse_metrics = calc_parse_eval(predictions, eval_labels, eval_dataset, tag_system, 553 | args.output_path, 554 | args.model_name, 555 | args.max_depth, 556 | args.keep_per_depth, 557 | args.is_greedy) 558 | print(parse_metrics) 559 | 560 | 561 | def predict_command(args): 562 | tagging_schema, model_type = decode_model_name(args.model_name) 563 | data_path = get_data_path(tagging_schema) # HexaTagger or others 564 | print("predict Args", args) 565 | 566 | if args.tagger == HEXATAGGER: 567 | prefix = args.lang + ".bht" 568 | else: 569 | prefix = args.lang 570 | reader = BracketParseCorpusReader(data_path, []) 571 | writer = SummaryWriter(comment=args.model_name) 572 | logging.info("Initializing Tag System") 573 | tag_system = initialize_tag_system(None, tagging_schema, args.lang, 574 | tag_vocab_path=args.tag_vocab_path, 575 | add_remove_top=True) 576 | logging.info("Preparing Data") 577 | eval_dataset, eval_dataloader = prepare_test_data( 578 | reader, tag_system, tagging_schema, 579 | args.bert_model_path, args.batch_size, 580 | "input") 581 | 582 | is_eng = True if args.lang == ENG else False 583 | model = initialize_model( 584 | model_type, tagging_schema, tag_system, args.bert_model_path, is_eng 585 | ) 586 | model.load_state_dict(torch.load(args.model_path + args.model_name)) 587 | model.to(device) 588 | 589 | num_leaf_labels, num_tags = calc_num_tags_per_task(tagging_schema, tag_system) 590 | 591 | predictions, eval_labels = predict( 592 | model, eval_dataloader, len(eval_dataset), 593 | num_tags, args.batch_size, device) 594 | 595 | if tagging_schema == HEXATAGGER: 596 | pred_output = dependency_decoding( 597 | predictions, eval_labels, eval_dataset, 598 | tag_system, args.output_path, args.model_name, 599 | args.max_depth, args.keep_per_depth, False) 600 | 601 | with open(args.output_path + args.model_name + ".pred.json", "w") as fout: 602 | print("Saving predictions to", args.output_path + args.model_name + ".pred.json") 603 | json.dump(pred_output, fout) 604 | 605 | """pred_output contains: { 606 | "predicted_dev_triples": predicted_dev_triples, 607 | "predicted_hexa_tags": pred_hexa_tags 608 | }""" 609 | 610 | 611 | 612 | 613 | def main(): 614 | args = parser.parse_args() 615 | if args.command == 'train': 616 | train_command(args) 617 | elif args.command == 'evaluate': 618 | evaluate_command(args) 619 | elif args.command == 'predict': 620 | predict_command(args) 621 | elif args.command == 'vocab': 622 | save_vocab(args) 623 | 624 | 625 | if __name__ == '__main__': 626 | main() 627 | -------------------------------------------------------------------------------- /tagging/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rycolab/parsing-as-tagging/e7a0be2d92dc44c8c5920a33db10dcfab8573dc1/tagging/__init__.py -------------------------------------------------------------------------------- /tagging/hexatagger.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | from nltk import ParentedTree as PTree 4 | from nltk import Tree 5 | 6 | from const import DUMMY_LABEL 7 | from tagging.tetratagger import BottomUpTetratagger 8 | from tagging.transform import RightCornerTransformer 9 | from tagging.tree_tools import debinarize_lex_tree, expand_unary 10 | 11 | 12 | class HexaTagger(BottomUpTetratagger, ABC): 13 | def preprocess(self, original_tree: Tree) -> PTree: 14 | tree = original_tree.copy(deep=True) 15 | tree.collapse_unary(collapsePOS=True, collapseRoot=True) 16 | 17 | ptree = PTree.convert(tree) 18 | root_label = ptree.label() 19 | tree_lc = PTree(root_label, []) 20 | RightCornerTransformer.transform(tree_lc, ptree, ptree) 21 | return tree_lc 22 | 23 | @staticmethod 24 | def create_shift_tag(label: str, left_or_right: str) -> str: 25 | arc_label = label.split("^^^")[-1] 26 | arc_label = arc_label.split("+")[0] 27 | return left_or_right + "/" + arc_label 28 | 29 | @staticmethod 30 | def _create_bi_reduce_tag(label: str, left_or_right: str) -> str: 31 | label = label.split("\\")[1] 32 | head_idx = label.split("^^^")[-1] 33 | # label = label.split("^^^")[0] 34 | if label.find("|") != -1: # drop extra node labels created after binarization 35 | return f'{left_or_right}' + "/" + f"{DUMMY_LABEL}^^^{head_idx}" 36 | else: 37 | return f'{left_or_right}' + "/" + label.replace("+", "/") 38 | 39 | @staticmethod 40 | def _create_unary_reduce_tag(label: str, left_or_right: str) -> str: 41 | label = label.split("\\")[0] 42 | head_idx = label.split("^^^")[-1] 43 | # label = label.split("^^^")[0] 44 | if label.find("|") != -1: # drop extra node labels created after binarization 45 | return f'{left_or_right}' + f"/{DUMMY_LABEL}^^^{head_idx}" 46 | else: 47 | return f'{left_or_right}' + "/" + label 48 | 49 | def tags_to_tree_pipeline(self, tags: [str], input_seq: []) -> Tree: 50 | ptree = self.tags_to_tree(tags, input_seq) 51 | return self.postprocess(ptree) 52 | 53 | @staticmethod 54 | def _create_pre_terminal_label(tag: str, default="X") -> str: 55 | arc_label = tag.split("/")[1] 56 | return f"X^^^{arc_label}+" 57 | 58 | @staticmethod 59 | def _create_unary_reduce_label(tag: str) -> str: 60 | idx = tag.find("/") 61 | if idx == -1: 62 | return DUMMY_LABEL 63 | return tag[idx + 1:].replace("/", "+") 64 | 65 | @staticmethod 66 | def _create_reduce_label(tag: str) -> str: 67 | idx = tag.find("/") 68 | if idx == -1: 69 | label = "X\\|" # to mark the second part as an extra node created via binarizaiton 70 | else: 71 | label = "X\\" + tag[idx + 1:].replace("/", "+") 72 | return label 73 | 74 | def postprocess(self, transformed_tree: PTree) -> Tree: 75 | tree = PTree("X", ["", ""]) 76 | tree = RightCornerTransformer.rev_transform(tree, transformed_tree) 77 | tree = Tree.convert(tree) 78 | if len(tree.leaves()) == 1: 79 | expand_unary(tree) 80 | # edge case with one node 81 | return tree 82 | 83 | debinarized_tree = Tree(tree.label(), []) 84 | debinarize_lex_tree(tree, debinarized_tree) 85 | expand_unary(debinarized_tree) 86 | # debinarized_tree.pretty_print() 87 | return debinarized_tree 88 | -------------------------------------------------------------------------------- /tagging/srtagger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from abc import ABC 3 | 4 | from nltk import ParentedTree as PTree 5 | from tqdm import tqdm as tq 6 | 7 | from learning.decode import BeamSearch, GreedySearch 8 | from tagging.tagger import Tagger, TagDecodeModerator 9 | from tagging.transform import LeftCornerTransformer 10 | 11 | import numpy as np 12 | 13 | from tagging.tree_tools import find_node_type, NodeType 14 | 15 | 16 | class SRTagDecodeModerator(TagDecodeModerator, ABC): 17 | def __init__(self, tag_vocab): 18 | super().__init__(tag_vocab) 19 | self.reduce_tag_size = len([tag for tag in tag_vocab if tag.startswith("r")]) 20 | self.shift_tag_size = len([tag for tag in tag_vocab if tag.startswith("s")]) 21 | 22 | self.rr_tag_size = len([tag for tag in tag_vocab if tag.startswith("rr")]) 23 | self.sr_tag_size = len([tag for tag in tag_vocab if tag.startswith("sr")]) 24 | 25 | self.rl_tag_size = self.reduce_tag_size - self.rr_tag_size # left reduce tag size 26 | self.sl_tag_size = self.shift_tag_size - self.sr_tag_size # left shift tag size 27 | 28 | self.mask_binarize = True 29 | 30 | def mask_scores_for_binarization(self, labels, scores) -> []: 31 | raise NotImplementedError 32 | 33 | 34 | class BUSRTagDecodeModerator(SRTagDecodeModerator): 35 | def __init__(self, tag_vocab): 36 | super().__init__(tag_vocab) 37 | stack_depth_change_by_id = [None] * len(tag_vocab) 38 | for i, tag in enumerate(tag_vocab): 39 | if tag.startswith("s"): 40 | stack_depth_change_by_id[i] = +1 41 | elif tag.startswith("r"): 42 | stack_depth_change_by_id[i] = -1 43 | assert None not in stack_depth_change_by_id 44 | self.stack_depth_change_by_id = np.array( 45 | stack_depth_change_by_id, dtype=int) 46 | 47 | self.reduce_only_mask = np.full((len(tag_vocab),), -np.inf) 48 | self.shift_only_mask = np.full((len(tag_vocab),), -np.inf) 49 | self.reduce_only_mask[:self.reduce_tag_size] = 0.0 50 | self.shift_only_mask[-self.shift_tag_size:] = 0.0 51 | 52 | def mask_scores_for_binarization(self, labels, scores) -> []: 53 | # after rr(right) -> only reduce, after rl(left) -> only shift 54 | # after sr(right) -> only reduce, after sl(left) -> only shift 55 | mask1 = np.where( 56 | (labels[:, None] >= self.rl_tag_size) & (labels[:, None] < self.reduce_tag_size), 57 | self.reduce_only_mask, 0.0) 58 | mask2 = np.where(labels[:, None] < self.rl_tag_size, self.shift_only_mask, 0.0) 59 | mask3 = np.where( 60 | labels[:, None] >= (self.sl_tag_size + self.reduce_tag_size), 61 | self.reduce_only_mask, 0.0) 62 | mask4 = np.where((labels[:, None] >= self.reduce_tag_size) & ( 63 | labels[:, None] < (self.sl_tag_size + self.reduce_tag_size)), 64 | self.shift_only_mask, 0.0) 65 | all_new_scores = scores + mask1 + mask2 + mask3 + mask4 66 | return all_new_scores 67 | 68 | 69 | class TDSRTagDecodeModerator(SRTagDecodeModerator): 70 | def __init__(self, tag_vocab): 71 | super().__init__(tag_vocab) 72 | is_shift_mask = np.concatenate( 73 | [ 74 | np.zeros(self.reduce_tag_size), 75 | np.ones(self.shift_tag_size), 76 | ] 77 | ) 78 | self.reduce_tags_only = np.asarray(-1e9 * is_shift_mask, dtype=float) 79 | 80 | stack_depth_change_by_id = [None] * len(tag_vocab) 81 | stack_depth_change_by_id_l2 = [None] * len(tag_vocab) 82 | for i, tag in enumerate(tag_vocab): 83 | if tag.startswith("s"): 84 | stack_depth_change_by_id_l2[i] = 0 85 | stack_depth_change_by_id[i] = -1 86 | elif tag.startswith("r"): 87 | stack_depth_change_by_id_l2[i] = -1 88 | stack_depth_change_by_id[i] = +2 89 | assert None not in stack_depth_change_by_id 90 | assert None not in stack_depth_change_by_id_l2 91 | self.stack_depth_change_by_id = np.array( 92 | stack_depth_change_by_id, dtype=int) 93 | self.stack_depth_change_by_id_l2 = np.array( 94 | stack_depth_change_by_id_l2, dtype=int) 95 | self._initialize_binarize_mask(tag_vocab) 96 | 97 | def _initialize_binarize_mask(self, tag_vocab) -> None: 98 | self.right_only_mask = np.full((len(tag_vocab),), -np.inf) 99 | self.left_only_mask = np.full((len(tag_vocab),), -np.inf) 100 | 101 | self.right_only_mask[self.rl_tag_size:self.reduce_tag_size] = 0.0 102 | self.right_only_mask[-self.sr_tag_size:] = 0.0 103 | 104 | self.left_only_mask[:self.rl_tag_size] = 0.0 105 | self.left_only_mask[ 106 | self.reduce_tag_size:self.reduce_tag_size + self.sl_tag_size] = 0.0 107 | 108 | def mask_scores_for_binarization(self, labels, scores) -> []: 109 | # if shift -> rr and sr, if reduce -> rl and sl 110 | mask1 = np.where(labels[:, None] >= self.reduce_tag_size, self.right_only_mask, 0.0) 111 | mask2 = np.where(labels[:, None] < self.reduce_tag_size, self.left_only_mask, 0.0) 112 | all_new_scores = scores + mask1 + mask2 113 | return all_new_scores 114 | 115 | 116 | class SRTagger(Tagger, ABC): 117 | def __init__(self, trees=None, tag_vocab=None, add_remove_top=False): 118 | super().__init__(trees, tag_vocab, add_remove_top) 119 | 120 | def add_trees_to_vocab(self, trees: []) -> None: 121 | self.label_vocab = set() 122 | for tree in tq(trees): 123 | for tag in self.tree_to_tags_pipeline(tree)[0]: 124 | self.tag_vocab.add(tag) 125 | idx = tag.find("/") 126 | if idx != -1: 127 | self.label_vocab.add(tag[idx + 1:]) 128 | else: 129 | self.label_vocab.add("") 130 | self.tag_vocab = sorted(self.tag_vocab) 131 | self.label_vocab = sorted(self.label_vocab) 132 | 133 | @staticmethod 134 | def create_shift_tag(label: str, is_right_child=False) -> str: 135 | suffix = "r" if is_right_child else "" 136 | if label.find("+") != -1: 137 | tag = "s" + suffix + "/" + "/".join(label.split("+")[:-1]) 138 | else: 139 | tag = "s" + suffix 140 | return tag 141 | 142 | @staticmethod 143 | def create_shift_label(tag: str) -> str: 144 | idx = tag.find("/") 145 | if idx != -1: 146 | return tag[idx + 1:].replace("/", "+") + "+" 147 | else: 148 | return "" 149 | 150 | @staticmethod 151 | def create_reduce_tag(label: str, is_right_child=False) -> str: 152 | if label.find("|") != -1: # drop extra node labels created after binarization 153 | tag = "r" 154 | else: 155 | tag = "r" + "/" + label.replace("+", "/") 156 | return "r" + tag if is_right_child else tag 157 | 158 | @staticmethod 159 | def _create_reduce_label(tag: str) -> str: 160 | idx = tag.find("/") 161 | if idx == -1: 162 | label = "|" # to mark the second part as an extra node created via binarizaiton 163 | else: 164 | label = tag[idx + 1:].replace("/", "+") 165 | return label 166 | 167 | 168 | class SRTaggerBottomUp(SRTagger): 169 | def __init__(self, trees=None, tag_vocab=None, add_remove_top=False): 170 | super().__init__(trees, tag_vocab, add_remove_top) 171 | self.decode_moderator = BUSRTagDecodeModerator(self.tag_vocab) 172 | 173 | def tree_to_tags(self, root: PTree) -> ([str], int): 174 | tags = [] 175 | lc = LeftCornerTransformer.extract_left_corner_no_eps(root) 176 | if len(root) == 1: # edge case 177 | tags.append(self.create_shift_tag(lc.label(), False)) 178 | return tags, 1 179 | 180 | is_right_child = lc.left_sibling() is not None 181 | tags.append(self.create_shift_tag(lc.label(), is_right_child)) 182 | 183 | logging.debug("SHIFT {}".format(lc.label())) 184 | stack = [lc] 185 | max_stack_len = 1 186 | 187 | while len(stack) > 0: 188 | node = stack[-1] 189 | max_stack_len = max(max_stack_len, len(stack)) 190 | 191 | if node.left_sibling() is None and node.right_sibling() is not None: 192 | lc = LeftCornerTransformer.extract_left_corner_no_eps(node.right_sibling()) 193 | stack.append(lc) 194 | logging.debug("SHIFT {}".format(lc.label())) 195 | is_right_child = lc.left_sibling() is not None 196 | tags.append(self.create_shift_tag(lc.label(), is_right_child)) 197 | 198 | elif len(stack) >= 2 and ( 199 | node.right_sibling() == stack[-2] or node.left_sibling() == stack[-2]): 200 | prev_node = stack[-2] 201 | logging.debug("REDUCE[ {0} {1} --> {2} ]".format( 202 | *(prev_node.label(), node.label(), node.parent().label()))) 203 | 204 | parent_is_right = node.parent().left_sibling() is not None 205 | tags.append(self.create_reduce_tag(node.parent().label(), parent_is_right)) 206 | stack.pop() 207 | stack.pop() 208 | stack.append(node.parent()) 209 | 210 | elif stack[0].parent() is None and len(stack) == 1: 211 | stack.pop() 212 | continue 213 | return tags, max_stack_len 214 | 215 | def tags_to_tree(self, tags: [str], input_seq: [str]) -> PTree: 216 | created_node_stack = [] 217 | node = None 218 | 219 | if len(tags) == 1: # base case 220 | assert tags[0].startswith('s') 221 | prefix = self.create_shift_label(tags[0]) 222 | return PTree(prefix + input_seq[0][1], [input_seq[0][0]]) 223 | for tag in tags: 224 | if tag.startswith('s'): 225 | prefix = self.create_shift_label(tag) 226 | created_node_stack.append(PTree(prefix + input_seq[0][1], [input_seq[0][0]])) 227 | input_seq.pop(0) 228 | else: 229 | last_node = created_node_stack.pop() 230 | last_2_node = created_node_stack.pop() 231 | node = PTree(self._create_reduce_label(tag), [last_2_node, last_node]) 232 | created_node_stack.append(node) 233 | 234 | if len(input_seq) != 0: 235 | raise ValueError("All the input sequence is not used") 236 | return node 237 | 238 | def logits_to_ids(self, logits: [], mask, max_depth, keep_per_depth, crf_transitions=None, 239 | is_greedy=False) -> [int]: 240 | if is_greedy: 241 | searcher = GreedySearch( 242 | self.decode_moderator, 243 | initial_stack_depth=0, 244 | crf_transitions=crf_transitions, 245 | max_depth=max_depth, 246 | keep_per_depth=keep_per_depth, 247 | ) 248 | else: 249 | 250 | searcher = BeamSearch( 251 | self.decode_moderator, 252 | initial_stack_depth=0, 253 | crf_transitions=crf_transitions, 254 | max_depth=max_depth, 255 | keep_per_depth=keep_per_depth, 256 | ) 257 | 258 | last_t = None 259 | seq_len = sum(mask) 260 | idx = 1 261 | for t in range(logits.shape[0]): 262 | if mask is not None and not mask[t]: 263 | continue 264 | if last_t is not None: 265 | searcher.advance( 266 | logits[last_t, :-len(self.tag_vocab)] 267 | ) 268 | if idx == seq_len: 269 | searcher.advance(logits[t, -len(self.tag_vocab):], is_last=True) 270 | else: 271 | searcher.advance(logits[t, -len(self.tag_vocab):]) 272 | last_t = t 273 | 274 | score, best_tag_ids = searcher.get_path() 275 | return best_tag_ids 276 | 277 | 278 | class SRTaggerTopDown(SRTagger): 279 | def __init__(self, trees=None, tag_vocab=None, add_remove_top=False): 280 | super().__init__(trees, tag_vocab, add_remove_top) 281 | self.decode_moderator = TDSRTagDecodeModerator(self.tag_vocab) 282 | 283 | def tree_to_tags(self, root: PTree) -> ([str], int): 284 | stack: [PTree] = [root] 285 | max_stack_len = 1 286 | tags = [] 287 | 288 | while len(stack) > 0: 289 | node = stack[-1] 290 | max_stack_len = max(max_stack_len, len(stack)) 291 | 292 | if find_node_type(node) == NodeType.NT: 293 | stack.pop() 294 | logging.debug("REDUCE[ {0} --> {1} {2}]".format( 295 | *(node.label(), node[0].label(), node[1].label()))) 296 | is_right_node = node.left_sibling() is not None 297 | tags.append(self.create_reduce_tag(node.label(), is_right_node)) 298 | stack.append(node[1]) 299 | stack.append(node[0]) 300 | 301 | else: 302 | logging.debug("-->\tSHIFT[ {0} ]".format(node.label())) 303 | is_right_node = node.left_sibling() is not None 304 | tags.append(self.create_shift_tag(node.label(), is_right_node)) 305 | stack.pop() 306 | 307 | return tags, max_stack_len 308 | 309 | def tags_to_tree(self, tags: [str], input_seq: [str]) -> PTree: 310 | if len(tags) == 1: # base case 311 | assert tags[0].startswith('s') 312 | prefix = self.create_shift_label(tags[0]) 313 | return PTree(prefix + input_seq[0][1], [input_seq[0][0]]) 314 | 315 | assert tags[0].startswith('r') 316 | node = PTree(self._create_reduce_label(tags[0]), []) 317 | created_node_stack: [PTree] = [node] 318 | 319 | for tag in tags[1:]: 320 | parent: PTree = created_node_stack[-1] 321 | if tag.startswith('s'): 322 | prefix = self.create_shift_label(tag) 323 | new_node = PTree(prefix + input_seq[0][1], [input_seq[0][0]]) 324 | input_seq.pop(0) 325 | else: 326 | label = self._create_reduce_label(tag) 327 | new_node = PTree(label, []) 328 | 329 | if len(parent) == 0: 330 | parent.insert(0, new_node) 331 | elif len(parent) == 1: 332 | parent.insert(1, new_node) 333 | created_node_stack.pop() 334 | 335 | if tag.startswith('r'): 336 | created_node_stack.append(new_node) 337 | 338 | if len(input_seq) != 0: 339 | raise ValueError("All the input sequence is not used") 340 | return node 341 | 342 | def logits_to_ids(self, logits: [], mask, max_depth, keep_per_depth, crf_transitions=None, 343 | is_greedy=False) -> [int]: 344 | if is_greedy: 345 | searcher = GreedySearch( 346 | self.decode_moderator, 347 | initial_stack_depth=1, 348 | crf_transitions=crf_transitions, 349 | max_depth=max_depth, 350 | min_depth=0, 351 | keep_per_depth=keep_per_depth, 352 | ) 353 | else: 354 | searcher = BeamSearch( 355 | self.decode_moderator, 356 | initial_stack_depth=1, 357 | crf_transitions=crf_transitions, 358 | max_depth=max_depth, 359 | min_depth=0, 360 | keep_per_depth=keep_per_depth, 361 | ) 362 | 363 | last_t = None 364 | seq_len = sum(mask) 365 | idx = 1 366 | is_last = False 367 | for t in range(logits.shape[0]): 368 | if mask is not None and not mask[t]: 369 | continue 370 | if last_t is not None: 371 | searcher.advance( 372 | logits[last_t, :-len(self.tag_vocab)] 373 | ) 374 | if idx == seq_len: 375 | is_last = True 376 | if last_t is None: 377 | searcher.advance( 378 | logits[t, -len(self.tag_vocab):] + self.decode_moderator.reduce_tags_only, 379 | is_last=is_last) 380 | else: 381 | searcher.advance(logits[t, -len(self.tag_vocab):], is_last=is_last) 382 | last_t = t 383 | 384 | score, best_tag_ids = searcher.get_path(required_stack_depth=0) 385 | return best_tag_ids 386 | -------------------------------------------------------------------------------- /tagging/tagger.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | from nltk import ParentedTree as PTree 4 | from nltk import Tree 5 | from tqdm import tqdm as tq 6 | 7 | from tagging.tree_tools import add_plus_to_tree, remove_plus_from_tree 8 | 9 | 10 | class TagDecodeModerator(ABC): 11 | def __init__(self, tag_vocab): 12 | self.vocab_len = len(tag_vocab) 13 | self.stack_depth_change_by_id = None 14 | self.stack_depth_change_by_id_l2 = None 15 | 16 | 17 | class Tagger(ABC): 18 | def __init__(self, trees=None, tag_vocab=None, add_remove_top=False): 19 | self.tag_vocab = set() 20 | self.add_remove_top = add_remove_top 21 | 22 | if tag_vocab is not None: 23 | self.tag_vocab = tag_vocab 24 | elif trees is not None: 25 | self.add_trees_to_vocab(trees) 26 | 27 | self.decode_moderator = None 28 | 29 | def add_trees_to_vocab(self, trees: []) -> None: 30 | for tree in tq(trees): 31 | tags = self.tree_to_tags_pipeline(tree)[0] 32 | for tag in tags: 33 | self.tag_vocab.add(tag) 34 | self.tag_vocab = sorted(self.tag_vocab) 35 | 36 | def tree_to_tags(self, root: PTree) -> [str]: 37 | raise NotImplementedError("tree to tags is not implemented") 38 | 39 | def tags_to_tree(self, tags: [str], input_seq: [str]) -> PTree: 40 | raise NotImplementedError("tags to tree is not implemented") 41 | 42 | def tree_to_tags_pipeline(self, tree: Tree) -> ([str], int): 43 | ptree = self.preprocess(tree) 44 | return self.tree_to_tags(ptree) 45 | 46 | def tree_to_ids_pipeline(self, tree: Tree) -> [int]: 47 | tags = self.tree_to_tags_pipeline(tree)[0] 48 | res = [] 49 | for tag in tags: 50 | if tag in self.tag_vocab: 51 | res.append(self.tag_vocab.index(tag)) 52 | elif tag[0:tag.find("/")] in self.tag_vocab: 53 | res.append(self.tag_vocab.index(tag[0:tag.find("/")])) 54 | else: 55 | res.append(0) 56 | 57 | return res 58 | 59 | def tags_to_tree_pipeline(self, tags: [str], input_seq: []) -> Tree: 60 | filtered_input_seq = [] 61 | for token, pos in input_seq: 62 | filtered_input_seq.append((token, pos.replace("+", "@"))) 63 | ptree = self.tags_to_tree(tags, filtered_input_seq) 64 | return self.postprocess(ptree) 65 | 66 | def ids_to_tree_pipeline(self, ids: [int], input_seq: []) -> Tree: 67 | tags = [self.tag_vocab[idx] for idx in ids] 68 | return self.tags_to_tree_pipeline(tags, input_seq) 69 | 70 | def logits_to_ids(self, logits: [], mask, max_depth, keep_per_depth, is_greedy=False) -> [int]: 71 | raise NotImplementedError("logits to ids is not implemented") 72 | 73 | def logits_to_tree(self, logits: [], leave_nodes: [], mask=None, max_depth=5, keep_per_depth=1, is_greedy=False) -> Tree: 74 | ids = self.logits_to_ids(logits, mask, max_depth, keep_per_depth, is_greedy=is_greedy) 75 | return self.ids_to_tree_pipeline(ids, leave_nodes) 76 | 77 | def preprocess(self, original_tree: Tree) -> PTree: 78 | tree = original_tree.copy(deep=True) 79 | if self.add_remove_top: 80 | cut_off_tree = tree[0] 81 | else: 82 | cut_off_tree = tree 83 | 84 | remove_plus_from_tree(cut_off_tree) 85 | cut_off_tree.collapse_unary(collapsePOS=True, collapseRoot=True) 86 | cut_off_tree.chomsky_normal_form() 87 | ptree = PTree.convert(cut_off_tree) 88 | return ptree 89 | 90 | def postprocess(self, tree: PTree) -> Tree: 91 | tree = Tree.convert(tree) 92 | tree.un_chomsky_normal_form() 93 | add_plus_to_tree(tree) 94 | if self.add_remove_top: 95 | return Tree("TOP", [tree]) 96 | else: 97 | return tree 98 | -------------------------------------------------------------------------------- /tagging/test.py: -------------------------------------------------------------------------------- 1 | # import logging 2 | import unittest 3 | 4 | import numpy as np 5 | from nltk import ParentedTree 6 | from nltk import Tree 7 | 8 | from learning.evaluate import evalb 9 | from tetratagger import BottomUpTetratagger, TopDownTetratagger 10 | from tagging.srtagger import SRTaggerBottomUp, SRTaggerTopDown 11 | from transform import LeftCornerTransformer, RightCornerTransformer 12 | from tree_tools import random_tree, is_topo_equal, create_dummy_tree 13 | 14 | from original_tetratagger import TetraTagSequence, TetraTagSystem 15 | from nltk.corpus.reader.bracket_parse import BracketParseCorpusReader 16 | from tqdm import tqdm as tq 17 | 18 | # logging.getLogger().setLevel(logging.DEBUG) 19 | 20 | np.random.seed(0) 21 | 22 | 23 | class TestTransforms(unittest.TestCase): 24 | def test_transform(self): 25 | tree = ParentedTree.fromstring("(S (NP (det the) (N dog)) (VP (V ran) (Adv fast)))") 26 | tree.pretty_print() 27 | new_tree_lc = ParentedTree("S", []) 28 | LeftCornerTransformer.transform(new_tree_lc, tree, tree) 29 | new_tree_lc.pretty_print() 30 | 31 | new_tree_rc = ParentedTree("S", []) 32 | RightCornerTransformer.transform(new_tree_rc, tree, tree) 33 | new_tree_rc.pretty_print() 34 | 35 | def test_rev_rc_transform(self, trials=100): 36 | for _ in range(trials): 37 | t = ParentedTree("ROOT", []) 38 | random_tree(t, depth=0, cutoff=5) 39 | new_tree_rc = ParentedTree("S", []) 40 | RightCornerTransformer.transform(new_tree_rc, t, t) 41 | tree_back = ParentedTree("X", ["", ""]) 42 | tree_back = RightCornerTransformer.rev_transform(tree_back, new_tree_rc) 43 | self.assertEqual(tree_back, t) 44 | 45 | def test_rev_lc_transform(self, trials=100): 46 | for _ in range(trials): 47 | t = ParentedTree("ROOT", []) 48 | random_tree(t, depth=0, cutoff=5) 49 | new_tree_lc = ParentedTree("S", []) 50 | LeftCornerTransformer.transform(new_tree_lc, t, t) 51 | tree_back = ParentedTree("X", ["", ""]) 52 | tree_back = LeftCornerTransformer.rev_transform(tree_back, new_tree_lc) 53 | self.assertEqual(tree_back, t) 54 | 55 | 56 | class TestTagging(unittest.TestCase): 57 | def test_buttom_up(self): 58 | tree = ParentedTree.fromstring("(S (NP (det the) (N dog)) (VP (V ran) (Adv fast)))") 59 | tree.pretty_print() 60 | tree_rc = ParentedTree("S", []) 61 | RightCornerTransformer.transform(tree_rc, tree, tree) 62 | tree_rc.pretty_print() 63 | tagger = BottomUpTetratagger() 64 | tags, _ = tagger.tree_to_tags(tree_rc) 65 | 66 | for tag in tagger.tetra_visualize(tags): 67 | print(tag) 68 | print("--" * 20) 69 | 70 | def test_buttom_up_alternate(self, trials=100): 71 | for _ in range(trials): 72 | t = ParentedTree("ROOT", []) 73 | random_tree(t, depth=0, cutoff=5) 74 | t_rc = ParentedTree("S", []) 75 | RightCornerTransformer.transform(t_rc, t, t) 76 | tagger = BottomUpTetratagger() 77 | tags, _ = tagger.tree_to_tags(t_rc) 78 | self.assertTrue(tagger.is_alternating(tags)) 79 | self.assertTrue((2 * len(t.leaves()) - 1) == len(tags)) 80 | 81 | def test_round_trip_test_buttom_up(self, trials=100): 82 | for _ in range(trials): 83 | tree = ParentedTree("ROOT", []) 84 | random_tree(tree, depth=0, cutoff=5) 85 | tree_rc = ParentedTree("S", []) 86 | RightCornerTransformer.transform(tree_rc, tree, tree) 87 | tree.pretty_print() 88 | tree_rc.pretty_print() 89 | tagger = BottomUpTetratagger() 90 | tags, _ = tagger.tree_to_tags(tree_rc) 91 | root_from_tags = tagger.tags_to_tree(tags, tree.leaves()) 92 | tree_back = ParentedTree("X", ["", ""]) 93 | tree_back = RightCornerTransformer.rev_transform(tree_back, root_from_tags, 94 | pick_up_labels=False) 95 | self.assertTrue(is_topo_equal(tree, tree_back)) 96 | 97 | def test_top_down(self): 98 | tree = ParentedTree.fromstring("(S (NP (det the) (N dog)) (VP (V ran) (Adv fast)))") 99 | tree.pretty_print() 100 | tree_lc = ParentedTree("S", []) 101 | LeftCornerTransformer.transform(tree_lc, tree, tree) 102 | tree_lc.pretty_print() 103 | tagger = TopDownTetratagger() 104 | tags, _ = tagger.tree_to_tags(tree_lc) 105 | 106 | for tag in tagger.tetra_visualize(tags): 107 | print(tag) 108 | print("--" * 20) 109 | 110 | def test_top_down_alternate(self, trials=100): 111 | for _ in range(trials): 112 | t = ParentedTree("ROOT", []) 113 | random_tree(t, depth=0, cutoff=5) 114 | t_lc = ParentedTree("S", []) 115 | LeftCornerTransformer.transform(t_lc, t, t) 116 | tagger = TopDownTetratagger() 117 | tags = tagger.tree_to_tags(t_lc) 118 | self.assertTrue(tagger.is_alternating(tags)) 119 | self.assertTrue((2 * len(t.leaves()) - 1) == len(tags)) 120 | 121 | def round_trip_test_top_down(self, trials=100): 122 | for _ in range(trials): 123 | tree = ParentedTree("ROOT", []) 124 | random_tree(tree, depth=0, cutoff=5) 125 | tree_lc = ParentedTree("S", []) 126 | LeftCornerTransformer.transform(tree_lc, tree, tree) 127 | tagger = TopDownTetratagger() 128 | tags = tagger.tree_to_tags(tree_lc) 129 | root_from_tags = tagger.tags_to_tree(tags, tree.leaves()) 130 | tree_back = ParentedTree("X", ["", ""]) 131 | tree_back = LeftCornerTransformer.rev_transform(tree_back, root_from_tags, 132 | pick_up_labels=False) 133 | self.assertTrue(is_topo_equal(tree, tree_back)) 134 | 135 | 136 | class TestPipeline(unittest.TestCase): 137 | def test_example_colab(self): 138 | example_tree = Tree.fromstring( 139 | "(S (NP (PRP She)) (VP (VBZ enjoys) (S (VP (VBG playing) (NP (NN tennis))))) (. .))") 140 | tagger = BottomUpTetratagger() 141 | tags = tagger.tree_to_tags_pipeline(example_tree)[0] 142 | print(tags) 143 | for tag in tagger.tetra_visualize(tags): 144 | print(tag) 145 | 146 | def test_dummy_tree(self): 147 | example_tree = Tree.fromstring( 148 | "(S (NP (PRP She)) (VP (VBZ enjoys) (S (VP (VBG playing) (NP (NN tennis))))) (. .))") 149 | print(example_tree.pos()) 150 | dummy = create_dummy_tree(example_tree.pos()) 151 | dummy.pretty_print() 152 | example_tree.pretty_print() 153 | print(evalb("../EVALB/", [example_tree], [dummy])) 154 | 155 | def test_tree_linearizations(self): 156 | READER = BracketParseCorpusReader('../data/spmrl/', 157 | ['English.train', 'English.dev', 'English.test']) 158 | trees = READER.parsed_sents('English.test') 159 | for tree in trees: 160 | print(tree) 161 | print(" ".join(tree.leaves())) 162 | 163 | def test_compare_to_original_tetratagger(self): 164 | # import pickle 165 | # with open("../data/tetra.pkl", 'rb') as f: 166 | # tag_vocab = pickle.load(f) 167 | 168 | READER = BracketParseCorpusReader('../data/spmrl/', ['English.train', 'English.dev', 'English.test']) 169 | trees = READER.parsed_sents('English.test') 170 | tagger = BottomUpTetratagger(add_remove_top=True) 171 | tetratagger = TetraTagSystem(trees=trees) 172 | for tree in tq(trees): 173 | original_tree = tree.copy(deep=True) 174 | # original_tags = TetraTagSequence.from_tree(original_tree) 175 | tags = tagger.tree_to_tags_pipeline(tree)[0] 176 | tetratags = tetratagger.tags_from_tree(tree) 177 | # ids = tagger.tree_to_ids_pipeline(tree) 178 | # ids = tagger.ids_from_tree(tree) 179 | # tree_back = tagger.tags_to_tree_pipeline(tags, tree.pos()) 180 | # self.assertEqual(original_tree, tree_back) 181 | self.assertEqual(list(tetratags), list(tags)) 182 | 183 | def test_example_colab_lc(self): 184 | example_tree = Tree.fromstring( 185 | "(S (NP (PRP She)) (VP (VBZ enjoys) (S (VP (VBG playing) (NP (NN tennis))))) (. .))") 186 | original_tree = example_tree.copy(deep=True) 187 | tagger = TopDownTetratagger() 188 | tags = tagger.tree_to_tags_pipeline(example_tree)[0] 189 | tree_back = tagger.tags_to_tree_pipeline(tags, example_tree.pos()) 190 | tree_back.pretty_print() 191 | self.assertEqual(original_tree, tree_back) 192 | 193 | def test_top_down_tetratagger(self): 194 | READER = BracketParseCorpusReader('../data/spmrl/', ['English.train', 'English.dev', 'English.test']) 195 | trees = READER.parsed_sents('English.test') 196 | tagger = TopDownTetratagger(add_remove_top=True) 197 | for tree in tq(trees): 198 | original_tree = tree.copy(deep=True) 199 | tags = tagger.tree_to_tags_pipeline(tree)[0] 200 | tree_back = tagger.tags_to_tree_pipeline(tags, tree.pos()) 201 | self.assertEqual(original_tree, tree_back) 202 | 203 | def test_tag_ids_top_down(self): 204 | READER = BracketParseCorpusReader('../data/spmrl/', ['English.train', 'English.dev', 'English.test']) 205 | trees = READER.parsed_sents('English.test') 206 | tagger = TopDownTetratagger(trees, add_remove_top=True) 207 | for tree in tq(trees): 208 | original_tree = tree.copy(deep=True) 209 | ids = tagger.tree_to_ids_pipeline(tree) 210 | tree_back = tagger.ids_to_tree_pipeline(ids, tree.pos()) 211 | self.assertEqual(original_tree, tree_back) 212 | 213 | def test_tag_ids_bottom_up(self): 214 | READER = BracketParseCorpusReader('../data/spmrl/', ['English.train', 'English.dev', 'English.test']) 215 | trees = READER.parsed_sents('English.test') 216 | tagger = BottomUpTetratagger(trees, add_remove_top=True) 217 | for tree in tq(trees): 218 | original_tree = tree.copy(deep=True) 219 | ids = tagger.tree_to_ids_pipeline(tree) 220 | tree_back = tagger.ids_to_tree_pipeline(ids, tree.pos()) 221 | self.assertEqual(original_tree, tree_back) 222 | 223 | def test_decoder_edges(self): 224 | READER = BracketParseCorpusReader('../data/spmrl/', ['English.train', 'English.dev', 'English.test']) 225 | trees = READER.parsed_sents('English.test') 226 | tagger_bu = SRTaggerBottomUp(add_remove_top=True) 227 | tagger_td = SRTaggerTopDown(add_remove_top=True) 228 | unique_tags = dict() 229 | counter = 0 230 | for tree in tq(trees): 231 | counter += 1 232 | tags = tagger_td.tree_to_tags_pipeline(tree)[0] 233 | for i, tag in enumerate(tags): 234 | tag_s = tag.split("/")[0] 235 | if i < len(tags) - 1: 236 | tag_next = tags[i+1].split("/")[0] 237 | else: 238 | tag_next = None 239 | if tag_s not in unique_tags and tag_next is not None: 240 | unique_tags[tag_s] = {tag_next} 241 | elif tag_next is not None: 242 | unique_tags[tag_s].add(tag_next) 243 | print(unique_tags) 244 | 245 | 246 | class TestSRTagger(unittest.TestCase): 247 | def test_tag_sequence_example(self): 248 | READER = BracketParseCorpusReader('../data/spmrl/', ['English.train', 'English.dev', 'English.test']) 249 | trees = READER.parsed_sents('English.train') 250 | tagger = SRTaggerBottomUp(add_remove_top=True) 251 | for tree in tq(trees): 252 | original_tree = tree.copy(deep=True) 253 | tags = tagger.tree_to_tags_pipeline(tree)[0] 254 | for tag in tags: 255 | if tag.find('/X/') != -1: 256 | print(tag) 257 | original_tree.pretty_print() 258 | tree_back = tagger.tags_to_tree_pipeline(tags, tree.pos()) 259 | self.assertEqual(original_tree, tree_back) 260 | 261 | def test_example_both_version(self): 262 | import nltk 263 | example_tree = nltk.Tree.fromstring( 264 | "(TOP (S (NP (PRP She)) (VP (VBZ enjoys) (S (VP (VBG playing) (NP (NN tennis))))) (. .)))") 265 | example_tree.pretty_print() 266 | td_tagger = SRTaggerTopDown(add_remove_top=True) 267 | bu_tagger = SRTaggerBottomUp(add_remove_top=True) 268 | t1 = example_tree.copy(deep=True) 269 | t2 = example_tree.copy(deep=True) 270 | td_tags = td_tagger.tree_to_tags_pipeline(t1)[0] 271 | bu_tags = bu_tagger.tree_to_tags_pipeline(t2)[0] 272 | self.assertEqual(set(td_tags), set(bu_tags)) 273 | print(list(td_tags)) 274 | print(list(bu_tags)) 275 | 276 | def test_bu_binarize(self): 277 | READER = BracketParseCorpusReader('../data/spmrl/', ['English.train', 'English.dev', 'English.test']) 278 | trees = READER.parsed_sents('English.dev') 279 | tagger = SRTaggerBottomUp(add_remove_top=True) 280 | for tree in tq(trees): 281 | tags = tagger.tree_to_tags_pipeline(tree)[0] 282 | for idx, tag in enumerate(tags): 283 | if tag.startswith("rr"): 284 | if (idx + 1) < len(tags): 285 | self.assertTrue(tags[idx + 1].startswith('r')) 286 | elif tag.startswith("r"): 287 | if (idx + 1) < len(tags): 288 | self.assertTrue(tags[idx + 1].startswith('s')) 289 | 290 | def test_td_binarize(self): 291 | READER = BracketParseCorpusReader('../data/spmrl/', ['English.train', 'English.dev', 'English.test']) 292 | trees = READER.parsed_sents('English.dev') 293 | tagger = SRTaggerTopDown(add_remove_top=True) 294 | for tree in tq(trees): 295 | tags = tagger.tree_to_tags_pipeline(tree)[0] 296 | for idx, tag in enumerate(tags): 297 | if tag.startswith("s"): 298 | if (idx + 1) < len(tags): 299 | self.assertTrue(tags[idx + 1].startswith('rr') or tags[idx + 1].startswith('s')) 300 | elif tag.startswith("r"): 301 | if (idx + 1) < len(tags): 302 | self.assertTrue(not tags[idx + 1].startswith('rr')) 303 | 304 | def test_td_bu(self): 305 | READER = BracketParseCorpusReader('../data/spmrl/', ['English.train', 'English.dev', 'English.test']) 306 | trees = READER.parsed_sents('English.dev') 307 | td_tagger = SRTaggerTopDown(add_remove_top=True) 308 | bu_tagger = SRTaggerBottomUp(add_remove_top=True) 309 | 310 | for tree in tq(trees): 311 | td_tags = td_tagger.tree_to_tags_pipeline(tree)[0] 312 | bu_tags = bu_tagger.tree_to_tags_pipeline(tree)[0] 313 | self.assertEqual(set(td_tags), set(bu_tags)) 314 | 315 | def test_tag_ids(self): 316 | READER = BracketParseCorpusReader('../data/spmrl/', ['English.train', 'English.dev', 'English.test']) 317 | trees = READER.parsed_sents('English.train') 318 | tagger = SRTaggerBottomUp(trees, add_remove_top=False) 319 | for tree in tq(trees): 320 | ids = tagger.tree_to_ids_pipeline(tree) 321 | tree_back = tagger.ids_to_tree_pipeline(ids, tree.pos()) 322 | self.assertEqual(tree, tree_back) 323 | 324 | def test_max_length(self): 325 | READER = BracketParseCorpusReader('../data/spmrl/', ['English.train', 'English.dev', 'English.test']) 326 | trees = READER.parsed_sents('English.train') 327 | tagger = SRTaggerBottomUp(trees, add_remove_top=True) 328 | print(len(tagger.tag_vocab)) 329 | 330 | 331 | class TestSPMRL(unittest.TestCase): 332 | def test_reading_trees(self): 333 | langs = ['Korean'] 334 | for l in langs: 335 | READER = BracketParseCorpusReader('../data/spmrl/', [l+'.train', l+'.dev', l+'.test']) 336 | trees = READER.parsed_sents(l+'.test') 337 | trees[0].pretty_print() 338 | 339 | def test_tagging_bu_sr(self): 340 | langs = ['Basque', 'French', 'German', 'Hebrew', 'Hungarian', 'Korean', 'Polish', 341 | 'Swedish'] 342 | for l in langs: 343 | READER = BracketParseCorpusReader('../data/spmrl/', 344 | [l + '.train', l + '.dev', l + '.test']) 345 | trees = READER.parsed_sents(l + '.test') 346 | trees[0].pretty_print() 347 | tagger = SRTaggerBottomUp(trees, add_remove_top=True) 348 | print(tagger.tree_to_tags_pipeline(trees[0])) 349 | print(tagger.tree_to_ids_pipeline(trees[0])) 350 | 351 | def test_tagging_td_sr(self): 352 | langs = ['Basque', 'French', 'German', 'Hebrew', 'Hungarian', 'Korean', 'Polish', 353 | 'Swedish'] 354 | for l in langs: 355 | READER = BracketParseCorpusReader('../data/spmrl/', 356 | [l + '.train', l + '.dev', l + '.test']) 357 | trees = READER.parsed_sents(l + '.test') 358 | trees[0].pretty_print() 359 | tagger = SRTaggerBottomUp(trees, add_remove_top=True) 360 | print(tagger.tree_to_tags_pipeline(trees[0])) 361 | print(tagger.tree_to_ids_pipeline(trees[0])) 362 | 363 | def test_tagging_tetra(self): 364 | langs = ['Basque', 'French', 'German', 'Hebrew', 'Hungarian', 'Korean', 'Polish', 365 | 'Swedish'] 366 | for l in langs: 367 | READER = BracketParseCorpusReader('../data/spmrl/', 368 | [l + '.train', l + '.dev', l + '.test']) 369 | trees = READER.parsed_sents(l + '.test') 370 | trees[0].pretty_print() 371 | tagger = BottomUpTetratagger(trees, add_remove_top=True) 372 | print(tagger.tree_to_tags_pipeline(trees[0])) 373 | print(tagger.tree_to_ids_pipeline(trees[0])) 374 | 375 | def test_taggers(self): 376 | tagger = SRTaggerBottomUp(add_remove_top=False) 377 | langs = ['Basque', 'French', 'German', 'Hebrew', 'Hungarian', 'Korean', 'Polish', 378 | 'Swedish'] 379 | for l in tq(langs): 380 | READER = BracketParseCorpusReader('../data/spmrl/', 381 | [l + '.train', l + '.dev', l + '.test']) 382 | trees = READER.parsed_sents(l + '.test') 383 | for tree in tq(trees): 384 | tags, _ = tagger.tree_to_tags_pipeline(tree) 385 | tree_back = tagger.tags_to_tree_pipeline(tags, tree.pos()) 386 | self.assertEqual(tree, tree_back) 387 | 388 | def test_korean(self): 389 | import pickle 390 | with open("../data/vocab/Korean-bu-sr.pkl", 'rb') as f: 391 | tag_vocab = pickle.load(f) 392 | tagger = SRTaggerBottomUp(tag_vocab=tag_vocab, add_remove_top=False) 393 | l = "Korean" 394 | READER = BracketParseCorpusReader('../data/spmrl/', 395 | [l + '.train', l + '.dev', l + '.test']) 396 | trees = READER.parsed_sents(l + '.test') 397 | for tree in tq(trees): 398 | tags = tagger.tree_to_tags_pipeline(tree)[0] 399 | tree_back = tagger.tags_to_tree_pipeline(tags, tree.pos()) 400 | self.assertEqual(tree, tree_back) 401 | 402 | 403 | class TestStackSize(unittest.TestCase): 404 | def test_english_tetra(self): 405 | l = "English" 406 | READER = BracketParseCorpusReader('../data/spmrl', [l+'.train', l+'.dev', l+'.test']) 407 | trees = READER.parsed_sents([l+'.train', l+'.dev', l+'.test']) 408 | tagger = BottomUpTetratagger(add_remove_top=True) 409 | stack_size_list = [] 410 | for tree in tq(trees): 411 | tags, max_depth = tagger.tree_to_tags_pipeline(tree) 412 | stack_size_list.append(max_depth) 413 | 414 | print(stack_size_list) 415 | 416 | 417 | if __name__ == '__main__': 418 | unittest.main() 419 | -------------------------------------------------------------------------------- /tagging/tetratagger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from abc import ABC 3 | 4 | import numpy as np 5 | from nltk import ParentedTree as PTree 6 | from nltk import Tree 7 | 8 | from learning.decode import BeamSearch, GreedySearch 9 | from tagging.tagger import Tagger, TagDecodeModerator 10 | from tagging.transform import LeftCornerTransformer, RightCornerTransformer 11 | from tagging.tree_tools import find_node_type, is_node_epsilon, NodeType 12 | 13 | 14 | class TetraTagDecodeModerator(TagDecodeModerator): 15 | def __init__(self, tag_vocab): 16 | super().__init__(tag_vocab) 17 | self.internal_tag_vocab_size = len( 18 | [tag for tag in tag_vocab if tag[0] in "LR"] 19 | ) 20 | self.leaf_tag_vocab_size = len( 21 | [tag for tag in tag_vocab if tag[0] in "lr"] 22 | ) 23 | 24 | is_leaf_mask = np.concatenate( 25 | [ 26 | np.zeros(self.internal_tag_vocab_size), 27 | np.ones(self.leaf_tag_vocab_size), 28 | ] 29 | ) 30 | self.internal_tags_only = np.asarray(-1e9 * is_leaf_mask, dtype=float) 31 | self.leaf_tags_only = np.asarray( 32 | -1e9 * (1 - is_leaf_mask), dtype=float 33 | ) 34 | 35 | stack_depth_change_by_id = [None] * len(tag_vocab) 36 | for i, tag in enumerate(tag_vocab): 37 | if tag.startswith("l"): 38 | stack_depth_change_by_id[i] = +1 39 | elif tag.startswith("R"): 40 | stack_depth_change_by_id[i] = -1 41 | else: 42 | stack_depth_change_by_id[i] = 0 43 | assert None not in stack_depth_change_by_id 44 | self.stack_depth_change_by_id = np.array( 45 | stack_depth_change_by_id, dtype=np.int32 46 | ) 47 | self.mask_binarize = False 48 | 49 | 50 | class TetraTagger(Tagger, ABC): 51 | def __init__(self, trees=None, tag_vocab=None, add_remove_top=False): 52 | super().__init__(trees, tag_vocab, add_remove_top) 53 | self.decode_moderator = TetraTagDecodeModerator(self.tag_vocab) 54 | 55 | def expand_tags(self, tags: [str]) -> [str]: 56 | raise NotImplementedError("expand tags is not implemented") 57 | 58 | @staticmethod 59 | def tetra_visualize(tags: [str]): 60 | for tag in tags: 61 | if tag.startswith('r'): 62 | yield "-->" 63 | if tag.startswith('l'): 64 | yield "<--" 65 | if tag.startswith('R'): 66 | yield "==>" 67 | if tag.startswith('L'): 68 | yield "<==" 69 | 70 | @staticmethod 71 | def create_shift_tag(label: str, left_or_right: str) -> str: 72 | if label.find("+") != -1: 73 | return left_or_right + "/" + "/".join(label.split("+")[:-1]) 74 | else: 75 | return left_or_right 76 | 77 | @staticmethod 78 | def _create_bi_reduce_tag(label: str, left_or_right: str) -> str: 79 | label = label.split("\\")[1] 80 | if label.find("|") != -1: # drop extra node labels created after binarization 81 | return left_or_right 82 | else: 83 | return left_or_right + "/" + label.replace("+", "/") 84 | 85 | @staticmethod 86 | def _create_unary_reduce_tag(label: str, left_or_right: str) -> str: 87 | label = label.split("\\")[0] 88 | if label.find("|") != -1: # drop extra node labels created after binarization 89 | return left_or_right 90 | else: 91 | return left_or_right + "/" + label.replace("+", "/") 92 | 93 | @staticmethod 94 | def create_merge_shift_tag(label: str, left_or_right: str) -> str: 95 | if label.find("/") != -1: 96 | return left_or_right + "/" + "/".join(label.split("/")[1:]) 97 | else: 98 | return left_or_right 99 | 100 | @staticmethod 101 | def _create_pre_terminal_label(tag: str, default="X") -> str: 102 | idx = tag.find("/") 103 | if idx != -1: 104 | label = tag[idx + 1:].replace("/", "+") 105 | if default == "": 106 | return label + "+" 107 | else: 108 | return label 109 | else: 110 | return default 111 | 112 | @staticmethod 113 | def _create_unary_reduce_label(tag: str) -> str: 114 | idx = tag.find("/") 115 | if idx == -1: 116 | return "X|" 117 | return tag[idx + 1:].replace("/", "+") 118 | 119 | @staticmethod 120 | def _create_reduce_label(tag: str) -> str: 121 | idx = tag.find("/") 122 | if idx == -1: 123 | label = "X\\|" # to mark the second part as an extra node created via binarizaiton 124 | else: 125 | label = "X\\" + tag[idx + 1:].replace("/", "+") 126 | return label 127 | 128 | @staticmethod 129 | def is_alternating(tags: [str]) -> bool: 130 | prev_state = True # true means reduce 131 | for tag in tags: 132 | if tag.startswith('r') or tag.startswith('l'): 133 | state = False 134 | else: 135 | state = True 136 | if state == prev_state: 137 | return False 138 | prev_state = state 139 | return True 140 | 141 | def logits_to_ids(self, logits: [], mask, max_depth, keep_per_depth, is_greedy=False) -> [int]: 142 | if is_greedy: 143 | searcher = GreedySearch( 144 | self.decode_moderator, 145 | initial_stack_depth=0, 146 | max_depth=max_depth, 147 | keep_per_depth=keep_per_depth, 148 | ) 149 | else: 150 | searcher = BeamSearch( 151 | self.decode_moderator, 152 | initial_stack_depth=0, 153 | max_depth=max_depth, 154 | keep_per_depth=keep_per_depth, 155 | ) 156 | 157 | last_t = None 158 | for t in range(logits.shape[0]): 159 | if mask is not None and not mask[t]: 160 | continue 161 | if last_t is not None: 162 | searcher.advance( 163 | logits[last_t, :] + self.decode_moderator.internal_tags_only 164 | ) 165 | searcher.advance(logits[t, :] + self.decode_moderator.leaf_tags_only) 166 | last_t = t 167 | 168 | score, best_tag_ids = searcher.get_path() 169 | return best_tag_ids 170 | 171 | 172 | class BottomUpTetratagger(TetraTagger): 173 | """ Kitaev and Klein (2020)""" 174 | 175 | def expand_tags(self, tags: [str]) -> [str]: 176 | new_tags = [] 177 | for tag in tags: 178 | if tag.startswith('r'): 179 | new_tags.append("l" + tag[1:]) 180 | new_tags.append("R") 181 | else: 182 | new_tags.append(tag) 183 | return new_tags 184 | 185 | def preprocess(self, tree: Tree) -> PTree: 186 | ptree: PTree = super().preprocess(tree) 187 | root_label = ptree.label() 188 | tree_rc = PTree(root_label, []) 189 | RightCornerTransformer.transform(tree_rc, ptree, ptree) 190 | return tree_rc 191 | 192 | def tree_to_tags(self, root: PTree) -> ([], int): 193 | tags = [] 194 | lc = LeftCornerTransformer.extract_left_corner_no_eps(root) 195 | tags.append(self.create_shift_tag(lc.label(), "l")) 196 | 197 | logging.debug("SHIFT {}".format(lc.label())) 198 | stack = [lc] 199 | max_stack_len = 1 200 | 201 | while len(stack) > 0: 202 | max_stack_len = max(max_stack_len, len(stack)) 203 | node = stack[-1] 204 | if find_node_type( 205 | node) == NodeType.NT: # special case: merge the reduce and last shift 206 | last_tag = tags.pop() 207 | last_two_tag = tags.pop() 208 | if not last_tag.startswith('R') or not last_two_tag.startswith('l'): 209 | raise ValueError( 210 | "When reaching NT the right PT should already be shifted") 211 | # merged shift 212 | tags.append(self.create_merge_shift_tag(last_two_tag, "r")) 213 | 214 | if node.left_sibling() is None and node.right_sibling() is not None: 215 | lc = LeftCornerTransformer.extract_left_corner_no_eps(node.right_sibling()) 216 | stack.append(lc) 217 | logging.debug("<-- \t SHIFT {}".format(lc.label())) 218 | # normal shift 219 | tags.append(self.create_shift_tag(lc.label(), "l")) 220 | 221 | elif len(stack) >= 2 and ( 222 | node.right_sibling() == stack[-2] or node.left_sibling() == stack[-2]): 223 | prev_node = stack[-2] 224 | logging.debug("==> \t REDUCE[ {0} {1} --> {2} ]".format( 225 | *(prev_node.label(), node.label(), node.parent().label()))) 226 | 227 | tags.append( 228 | self._create_bi_reduce_tag(prev_node.label(), "R")) # normal reduce 229 | stack.pop() 230 | stack.pop() 231 | stack.append(node.parent()) 232 | 233 | elif find_node_type(node) != NodeType.NT_NT: 234 | if stack[0].parent() is None and len(stack) == 1: 235 | stack.pop() 236 | continue 237 | logging.debug( 238 | "<== \t REDUCE[ {0} --> {1} ]".format( 239 | *(node.label(), node.parent().label()))) 240 | tags.append(self._create_unary_reduce_tag( 241 | node.parent().label(), "L")) # unary reduce 242 | stack.pop() 243 | stack.append(node.parent()) 244 | else: 245 | logging.error("ERROR: Undefined stack state") 246 | return 247 | logging.debug("=" * 20) 248 | return tags, max_stack_len 249 | 250 | def _unary_reduce(self, node, last_node, tag): 251 | label = self._create_unary_reduce_label(tag) 252 | node.insert(0, PTree(label + "\\" + label, ["EPS"])) 253 | node.insert(1, last_node) 254 | return node 255 | 256 | def _reduce(self, node, last_node, last_2_node, tag): 257 | label = self._create_reduce_label(tag) 258 | last_2_node.set_label(label) 259 | node.insert(0, last_2_node) 260 | node.insert(1, last_node) 261 | return node 262 | 263 | def postprocess(self, transformed_tree: PTree) -> Tree: 264 | tree = PTree("X", ["", ""]) 265 | tree = RightCornerTransformer.rev_transform(tree, transformed_tree) 266 | return super().postprocess(tree) 267 | 268 | def tags_to_tree(self, tags: [str], input_seq: [str]) -> PTree: 269 | created_node_stack = [] 270 | node = None 271 | expanded_tags = self.expand_tags(tags) 272 | if len(expanded_tags) == 1: # base case 273 | assert expanded_tags[0].startswith('l') 274 | prefix = self._create_pre_terminal_label(expanded_tags[0], "") 275 | return PTree(prefix + input_seq[0][1], [input_seq[0][0]]) 276 | for tag in expanded_tags: 277 | if tag.startswith('l'): # shift 278 | prefix = self._create_pre_terminal_label(tag, "") 279 | created_node_stack.append(PTree(prefix + input_seq[0][1], [input_seq[0][0]])) 280 | input_seq.pop(0) 281 | else: 282 | node = PTree("X", []) 283 | if tag.startswith('R'): # normal reduce 284 | last_node = created_node_stack.pop() 285 | last_2_node = created_node_stack.pop() 286 | created_node_stack.append(self._reduce(node, last_node, last_2_node, tag)) 287 | elif tag.startswith('L'): # unary reduce 288 | created_node_stack.append( 289 | self._unary_reduce(node, created_node_stack.pop(), tag)) 290 | if len(input_seq) != 0: 291 | raise ValueError("All the input sequence is not used") 292 | return node 293 | 294 | 295 | class TopDownTetratagger(TetraTagger): 296 | 297 | @staticmethod 298 | def create_merge_shift_tag(label: str, left_or_right: str) -> str: 299 | if label.find("+") != -1: 300 | return left_or_right + "/" + "/".join(label.split("+")[:-1]) 301 | else: 302 | return left_or_right 303 | 304 | def expand_tags(self, tags: [str]) -> [str]: 305 | new_tags = [] 306 | for tag in tags: 307 | if tag.startswith('l'): 308 | new_tags.append("L") 309 | new_tags.append("r" + tag[1:]) 310 | else: 311 | new_tags.append(tag) 312 | return new_tags 313 | 314 | def preprocess(self, tree: Tree) -> PTree: 315 | ptree = super(TopDownTetratagger, self).preprocess(tree) 316 | root_label = ptree.label() 317 | tree_lc = PTree(root_label, []) 318 | LeftCornerTransformer.transform(tree_lc, ptree, ptree) 319 | return tree_lc 320 | 321 | def tree_to_tags(self, root: PTree) -> ([str], int): 322 | """ convert left-corner transformed tree to shifts and reduces """ 323 | stack: [PTree] = [root] 324 | max_stack_len = 1 325 | logging.debug("SHIFT {}".format(root.label())) 326 | tags = [] 327 | while len(stack) > 0: 328 | max_stack_len = max(max_stack_len, len(stack)) 329 | node = stack[-1] 330 | if find_node_type(node) == NodeType.NT or find_node_type(node) == NodeType.NT_NT: 331 | stack.pop() 332 | logging.debug("REDUCE[ {0} --> {1} {2}]".format( 333 | *(node.label(), node[0].label(), node[1].label()))) 334 | if find_node_type(node) == NodeType.NT: 335 | if find_node_type(node[0]) != NodeType.PT: 336 | raise ValueError("Left child of NT should be a PT") 337 | stack.append(node[1]) 338 | tags.append( 339 | self.create_merge_shift_tag(node[0].label(), "l")) # merged shift 340 | else: 341 | if not is_node_epsilon(node[1]): 342 | stack.append(node[1]) 343 | tags.append(self._create_bi_reduce_tag(node[1].label(), "L")) 344 | # normal reduce 345 | else: 346 | tags.append(self._create_unary_reduce_tag(node[1].label(), "R")) 347 | # unary reduce 348 | stack.append(node[0]) 349 | 350 | elif find_node_type(node) == NodeType.PT: 351 | tags.append(self.create_shift_tag(node.label(), "r")) # normal shift 352 | logging.debug("-->\tSHIFT[ {0} ]".format(node.label())) 353 | stack.pop() 354 | return tags, max_stack_len 355 | 356 | def postprocess(self, transformed_tree: PTree) -> Tree: 357 | ptree = PTree("X", ["", ""]) 358 | ptree = LeftCornerTransformer.rev_transform(ptree, transformed_tree) 359 | return super().postprocess(ptree) 360 | 361 | def tags_to_tree(self, tags: [str], input_seq: [str]) -> PTree: 362 | expanded_tags = self.expand_tags(tags) 363 | root = PTree("X", []) 364 | created_node_stack = [root] 365 | if len(expanded_tags) == 1: # base case 366 | assert expanded_tags[0].startswith('r') 367 | prefix = self._create_pre_terminal_label(expanded_tags[0], "") 368 | return PTree(prefix + input_seq[0][1], [input_seq[0][0]]) 369 | for tag in expanded_tags: 370 | if tag.startswith('r'): # shift 371 | node = created_node_stack.pop() 372 | prefix = self._create_pre_terminal_label(tag, "") 373 | node.set_label(prefix + input_seq[0][1]) 374 | node.insert(0, input_seq[0][0]) 375 | input_seq.pop(0) 376 | elif tag.startswith('R') or tag.startswith('L'): 377 | parent = created_node_stack.pop() 378 | if tag.startswith('L'): # normal reduce 379 | label = self._create_reduce_label(tag) 380 | r_node = PTree(label, []) 381 | created_node_stack.append(r_node) 382 | else: 383 | label = self._create_unary_reduce_label(tag) 384 | r_node = PTree(label + "\\" + label, ["EPS"]) 385 | 386 | l_node_label = self._create_reduce_label(tag) 387 | l_node = PTree(l_node_label, []) 388 | created_node_stack.append(l_node) 389 | parent.insert(0, l_node) 390 | parent.insert(1, r_node) 391 | else: 392 | raise ValueError("Invalid tag type") 393 | if len(input_seq) != 0: 394 | raise ValueError("All the input sequence is not used") 395 | return root 396 | -------------------------------------------------------------------------------- /tagging/transform.py: -------------------------------------------------------------------------------- 1 | from nltk import ParentedTree as PTree 2 | from tagging.tree_tools import find_node_type, is_node_epsilon, NodeType 3 | 4 | 5 | class Transformer: 6 | @classmethod 7 | def expand_nt(cls, node: PTree, ref_node: PTree) -> (PTree, PTree, PTree, PTree): 8 | raise NotImplementedError("expand non-terminal is not implemented") 9 | 10 | @classmethod 11 | def expand_nt_nt(cls, node: PTree, ref_node1: PTree, ref_node2: PTree) -> ( 12 | PTree, PTree, PTree, PTree): 13 | raise NotImplementedError("expand paired non-terimnal is not implemented") 14 | 15 | @classmethod 16 | def extract_right_corner(cls, node: PTree) -> PTree: 17 | while type(node[0]) != str: 18 | if len(node) > 1: 19 | node = node[1] 20 | else: # unary rules 21 | node = node[0] 22 | return node 23 | 24 | @classmethod 25 | def extract_left_corner(cls, node: PTree) -> PTree: 26 | while len(node) > 1: 27 | node = node[0] 28 | return node 29 | 30 | @classmethod 31 | def transform(cls, node: PTree, ref_node1: PTree, ref_node2: PTree) -> None: 32 | if node is None: 33 | return 34 | type = find_node_type(node) 35 | if type == NodeType.NT: 36 | left_ref1, left_ref2, right_ref1, right_ref2, is_base_case = cls.expand_nt(node, 37 | ref_node1) 38 | elif type == NodeType.NT_NT: 39 | is_base_case = False 40 | left_ref1, left_ref2, right_ref1, right_ref2 = cls.expand_nt_nt( 41 | node, ref_node1, 42 | ref_node2) 43 | else: 44 | return 45 | if is_base_case: 46 | return 47 | cls.transform(node[0], left_ref1, left_ref2) 48 | cls.transform(node[1], right_ref1, right_ref2) 49 | 50 | 51 | class LeftCornerTransformer(Transformer): 52 | 53 | @classmethod 54 | def extract_left_corner_no_eps(cls, node: PTree) -> PTree: 55 | while len(node) > 1: 56 | if not is_node_epsilon(node[0]): 57 | node = node[0] 58 | else: 59 | node = node[1] 60 | return node 61 | 62 | @classmethod 63 | def expand_nt(cls, node: PTree, ref_node: PTree) -> (PTree, PTree, PTree, PTree, bool): 64 | leftcorner_node = cls.extract_left_corner(ref_node) 65 | if leftcorner_node == ref_node: # this only happens if the tree only consists of one terminal rule 66 | node.insert(0, ref_node.leaves()[0]) 67 | return None, None, None, None, True 68 | new_right_node = PTree(node.label() + "\\" + leftcorner_node.label(), []) 69 | new_left_node = PTree(leftcorner_node.label(), leftcorner_node.leaves()) 70 | 71 | node.insert(0, new_left_node) 72 | node.insert(1, new_right_node) 73 | return leftcorner_node, leftcorner_node, ref_node, leftcorner_node, False 74 | 75 | @classmethod 76 | def expand_nt_nt(cls, node: PTree, ref_node1: PTree, ref_node2: PTree) -> ( 77 | PTree, PTree, PTree, PTree): 78 | parent_node = ref_node2.parent() 79 | if ref_node1 == parent_node: 80 | new_right_node = PTree(node.label().split("\\")[0] + "\\" + parent_node.label(), 81 | ["EPS"]) 82 | else: 83 | new_right_node = PTree(node.label().split("\\")[0] + "\\" + parent_node.label(), 84 | []) 85 | 86 | sibling_node = ref_node2.right_sibling() 87 | if len(sibling_node) == 1: 88 | new_left_node = PTree(sibling_node.label(), sibling_node.leaves()) 89 | else: 90 | new_left_node = PTree(sibling_node.label(), []) 91 | 92 | node.insert(0, new_left_node) 93 | node.insert(1, new_right_node) 94 | return sibling_node, sibling_node, ref_node1, parent_node 95 | 96 | @classmethod 97 | def rev_transform(cls, node: PTree, ref_node: PTree, pick_up_labels=True) -> PTree: 98 | if find_node_type(ref_node) == NodeType.NT_NT and pick_up_labels: 99 | node.set_label(ref_node[1].label().split("\\")[1]) 100 | if len(ref_node) == 1 and find_node_type(ref_node) == NodeType.PT: # base case 101 | return ref_node 102 | if find_node_type(ref_node[0]) == NodeType.PT and not is_node_epsilon(ref_node[1]): 103 | # X -> word X 104 | pt_node = PTree(ref_node[0].label(), ref_node[0].leaves()) 105 | if node[0] == "": 106 | node[0] = pt_node 107 | else: 108 | node[1] = pt_node 109 | par_node = PTree("X", [node, ""]) 110 | node = par_node 111 | return cls.rev_transform(node, ref_node[1], pick_up_labels) 112 | elif find_node_type(ref_node[0]) != NodeType.PT and is_node_epsilon(ref_node[1]): 113 | # X -> X X-X 114 | if node[0] == "": 115 | raise ValueError( 116 | "When reaching the root the left branch should already exist") 117 | node[1] = cls.rev_transform(PTree("X", ["", ""]), ref_node[0], pick_up_labels) 118 | return node 119 | elif find_node_type(ref_node[0]) == NodeType.PT and is_node_epsilon(ref_node[1]): 120 | # X -> word X-X 121 | if node[0] == "": 122 | raise ValueError( 123 | "When reaching the end of the chain the left branch should already exist") 124 | node[1] = PTree(ref_node[0].label(), ref_node[0].leaves()) 125 | return node 126 | elif find_node_type(ref_node[0]) != NodeType.PT and find_node_type( 127 | ref_node[1]) != NodeType.PT: 128 | # X -> X X 129 | node[1] = cls.rev_transform(PTree("X", ["", ""]), ref_node[0], pick_up_labels) 130 | par_node = PTree("X", [node, ""]) 131 | return cls.rev_transform(par_node, ref_node[1], pick_up_labels) 132 | 133 | 134 | class RightCornerTransformer(Transformer): 135 | @classmethod 136 | def expand_nt(cls, node: PTree, ref_node: PTree) -> (PTree, PTree, PTree, PTree, bool): 137 | rightcorner_node = cls.extract_right_corner(ref_node) 138 | if rightcorner_node == ref_node: 139 | node.insert(0, ref_node.leaves()[0]) 140 | return None, None, None, None, True 141 | new_left_node = PTree(node.label() + "\\" + rightcorner_node.label(), []) 142 | new_right_node = PTree(rightcorner_node.label(), rightcorner_node.leaves()) 143 | 144 | node.insert(0, new_left_node) 145 | node.insert(1, new_right_node) 146 | return ref_node, rightcorner_node, rightcorner_node, rightcorner_node, False 147 | 148 | @classmethod 149 | def expand_nt_nt(cls, node: PTree, ref_node1: PTree, ref_node2: PTree) -> ( 150 | PTree, PTree, PTree, PTree): 151 | parent_node = ref_node2.parent() 152 | if ref_node1 == parent_node: 153 | new_left_node = PTree(node.label().split("\\")[0] + "\\" + parent_node.label(), 154 | ["EPS"]) 155 | else: 156 | new_left_node = PTree(node.label().split("\\")[0] + "\\" + parent_node.label(), []) 157 | 158 | sibling_node = ref_node2.left_sibling() 159 | if len(sibling_node) == 1: 160 | new_right_node = PTree(sibling_node.label(), sibling_node.leaves()) 161 | else: 162 | new_right_node = PTree(sibling_node.label(), []) 163 | 164 | node.insert(0, new_left_node) 165 | node.insert(1, new_right_node) 166 | return ref_node1, parent_node, sibling_node, sibling_node 167 | 168 | @classmethod 169 | def rev_transform(cls, node: PTree, ref_node: PTree, pick_up_labels=True) -> PTree: 170 | if find_node_type(ref_node) == NodeType.NT_NT and pick_up_labels: 171 | node.set_label(ref_node[0].label().split("\\")[1]) 172 | if len(ref_node) == 1 and find_node_type(ref_node) == NodeType.PT: # base case 173 | return ref_node 174 | if find_node_type(ref_node[1]) == NodeType.PT and not is_node_epsilon(ref_node[0]): 175 | # X -> X word 176 | pt_node = PTree(ref_node[1].label(), ref_node[1].leaves()) 177 | if node[1] == "": 178 | node[1] = pt_node 179 | else: 180 | node[0] = pt_node 181 | par_node = PTree("X", ["", node]) 182 | node = par_node 183 | return cls.rev_transform(node, ref_node[0], pick_up_labels) 184 | elif find_node_type(ref_node[1]) != NodeType.PT and is_node_epsilon(ref_node[0]): 185 | # X -> X-X X 186 | if node[1] == "": 187 | raise ValueError( 188 | "When reaching the root the right branch should already exist") 189 | node[0] = cls.rev_transform(PTree("X", ["", ""]), ref_node[1], pick_up_labels) 190 | return node 191 | elif find_node_type(ref_node[1]) == NodeType.PT and is_node_epsilon(ref_node[0]): 192 | # X -> X-X word 193 | if node[1] == "": 194 | raise ValueError( 195 | "When reaching the end of the chain the right branch should already exist") 196 | node[0] = PTree(ref_node[1].label(), ref_node[1].leaves()) 197 | return node 198 | elif find_node_type(ref_node[1]) != NodeType.PT and find_node_type( 199 | ref_node[0]) != NodeType.PT: 200 | # X -> X X 201 | node[0] = cls.rev_transform(PTree("X", ["", ""]), ref_node[1], pick_up_labels) 202 | par_node = PTree("X", ["", node]) 203 | return cls.rev_transform(par_node, ref_node[0], pick_up_labels) 204 | -------------------------------------------------------------------------------- /tagging/tree_tools.py: -------------------------------------------------------------------------------- 1 | import random 2 | import string 3 | from enum import Enum 4 | 5 | import numpy as np 6 | from nltk import ParentedTree 7 | from nltk import Tree 8 | 9 | from const import DUMMY_LABEL 10 | 11 | 12 | class NodeType(Enum): 13 | NT = 0 14 | NT_NT = 1 15 | PT = 2 16 | 17 | 18 | def find_node_type(node: ParentedTree) -> NodeType: 19 | if len(node) == 1: 20 | return NodeType.PT 21 | elif node.label().find("\\") != -1: 22 | return NodeType.NT_NT 23 | else: 24 | return NodeType.NT 25 | 26 | 27 | def is_node_epsilon(node: ParentedTree) -> bool: 28 | node_leaves = node.leaves() 29 | if len(node_leaves) == 1 and node_leaves[0] == "EPS": 30 | return True 31 | return False 32 | 33 | 34 | def is_topo_equal(first: ParentedTree, second: ParentedTree) -> bool: 35 | if len(first) == 1 and len(second) != 1: 36 | return False 37 | if len(first) != 1 and len(second) == str: 38 | return False 39 | if len(first) == 1 and len(second) == 1: 40 | return True 41 | return is_topo_equal(first[0], second[0]) and is_topo_equal(first[1], second[1]) 42 | 43 | 44 | def random_tree(node: ParentedTree, depth=0, p=.75, cutoff=7) -> None: 45 | """ sample a random tree 46 | @param input_str: list of sampled terminals 47 | """ 48 | if np.random.binomial(1, p) == 1 and depth < cutoff: 49 | # add the left child tree 50 | left_label = "X/" + str(depth) 51 | left = ParentedTree(left_label, []) 52 | node.insert(0, left) 53 | random_tree(left, depth=depth + 1, p=p, cutoff=cutoff) 54 | else: 55 | label = "X/" + str(depth) 56 | left = ParentedTree(label, [random.choice(string.ascii_letters)]) 57 | node.insert(0, left) 58 | 59 | if np.random.binomial(1, p) == 1 and depth < cutoff: 60 | # add the right child tree 61 | right_label = "X/" + str(depth) 62 | right = ParentedTree(right_label, []) 63 | node.insert(1, right) 64 | random_tree(right, depth=depth + 1, p=p, cutoff=cutoff) 65 | else: 66 | label = "X/" + str(depth) 67 | right = ParentedTree(label, [random.choice(string.ascii_letters)]) 68 | node.insert(1, right) 69 | 70 | 71 | def create_dummy_tree(leaves): 72 | dummy_tree = Tree("S", []) 73 | idx = 0 74 | for token, pos in leaves: 75 | dummy_tree.insert(idx, Tree(pos, [token])) 76 | idx += 1 77 | 78 | return dummy_tree 79 | 80 | 81 | def remove_plus_from_tree(tree): 82 | if type(tree) == str: 83 | return 84 | label = tree.label() 85 | new_label = label.replace("+", "@") 86 | tree.set_label(new_label) 87 | for child in tree: 88 | remove_plus_from_tree(child) 89 | 90 | 91 | def add_plus_to_tree(tree): 92 | if type(tree) == str: 93 | return 94 | label = tree.label() 95 | new_label = label.replace("@", "+") 96 | tree.set_label(new_label) 97 | for child in tree: 98 | add_plus_to_tree(child) 99 | 100 | 101 | def process_label(label, node_type): 102 | # label format is NT^^^HEAD_IDX 103 | suffix = label[label.rindex("^^^"):] 104 | head_idx = int(label[label.rindex("^^^") + 3:]) 105 | return node_type + suffix, head_idx 106 | 107 | 108 | def extract_head_idx(node): 109 | label = node.label() 110 | return int(label[label.rindex("^^^") + 3:]) 111 | 112 | 113 | def attach_to_tree(ref_node, parent_node, stack): 114 | if len(ref_node) == 1: 115 | new_node = Tree(ref_node.label(), [ref_node[0]]) 116 | parent_node.insert(len(parent_node), new_node) 117 | return stack 118 | 119 | new_node = Tree(ref_node.label(), []) 120 | stack.append((ref_node, new_node)) 121 | parent_node.insert(len(parent_node), new_node) 122 | return stack 123 | 124 | 125 | def relabel_tree(root): 126 | stack = [root] 127 | while len(stack) != 0: 128 | cur_node = stack.pop() 129 | if type(cur_node) == str: 130 | continue 131 | cur_node.set_label(f"X^^^{extract_head_idx(cur_node)}") 132 | for child in cur_node: 133 | stack.append(child) 134 | 135 | 136 | def debinarize_lex_tree(root, new_root): 137 | stack = [(root, new_root)] # (node in binarized tree, node in debinarized tree) 138 | while len(stack) != 0: 139 | ref_node, new_node = stack.pop() 140 | head_idx = extract_head_idx(ref_node) 141 | 142 | # attach the left child 143 | stack = attach_to_tree(ref_node[0], new_node, stack) 144 | 145 | cur_node = ref_node[1] 146 | while cur_node.label().startswith(DUMMY_LABEL): 147 | right_idx = extract_head_idx(cur_node) 148 | head_idx += right_idx 149 | # attach the left child 150 | stack = attach_to_tree(cur_node[0], new_node, stack) 151 | cur_node = cur_node[1] 152 | 153 | # attach the right child 154 | stack = attach_to_tree(cur_node, new_node, stack) 155 | # update the label 156 | new_node.set_label(f"X^^^{head_idx}") 157 | 158 | 159 | def binarize_lex_tree(children, node, node_type): 160 | # node type is X if normal node, Y|X (DUMMY_LABEL) if a dummy node created because of binarization 161 | 162 | if len(children) == 1: 163 | node.insert(0, children[0]) 164 | return node 165 | new_label, head_idx = process_label(node.label(), node_type) 166 | node.set_label(new_label) 167 | if len(children) == 2: 168 | left_node = Tree(children[0].label(), []) 169 | left_node = binarize_lex_tree(children[0], left_node, "X") 170 | right_node = Tree(children[1].label(), []) 171 | right_node = binarize_lex_tree(children[1], right_node, "X") 172 | node.insert(0, left_node) 173 | node.insert(1, right_node) 174 | return node 175 | elif len(children) > 2: 176 | if head_idx > 1: 177 | node.set_label(f"{node_type}^^^1") 178 | if head_idx >= 1: 179 | head_idx -= 1 180 | left_node = Tree(children[0].label(), []) 181 | left_node = binarize_lex_tree(children[0], left_node, "X") 182 | right_node = Tree(f"{DUMMY_LABEL}^^^{head_idx}", []) 183 | right_node = binarize_lex_tree(children[1:], right_node, DUMMY_LABEL) 184 | node.insert(0, left_node) 185 | node.insert(1, right_node) 186 | return node 187 | else: 188 | raise ValueError("node has zero children!") 189 | 190 | 191 | def expand_unary(tree): 192 | if len(tree) == 1: 193 | label = tree.label().split("+") 194 | pos_label = label[1] 195 | tree.set_label(label[0]) 196 | tree[0] = Tree(pos_label, [tree[0]]) 197 | return 198 | for child in tree: 199 | expand_unary(child) 200 | --------------------------------------------------------------------------------