├── .gitignore ├── LICENSE ├── README.md ├── archive_bert.py ├── concat_treebanks.py ├── config ├── archive │ ├── bert-base-multilingual-cased │ │ ├── bert_config.json │ │ └── vocab.txt │ └── bert-large-cased │ │ ├── bert_config.json │ │ └── vocab.txt ├── create_vocab.json ├── create_vocab_sigmorphon.json ├── sigmorphon │ └── multilingual │ │ └── udify_bert_sigmorphon_multilingual.json ├── ud │ ├── en │ │ └── udify_bert_finetune_en_ewt.json │ └── multilingual │ │ └── udify_bert_finetune_multilingual.json └── udify_base.json ├── create_vocabs.py ├── data └── .gitkeep ├── docs ├── udify-architecture.pdf └── udify-architecture.png ├── logs └── .gitkeep ├── predict.py ├── requirements.txt ├── scripts ├── concat_ud_data.sh ├── conll18_ud_eval.py ├── download_sigmorphon_data.sh ├── download_ud_data.sh └── evaluate_2019_task2.py ├── train.py └── udify ├── __init__.py ├── dataset_readers ├── __init__.py ├── conll18_ud_eval.py ├── evaluate_2019_task2.py ├── lemma_edit.py ├── parser.py ├── sigmorphon_2019_task_2.py └── universal_dependencies.py ├── models ├── __init__.py ├── dependency_decoder.py ├── tag_decoder.py └── udify_model.py ├── modules ├── __init__.py ├── bert_pretrained.py ├── residual_rnn.py ├── scalar_mix.py ├── text_field_embedder.py └── token_characters_encoder.py ├── optimizers ├── __init__.py └── ulmfit_sqrt.py ├── predictors ├── __init__.py ├── predictor.py └── text_predictor.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | 107 | # Project specific 108 | /.idea/ 109 | /data/ 110 | /logs/ 111 | /tmp/ 112 | notebooks/bertviz/ 113 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Dan Kondratyuk 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UDify 2 | 3 | [![MIT License](https://img.shields.io/badge/License-MIT-green.svg)](LICENSE) 4 | 5 | UDify is a single model that parses Universal Dependencies (UPOS, UFeats, Lemmas, Deps) jointly, accepting any of 75 6 | supported languages as input (trained on UD v2.3 with 124 treebanks). This repository accompanies the paper, 7 | "[75 Languages, 1 Model: Parsing Universal Dependencies Universally](https://arxiv.org/abs/1904.02099)," 8 | providing tools to train a multilingual model capable of parsing any Universal Dependencies treebank with high 9 | accuracy. This project also supports training and evaluating for the 10 | [SIGMORPHON 2019 Shared Task #2](https://sigmorphon.github.io/sharedtasks/2019/task2/), which achieved 1st place in 11 | morphology tagging (paper can be found [here](https://www.aclweb.org/anthology/W19-4203)). 12 | 13 | Integration with SpaCy is supported by [Camphr](https://github.com/PKSHATechnology-Research/camphr). 14 | 15 | [![UDify Model Architecture](docs/udify-architecture.png)](https://arxiv.org/pdf/1904.02099.pdf) 16 | 17 | The project is built using [AllenNLP](https://allennlp.org/) and [PyTorch](https://pytorch.org/). 18 | 19 | ## Getting Started 20 | 21 | Install the Python packages in `requirements.txt`. UDify depends on AllenNLP and PyTorch. For Windows OS, use 22 | [WSL](https://docs.microsoft.com/en-us/windows/wsl/install-win10). Optionally, install TensorFlow to get access to 23 | TensorBoard to get a rich visualization of model performance on each UD task. 24 | 25 | ```bash 26 | pip install -r ./requirements.txt 27 | ``` 28 | 29 | Download the UD corpus by running the script 30 | 31 | ```bash 32 | bash ./scripts/download_ud_data.sh 33 | ``` 34 | 35 | or alternatively download the data from [universaldependencies.org](https://universaldependencies.org/) and extract 36 | into `data/ud-treebanks-v2.3/`, then run `scripts/concat_ud_data.sh` to generate the multilingual UD dataset. 37 | 38 | ### Training the Model 39 | 40 | Before training, make sure the dataset is downloaded and extracted into the `data` directory and the multilingual 41 | dataset is generated with `scripts/concat_ud_data.sh`. To train the multilingual model (fine-tune UD on BERT), 42 | run the command 43 | 44 | ```bash 45 | python train.py --config config/ud/multilingual/udify_bert_finetune_multilingual.json --name multilingual 46 | ``` 47 | 48 | which will begin loading the dataset and model before training the network. The model metrics, vocab, and weights will 49 | be saved under `logs/multilingual`. Note that this process is highly memory intensive and requires 16+ GB of RAM and 50 | 12+ GB of GPU memory (requirements are half if fp16 is enabled in AllenNLP, but this [requires custom changes to the library](https://github.com/allenai/allennlp/issues/2149)). 51 | The training may take 20 or more days to complete all 80 epochs depending on the type of your GPU. 52 | 53 | ### Training on Other Datasets 54 | 55 | An example config is given for fine-tuning on just English EWT. Just run: 56 | 57 | ```bash 58 | python train.py --config config/ud/en/udify_bert_finetune_en_ewt.json --name en_ewt --dataset_dir data/ud-treebanks-v2.3/ 59 | ``` 60 | 61 | To run your own dataset, copy `config/ud/multilingual/udify_bert_finetune_multilingual.json` and modify the following 62 | json parameters: 63 | 64 | - `train_data_path`, `validation_data_path`, and `test_data_path` to the paths of the dataset conllu files. These can 65 | be optionally `null`. 66 | - `directory_path` to `data/vocab//vocabulary`. 67 | - `warmup_steps` and `start_step` to be equal to the number of steps in the first epoch. A good initial value is in the 68 | range `100-1000`. Alternatively, run the training script first to see the number of steps to the right of the progress 69 | bar. 70 | - If using just one treebank, optionally add `xpos` to the `tasks` list. 71 | 72 | ### Viewing Model Performance 73 | 74 | One can view how well the models are performing by running TensorBoard 75 | 76 | ```bash 77 | tensorboard --logdir logs 78 | ``` 79 | 80 | This should show the currently trained model as well as any other previously trained models. The model will be stored 81 | in a folder specified by the `--name` parameter as well as a date stamp, e.g., `logs/multilingual/2019.07.03_11.08.51`. 82 | 83 | ## Pretrained Models 84 | 85 | [Pretrained models can be found here](http://hdl.handle.net/11234/1-3042). This can be used for predicting conllu 86 | annotations or for fine-tuning. The link contains the following: 87 | 88 | - `udify-model.tar.gz` - The full UDify model archive that can be used for prediction with `predict.py`. Note that this 89 | model has been trained for extra epochs, and may differ slightly from the model shown in the original research paper. 90 | - `udify-bert.tar.gz` - The extracted BERT weights from the UDify model, in 91 | [huggingface transformers (pytorch-pretrained-bert)](https://github.com/huggingface/transformers) format. 92 | 93 | ## Predicting Universal Dependencies from a Trained Model 94 | 95 | To predict UD annotations, one can supply the path to the trained model and an input `conllu`-formatted file: 96 | 97 | ```bash 98 | python predict.py [--eval_file results.json] 99 | ``` 100 | 101 | For instance, predicting the dev set of English EWT with the trained model saved under 102 | `logs/model.tar.gz` and UD treebanks at `data/ud-treebanks-v2.3` can be done with 103 | 104 | ```bash 105 | python predict.py logs/model.tar.gz data/ud-treebanks-v2.3/UD_English-EWT/en_ewt-ud-dev.conllu logs/pred.conllu --eval_file logs/pred.json 106 | ``` 107 | 108 | and will save the output predictions to `logs/pred.conllu` and evaluation to `logs/pred.json`. 109 | 110 | ## Configuration Options 111 | 112 | 1. One can specify the type of device to run on. For a single GPU, use the flag `--device 0`, or `--device -1` for CPU. 113 | 2. To skip waiting for the dataset to be fully loaded into memory, use the flag `--lazy`. Note that the dataset won't be shuffled. 114 | 3. Resume an existing training run with `--resume `. 115 | 4. Specify a config file with `--config `. 116 | 117 | ## SIGMORPHON 2019 Shared Task 118 | 119 | A modification to the basic UDify model is available for parsing morphology in the 120 | [SIGMORPHON 2019 Shared Task #2](https://sigmorphon.github.io/sharedtasks/2019/task2/). The following paper describes 121 | the model in more detail: "[Cross-Lingual Lemmatization and Morphology Tagging with Two-Stage Multilingual BERT Fine-Tuning](https://www.aclweb.org/anthology/W19-4203)". 122 | 123 | Training is similar to UD, just 124 | run `download_sigmorphon_data.sh` and then use the configuration file under `config/sigmorphon/multilingual`, e.g., 125 | 126 | ```bash 127 | python train.py --config config/sigmorphon/multilingual/udify_bert_sigmorphon_multilingual.json --name sigmorphon 128 | ``` 129 | 130 | ## FAQ 131 | 132 | 1. When fine-tuning, my scores/metrics show poor performance. 133 | 134 | It should take about 10 epochs to start seeing good scores coming from all the metrics, and 80 epochs to be competitive 135 | with UDPipe Future. 136 | 137 | One caveat is that if you use a subset of treebanks for fine-tuning instead of all 124 UD v2.3 treebanks, 138 | *you must modify the configuration file*. Make sure to tune the learning rate scheduler to the number of 139 | training steps. Copy the [`udify_bert_finetune_multilingual.json`](https://github.com/Hyperparticle/udify/blob/master/config/ud/multilingual/udify_bert_finetune_multilingual.json) 140 | config and modify the `"warmup_steps"` and `"start_step"` values. A good initial choice would be to set both to be 141 | equal to the number of training batches of one epoch (run the training script first to see the batches remaining, to 142 | the right of the progress bar). 143 | 144 | Have a question not listed here? [Open a GitHub Issue](https://github.com/Hyperparticle/udify/issues). 145 | 146 | ## Citing This Research 147 | 148 | If you use UDify for your research, please cite this work as: 149 | 150 | ```latex 151 | @inproceedings{kondratyuk-straka-2019-75, 152 | title = {75 Languages, 1 Model: Parsing Universal Dependencies Universally}, 153 | author = {Kondratyuk, Dan and Straka, Milan}, 154 | booktitle = {Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)}, 155 | year = {2019}, 156 | address = {Hong Kong, China}, 157 | publisher = {Association for Computational Linguistics}, 158 | url = {https://www.aclweb.org/anthology/D19-1279}, 159 | pages = {2779--2795} 160 | } 161 | ``` 162 | -------------------------------------------------------------------------------- /archive_bert.py: -------------------------------------------------------------------------------- 1 | """ 2 | Extracts a BERT archive from an existing model 3 | """ 4 | 5 | import logging 6 | import argparse 7 | 8 | from udify import util 9 | 10 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 11 | level=logging.INFO) 12 | logger = logging.getLogger(__name__) 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("archive_dir", type=str, help="The directory where model.tar.gz resides") 16 | parser.add_argument("--output_path", default=None, type=str, help="The path to output the file") 17 | 18 | args = parser.parse_args() 19 | 20 | bert_config = "config/archive/bert-base-multilingual-cased/bert_config.json" 21 | util.archive_bert_model(args.archive_dir, bert_config, args.output_path) 22 | -------------------------------------------------------------------------------- /concat_treebanks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Concatenates all treebanks together 3 | """ 4 | 5 | import os 6 | import shutil 7 | import logging 8 | import argparse 9 | 10 | from udify import util 11 | 12 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 13 | level=logging.INFO) 14 | logger = logging.getLogger(__name__) 15 | 16 | parser = argparse.ArgumentParser() 17 | 18 | parser.add_argument("output_dir", type=str, help="The path to output the concatenated files") 19 | parser.add_argument("--dataset_dir", default="data/ud-treebanks-v2.3", type=str, 20 | help="The path containing all UD treebanks") 21 | parser.add_argument("--treebanks", default=[], type=str, nargs="+", 22 | help="Specify a list of treebanks to use; leave blank to default to all treebanks available") 23 | 24 | args = parser.parse_args() 25 | 26 | treebanks = util.get_ud_treebank_files(args.dataset_dir, args.treebanks) 27 | train, dev, test = list(zip(*[treebanks[k] for k in treebanks])) 28 | 29 | for treebank, name in zip([train, dev, test], ["train.conllu", "dev.conllu", "test.conllu"]): 30 | with open(os.path.join(args.output_dir, name), 'w') as write: 31 | for t in treebank: 32 | if not t: 33 | continue 34 | with open(t, 'r') as read: 35 | shutil.copyfileobj(read, write) 36 | -------------------------------------------------------------------------------- /config/archive/bert-base-multilingual-cased/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 768, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 3072, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 12, 11 | "num_hidden_layers": 12, 12 | "pooler_fc_size": 768, 13 | "pooler_num_attention_heads": 12, 14 | "pooler_num_fc_layers": 3, 15 | "pooler_size_per_head": 128, 16 | "pooler_type": "first_token_transform", 17 | "type_vocab_size": 2, 18 | "vocab_size": 119547 19 | } 20 | -------------------------------------------------------------------------------- /config/archive/bert-large-cased/bert_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "attention_probs_dropout_prob": 0.1, 3 | "directionality": "bidi", 4 | "hidden_act": "gelu", 5 | "hidden_dropout_prob": 0.1, 6 | "hidden_size": 1024, 7 | "initializer_range": 0.02, 8 | "intermediate_size": 4096, 9 | "max_position_embeddings": 512, 10 | "num_attention_heads": 16, 11 | "num_hidden_layers": 24, 12 | "pooler_fc_size": 768, 13 | "pooler_num_attention_heads": 12, 14 | "pooler_num_fc_layers": 3, 15 | "pooler_size_per_head": 128, 16 | "pooler_type": "first_token_transform", 17 | "type_vocab_size": 2, 18 | "vocab_size": 28996 19 | } 20 | -------------------------------------------------------------------------------- /config/create_vocab.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "udify_universal_dependencies", 4 | "token_indexers": { 5 | "tokens": { 6 | "type": "single_id", 7 | "lowercase_tokens": false 8 | }, 9 | "token_characters": { 10 | "type": "characters" 11 | } 12 | } 13 | }, 14 | "vocabulary": { 15 | "non_padded_namespaces": ["upos", "xpos", "feats", "lemmas", "*tags", "*labels"], 16 | // Add the special "@@UNKNOWN@@" token to handle dev/test set labels that don't appear in training 17 | "tokens_to_add": { 18 | "upos": ["@@UNKNOWN@@"], 19 | "xpos": ["@@UNKNOWN@@"], 20 | "feats": ["@@UNKNOWN@@"], 21 | "lemmas": ["@@UNKNOWN@@"], 22 | "head_tags": ["@@UNKNOWN@@"] 23 | } 24 | }, 25 | "train_file_paths": {}, 26 | "validation_file_paths": {} 27 | } -------------------------------------------------------------------------------- /config/create_vocab_sigmorphon.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "type": "udify_sigmorphon_2019_task_2", 4 | "token_indexers": { 5 | "tokens": { 6 | "type": "single_id", 7 | "lowercase_tokens": false 8 | }, 9 | "token_characters": { 10 | "type": "characters" 11 | } 12 | } 13 | }, 14 | "vocabulary": { 15 | "non_padded_namespaces": ["feats", "lemmas", "*tags", "*labels"], 16 | // Add the special "@@UNKNOWN@@" token to handle dev/test set labels that don't appear in training 17 | "tokens_to_add": { 18 | "feats": ["@@UNKNOWN@@"], 19 | "lemmas": ["@@UNKNOWN@@"] 20 | } 21 | } 22 | } -------------------------------------------------------------------------------- /config/sigmorphon/multilingual/udify_bert_sigmorphon_multilingual.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "lazy": false, 4 | "type": "udify_sigmorphon_2019_task_2", 5 | "token_indexers": { 6 | "tokens": { 7 | "type": "single_id", 8 | "lowercase_tokens": true 9 | }, 10 | "token_characters": { 11 | "type": "characters" 12 | }, 13 | "bert": { 14 | "type": "udify-bert-pretrained", 15 | "pretrained_model": "config/archive/bert-base-multilingual-cased/vocab.txt", 16 | "do_lowercase": false, 17 | "use_starting_offsets": true 18 | } 19 | } 20 | }, 21 | "train_data_path": "data/sigmorphon-2019/multilingual/train.conllu", 22 | "validation_data_path": "data/sigmorphon-2019/multilingual/dev.conllu", 23 | "test_data_path": "data/sigmorphon-2019/multilingual/test.conllu", 24 | "vocabulary": { 25 | "directory_path": "data/vocab/multilingual/sigmorphon/vocabulary" 26 | }, 27 | "model": { 28 | "word_dropout": 0.25, 29 | "mix_embedding": 12, 30 | "layer_dropout": 0.2, 31 | "tasks": ["feats", "lemmas"], 32 | "text_field_embedder": { 33 | "type": "udify_embedder", 34 | "dropout": 0.5, 35 | "allow_unmatched_keys": true, 36 | "embedder_to_indexer_map": { 37 | "bert": ["bert", "bert-offsets"] 38 | }, 39 | "token_embedders": { 40 | "bert": { 41 | "type": "udify-bert-pretrained", 42 | "pretrained_model": "bert-base-multilingual-cased", 43 | "requires_grad": true, 44 | "dropout": 0.25, 45 | "layer_dropout": 0.25, 46 | "combine_layers": "all" 47 | } 48 | } 49 | }, 50 | "post_encoder_embedder": { 51 | "type": "udify_embedder", 52 | "dropout": 0.5, 53 | "allow_unmatched_keys": true, 54 | "embedder_to_indexer_map": { 55 | "token_characters": ["token_characters"] 56 | }, 57 | "token_characters": { 58 | "type": "udify_character_encoding", 59 | "dropout": 0.5, 60 | "embedding": { 61 | "embedding_dim": 256 62 | }, 63 | "encoder": { 64 | "type": "lstm", 65 | "input_size": 256, 66 | "hidden_size": 384, 67 | "num_layers": 2, 68 | "dropout": 0.05, 69 | "bidirectional": true 70 | } 71 | } 72 | }, 73 | "encoder": { 74 | "type": "pass_through", 75 | "input_dim": 768 76 | }, 77 | "decoders": { 78 | "feats": { 79 | "encoder": { 80 | "type": "udify_residual_rnn", 81 | "input_size": 768, 82 | "hidden_size": 768, 83 | "num_layers": 1, 84 | "dropout": 0.5 85 | }, 86 | "adaptive": false, 87 | "features": [ 88 | "interrogativity", 89 | "deixis", 90 | "evidentiality", 91 | "switch-reference", 92 | "case", 93 | "number", 94 | "language-specific_features", 95 | "valency", 96 | "information_structure", 97 | "tense", 98 | "person", 99 | "part_of_speech", 100 | "possession", 101 | "politeness", 102 | "voice", 103 | "finiteness", 104 | "definiteness", 105 | "mood", 106 | "animacy", 107 | "aspect", 108 | "comparison", 109 | "aktionsart", 110 | "polarity", 111 | "gender", 112 | "argument_marking" 113 | ] 114 | }, 115 | "lemmas": { 116 | "encoder": { 117 | "type": "udify_residual_rnn", 118 | "input_size": 768, 119 | "hidden_size": 768, 120 | "num_layers": 2, 121 | "dropout": 0.5 122 | }, 123 | "adaptive": false 124 | } 125 | } 126 | }, 127 | "iterator": { 128 | "batch_size": 32, 129 | "maximum_samples_per_batch": ["num_tokens", 32 * 100] 130 | }, 131 | "trainer": { 132 | "num_epochs": 50, 133 | "patience": 12, 134 | "num_serialized_models_to_keep": 4, 135 | "should_log_learning_rate": true, 136 | "should_log_parameter_statistics": true, 137 | "summary_interval": 200, 138 | "optimizer": { 139 | "type": "bert_adam", 140 | "b1": 0.9, 141 | "b2": 0.99, 142 | "weight_decay": 0.01, 143 | "lr": 2e-3, 144 | "parameter_groups": [ 145 | [["^text_field_embedder.*.bert_model.embeddings", 146 | "^text_field_embedder.*.bert_model.encoder"], {}], 147 | [["^text_field_embedder.*._scalar_mix", 148 | "^text_field_embedder.*.pooler", 149 | "^post_encoder_embedder", 150 | "^scalar_mix", 151 | "^decoders", 152 | "^shared_encoder"], {}] 153 | ] 154 | }, 155 | "learning_rate_scheduler": { 156 | "type": "ulmfit_sqrt", 157 | "model_size": 1, 158 | "warmup_steps": 100, 159 | "start_step": 100, 160 | "factor": 5.0, 161 | "gradual_unfreezing": true, 162 | "discriminative_fine_tuning": true, 163 | "decay_factor": 0.02, 164 | "steepness": 0.50 165 | } 166 | }, 167 | "udify_replace": [ 168 | "dataset_reader.token_indexers", 169 | "model.text_field_embedder", 170 | "model.encoder", 171 | "model.decoders.xpos", 172 | "model.decoders.deps", 173 | "model.decoders.upos", 174 | "model.decoders.feats.encoder", 175 | "model.decoders.lemmas.encoder", 176 | "trainer.learning_rate_scheduler", 177 | "trainer.optimizer", 178 | "vocabulary.directory_path" 179 | ] 180 | } -------------------------------------------------------------------------------- /config/ud/en/udify_bert_finetune_en_ewt.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "lazy": false, 4 | "token_indexers": { 5 | "tokens": { 6 | "type": "single_id", 7 | "lowercase_tokens": true 8 | }, 9 | "bert": { 10 | "type": "udify-bert-pretrained", 11 | "pretrained_model": "config/archive/bert-base-multilingual-cased/vocab.txt", 12 | "do_lowercase": false, 13 | "use_starting_offsets": true 14 | } 15 | } 16 | }, 17 | "train_data_path": "data/ud-treebanks-v2.3/UD_English-EWT/en_ewt-ud-train.conllu", 18 | "validation_data_path": "data/ud-treebanks-v2.3/UD_English-EWT/en_ewt-ud-dev.conllu", 19 | "test_data_path": "data/ud-treebanks-v2.3/UD_English-EWT/en_ewt-ud-test.conllu", 20 | "vocabulary": { 21 | "directory_path": "data/vocab/en_ewt/vocabulary" 22 | }, 23 | "model": { 24 | "word_dropout": 0.2, 25 | "mix_embedding": 12, 26 | "layer_dropout": 0.1, 27 | "tasks": ["upos", "feats", "lemmas", "deps"], 28 | "text_field_embedder": { 29 | "type": "udify_embedder", 30 | "dropout": 0.5, 31 | "allow_unmatched_keys": true, 32 | "embedder_to_indexer_map": { 33 | "bert": ["bert", "bert-offsets"] 34 | }, 35 | "token_embedders": { 36 | "bert": { 37 | "type": "udify-bert-pretrained", 38 | "pretrained_model": "bert-base-multilingual-cased", 39 | "requires_grad": true, 40 | "dropout": 0.15, 41 | "layer_dropout": 0.1, 42 | "combine_layers": "all" 43 | } 44 | } 45 | }, 46 | "encoder": { 47 | "type": "pass_through", 48 | "input_dim": 768 49 | }, 50 | "decoders": { 51 | "upos": { 52 | "encoder": { 53 | "type": "pass_through", 54 | "input_dim": 768 55 | } 56 | }, 57 | "feats": { 58 | "encoder": { 59 | "type": "pass_through", 60 | "input_dim": 768 61 | }, 62 | "adaptive": true 63 | }, 64 | "lemmas": { 65 | "encoder": { 66 | "type": "pass_through", 67 | "input_dim": 768 68 | }, 69 | "adaptive": true 70 | }, 71 | "deps": { 72 | "tag_representation_dim": 256, 73 | "arc_representation_dim": 768, 74 | "encoder": { 75 | "type": "pass_through", 76 | "input_dim": 768 77 | } 78 | } 79 | } 80 | }, 81 | "iterator": { 82 | "batch_size": 32, 83 | "maximum_samples_per_batch": ["num_tokens", 32 * 100] 84 | }, 85 | "trainer": { 86 | "num_epochs": 80, 87 | "patience": 80, 88 | "num_serialized_models_to_keep": 1, 89 | "should_log_learning_rate": true, 90 | "summary_interval": 100, 91 | "optimizer": { 92 | "type": "bert_adam", 93 | "b1": 0.9, 94 | "b2": 0.99, 95 | "weight_decay": 0.01, 96 | "lr": 1e-3, 97 | "parameter_groups": [ 98 | [["^text_field_embedder.*.bert_model.embeddings", 99 | "^text_field_embedder.*.bert_model.encoder"], {}], 100 | [["^text_field_embedder.*._scalar_mix", 101 | "^text_field_embedder.*.pooler", 102 | "^scalar_mix", 103 | "^decoders", 104 | "^shared_encoder"], {}] 105 | ] 106 | }, 107 | "learning_rate_scheduler": { 108 | "type": "ulmfit_sqrt", 109 | "model_size": 1, 110 | "warmup_steps": 392, 111 | "start_step": 392, 112 | "factor": 5.0, 113 | "gradual_unfreezing": true, 114 | "discriminative_fine_tuning": true, 115 | "decay_factor": 0.04 116 | } 117 | }, 118 | "udify_replace": [ 119 | "dataset_reader.token_indexers", 120 | "model.text_field_embedder", 121 | "model.encoder", 122 | "model.decoders.xpos", 123 | "model.decoders.deps.encoder", 124 | "model.decoders.upos.encoder", 125 | "model.decoders.feats.encoder", 126 | "model.decoders.lemmas.encoder", 127 | "trainer.learning_rate_scheduler", 128 | "trainer.optimizer" 129 | ] 130 | } -------------------------------------------------------------------------------- /config/ud/multilingual/udify_bert_finetune_multilingual.json: -------------------------------------------------------------------------------- 1 | { 2 | "dataset_reader": { 3 | "lazy": false, 4 | "token_indexers": { 5 | "tokens": { 6 | "type": "single_id", 7 | "lowercase_tokens": true 8 | }, 9 | "bert": { 10 | "type": "udify-bert-pretrained", 11 | "pretrained_model": "config/archive/bert-base-multilingual-cased/vocab.txt", 12 | "do_lowercase": false, 13 | "use_starting_offsets": true 14 | } 15 | } 16 | }, 17 | "train_data_path": "data/ud/multilingual/train.conllu", 18 | "validation_data_path": "data/ud/multilingual/dev.conllu", 19 | "test_data_path": "data/ud/multilingual/test.conllu", 20 | "vocabulary": { 21 | "directory_path": "data/vocab/multilingual/vocabulary" 22 | }, 23 | "model": { 24 | "word_dropout": 0.2, 25 | "mix_embedding": 12, 26 | "layer_dropout": 0.1, 27 | "tasks": ["upos", "feats", "lemmas", "deps"], 28 | "text_field_embedder": { 29 | "type": "udify_embedder", 30 | "dropout": 0.5, 31 | "allow_unmatched_keys": true, 32 | "embedder_to_indexer_map": { 33 | "bert": ["bert", "bert-offsets"] 34 | }, 35 | "token_embedders": { 36 | "bert": { 37 | "type": "udify-bert-pretrained", 38 | "pretrained_model": "bert-base-multilingual-cased", 39 | "requires_grad": true, 40 | "dropout": 0.15, 41 | "layer_dropout": 0.1, 42 | "combine_layers": "all" 43 | } 44 | } 45 | }, 46 | "encoder": { 47 | "type": "pass_through", 48 | "input_dim": 768 49 | }, 50 | "decoders": { 51 | "upos": { 52 | "encoder": { 53 | "type": "pass_through", 54 | "input_dim": 768 55 | } 56 | }, 57 | "feats": { 58 | "encoder": { 59 | "type": "pass_through", 60 | "input_dim": 768 61 | }, 62 | "adaptive": true 63 | }, 64 | "lemmas": { 65 | "encoder": { 66 | "type": "pass_through", 67 | "input_dim": 768 68 | }, 69 | "adaptive": true 70 | }, 71 | "deps": { 72 | "tag_representation_dim": 256, 73 | "arc_representation_dim": 768, 74 | "encoder": { 75 | "type": "pass_through", 76 | "input_dim": 768 77 | } 78 | } 79 | } 80 | }, 81 | "iterator": { 82 | "batch_size": 32, 83 | "maximum_samples_per_batch": ["num_tokens", 32 * 100] 84 | }, 85 | "trainer": { 86 | "num_epochs": 80, 87 | "patience": 80, 88 | "num_serialized_models_to_keep": 10, 89 | "keep_serialized_model_every_num_seconds": 2 * 60 * 60, 90 | "model_save_interval": 1 * 60 * 60, 91 | "should_log_learning_rate": true, 92 | "should_log_parameter_statistics": true, 93 | "summary_interval": 2500, 94 | "optimizer": { 95 | "type": "bert_adam", 96 | "b1": 0.9, 97 | "b2": 0.99, 98 | "weight_decay": 0.01, 99 | "lr": 1e-3, 100 | "parameter_groups": [ 101 | [["^text_field_embedder.*.bert_model.embeddings", 102 | "^text_field_embedder.*.bert_model.encoder"], {}], 103 | [["^text_field_embedder.*._scalar_mix", 104 | "^text_field_embedder.*.pooler", 105 | "^scalar_mix", 106 | "^decoders", 107 | "^shared_encoder"], {}] 108 | ] 109 | }, 110 | "learning_rate_scheduler": { 111 | "type": "ulmfit_sqrt", 112 | "model_size": 1, 113 | "warmup_steps": 8000, 114 | "start_step": 21695, 115 | "factor": 5.0, 116 | "gradual_unfreezing": true, 117 | "discriminative_fine_tuning": true, 118 | "decay_factor": 0.04 119 | } 120 | }, 121 | "udify_replace": [ 122 | "dataset_reader.token_indexers", 123 | "model.text_field_embedder", 124 | "model.encoder", 125 | "model.decoders.xpos", 126 | "model.decoders.deps.encoder", 127 | "model.decoders.upos.encoder", 128 | "model.decoders.feats.encoder", 129 | "model.decoders.lemmas.encoder", 130 | "trainer.learning_rate_scheduler", 131 | "trainer.optimizer" 132 | ] 133 | } -------------------------------------------------------------------------------- /config/udify_base.json: -------------------------------------------------------------------------------- 1 | // The base configuration for UD parsing 2 | { 3 | "dataset_reader": { 4 | "type": "udify_universal_dependencies", 5 | "lazy": false, 6 | "token_indexers": { 7 | "tokens": { 8 | "type": "single_id", 9 | "lowercase_tokens": true 10 | }, 11 | "token_characters": { 12 | "type": "characters", 13 | "min_padding_length": 1 14 | } 15 | } 16 | }, 17 | "vocabulary": { 18 | "directory_path": "data/vocab/UD_English-EWT/vocabulary", 19 | "non_padded_namespaces": ["upos", "xpos", "feats", "lemmas", "*tags", "*labels"] 20 | }, 21 | "train_data_path": "data/ud-treebanks-v2.3/UD_English-EWT/en_ewt-ud-train.conllu", 22 | "validation_data_path": "data/ud-treebanks-v2.3/UD_English-EWT/en_ewt-ud-dev.conllu", 23 | "test_data_path": "data/ud-treebanks-v2.3/UD_English-EWT/en_ewt-ud-test.conllu", 24 | "evaluate_on_test": true, 25 | "model": { 26 | local global_dropout = 0.5, 27 | local label_smoothing = 0.03, 28 | 29 | local word_embedding_dim = 512, 30 | local char_embedding_dim = 256, 31 | local hidden_dim = 512, 32 | 33 | local shared_layers = 2, 34 | local task_layers = 1, 35 | 36 | "type": "udify_model", 37 | "tasks": ["upos", "xpos", "feats", "lemmas", "deps"], 38 | "dropout": global_dropout, 39 | "word_dropout": 0.2, 40 | "text_field_embedder": { 41 | "type": "udify_embedder", 42 | "sum_embeddings": ["tokens", "token_characters"], 43 | "dropout": global_dropout, 44 | "token_embedders": { 45 | "tokens": { 46 | "type": "embedding", 47 | "embedding_dim": word_embedding_dim 48 | }, 49 | "token_characters": { 50 | "type": "udify_character_encoding", 51 | "dropout": global_dropout, 52 | "embedding": { 53 | "embedding_dim": char_embedding_dim 54 | }, 55 | "encoder": { 56 | "type": "lstm", 57 | "input_size": char_embedding_dim, 58 | "hidden_size": char_embedding_dim, 59 | "num_layers": task_layers, 60 | "dropout": 0.0, 61 | "bidirectional": true 62 | } 63 | } 64 | } 65 | }, 66 | "encoder": { 67 | "type": "udify_residual_rnn", 68 | "input_size": hidden_dim, 69 | "hidden_size": hidden_dim, 70 | "num_layers": shared_layers, 71 | "dropout": global_dropout 72 | }, 73 | "decoders": { 74 | "upos": { 75 | "type": "udify_tag_decoder", 76 | "task": "upos", 77 | "encoder": { 78 | "type": "udify_residual_rnn", 79 | "input_size": hidden_dim, 80 | "hidden_size": hidden_dim, 81 | "num_layers": task_layers, 82 | "dropout": global_dropout 83 | }, 84 | "label_smoothing": label_smoothing, 85 | "dropout": global_dropout 86 | }, 87 | "xpos": { 88 | "type": "udify_tag_decoder", 89 | "task": "xpos", 90 | "encoder": { 91 | "type": "udify_residual_rnn", 92 | "input_size": hidden_dim, 93 | "hidden_size": hidden_dim, 94 | "num_layers": task_layers, 95 | "dropout": global_dropout 96 | }, 97 | "label_smoothing": label_smoothing, 98 | "dropout": global_dropout 99 | }, 100 | "feats": { 101 | "type": "udify_tag_decoder", 102 | "task": "feats", 103 | "encoder": { 104 | "type": "udify_residual_rnn", 105 | "input_size": hidden_dim, 106 | "hidden_size": hidden_dim, 107 | "num_layers": task_layers, 108 | "dropout": global_dropout 109 | }, 110 | "label_smoothing": label_smoothing, 111 | "dropout": global_dropout 112 | }, 113 | "lemmas": { 114 | "type": "udify_tag_decoder", 115 | "task": "lemmas", 116 | "encoder": { 117 | "type": "udify_residual_rnn", 118 | "input_size": hidden_dim, 119 | "hidden_size": hidden_dim, 120 | "num_layers": task_layers, 121 | "dropout": global_dropout 122 | }, 123 | "label_smoothing": label_smoothing, 124 | "dropout": global_dropout 125 | }, 126 | "deps": { 127 | "type": "udify_dependency_decoder", 128 | "pos_embed_dim": null, 129 | "tag_representation_dim": 128, 130 | "arc_representation_dim": 512, 131 | "dropout": global_dropout, 132 | "encoder": { 133 | "type": "udify_residual_rnn", 134 | "input_size": hidden_dim, 135 | "hidden_size": hidden_dim, 136 | "num_layers": task_layers, 137 | "dropout": global_dropout, 138 | "residual": false 139 | } 140 | } 141 | } 142 | }, 143 | "iterator": { 144 | "type": "bucket", 145 | "batch_size": 32, 146 | "sorting_keys": [["tokens", "num_tokens"]], 147 | "biggest_batch_first": true 148 | // "track_epoch": true 149 | }, 150 | "trainer": { 151 | "optimizer": { 152 | "type": "adam", 153 | "lr": 4e-3, 154 | "betas": [0.9, 0.99] 155 | }, 156 | "learning_rate_scheduler": { 157 | "type": "multi_step", 158 | "milestones": [0, 1, 15, 20, 30, 40], 159 | "gamma": 0.5 160 | }, 161 | "num_epochs": 50, 162 | "patience": 50, 163 | "validation_metric": "+.run/.sum", 164 | "should_log_learning_rate": false, 165 | "should_log_parameter_statistics": false, 166 | "summary_interval": 500, 167 | "num_serialized_models_to_keep": 1, 168 | "grad_norm": 5.0, 169 | "grad_clipping": 10.0, 170 | "cuda_device": 0 171 | }, 172 | // "random_seed": 13370, 173 | // "numpy_seed": 1337, 174 | // "pytorch_seed": 133 175 | } 176 | -------------------------------------------------------------------------------- /create_vocabs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Creates vocab files for all treebanks in the given directory 3 | """ 4 | 5 | import os 6 | import json 7 | import logging 8 | import argparse 9 | 10 | from allennlp.commands.make_vocab import make_vocab_from_params 11 | from allennlp.common import Params 12 | from allennlp.common.util import import_submodules 13 | 14 | from udify import util 15 | 16 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 17 | level=logging.INFO) 18 | logger = logging.getLogger(__name__) 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--dataset_dir", default="data/ud-treebanks-v2.3", type=str, 22 | help="The path containing all UD treebanks") 23 | parser.add_argument("--output_dir", default="data/vocab", type=str, help="The path to save all vocabulary files") 24 | parser.add_argument("--treebanks", default=[], type=str, nargs="+", 25 | help="Specify a list of treebanks to use; leave blank to default to all treebanks available") 26 | parser.add_argument("--params_file", default=None, type=str, help="The path to the vocab params") 27 | args = parser.parse_args() 28 | 29 | import_submodules("udify") 30 | 31 | params_file = util.VOCAB_CONFIG_PATH if not args.params_file else args.params_file 32 | 33 | treebanks = sorted(util.get_ud_treebank_files(args.dataset_dir, args.treebanks).items()) 34 | for treebank, (train_file, dev_file, test_file) in treebanks: 35 | logger.info(f"Creating vocabulary for treebank {treebank}") 36 | 37 | if not train_file: 38 | logger.info(f"No training data for {treebank}, skipping") 39 | continue 40 | 41 | overrides = json.dumps({ 42 | "train_data_path": train_file, 43 | "validation_data_path": dev_file, 44 | "test_data_path": test_file 45 | }) 46 | params = Params.from_file(params_file, overrides) 47 | output_file = os.path.join(args.output_dir, treebank) 48 | 49 | make_vocab_from_params(params, output_file) 50 | -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hyperparticle/udify/18d63ac1b2da5a1afea58f317ade79bc84910450/data/.gitkeep -------------------------------------------------------------------------------- /docs/udify-architecture.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hyperparticle/udify/18d63ac1b2da5a1afea58f317ade79bc84910450/docs/udify-architecture.pdf -------------------------------------------------------------------------------- /docs/udify-architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hyperparticle/udify/18d63ac1b2da5a1afea58f317ade79bc84910450/docs/udify-architecture.png -------------------------------------------------------------------------------- /logs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Hyperparticle/udify/18d63ac1b2da5a1afea58f317ade79bc84910450/logs/.gitkeep -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | """ 2 | Predict conllu files given a trained model 3 | """ 4 | 5 | import os 6 | import shutil 7 | import logging 8 | import argparse 9 | import tarfile 10 | from pathlib import Path 11 | 12 | from allennlp.common import Params 13 | from allennlp.common.util import import_submodules 14 | from allennlp.models.archival import archive_model 15 | 16 | from udify import util 17 | 18 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 19 | level=logging.INFO) 20 | logger = logging.getLogger(__name__) 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("archive", type=str, help="The archive file") 24 | parser.add_argument("input_file", type=str, help="The input file to predict") 25 | parser.add_argument("pred_file", type=str, help="The output prediction file") 26 | parser.add_argument("--eval_file", default=None, type=str, 27 | help="If set, evaluate the prediction and store it in the given file") 28 | parser.add_argument("--device", default=0, type=int, help="CUDA device number; set to -1 for CPU") 29 | parser.add_argument("--batch_size", default=1, type=int, help="The size of each prediction batch") 30 | parser.add_argument("--lazy", action="store_true", help="Lazy load dataset") 31 | parser.add_argument("--raw_text", action="store_true", help="Input raw sentences, one per line in the input file.") 32 | 33 | args = parser.parse_args() 34 | 35 | import_submodules("udify") 36 | 37 | archive_dir = Path(args.archive).resolve().parent 38 | 39 | if not os.path.isfile(archive_dir / "weights.th"): 40 | with tarfile.open(args.archive) as tar: 41 | tar.extractall(archive_dir) 42 | 43 | config_file = archive_dir / "config.json" 44 | 45 | overrides = {} 46 | if args.device is not None: 47 | overrides["trainer"] = {"cuda_device": args.device} 48 | if args.lazy: 49 | overrides["dataset_reader"] = {"lazy": args.lazy} 50 | configs = [Params(overrides), Params.from_file(config_file)] 51 | params = util.merge_configs(configs) 52 | 53 | predictor = "udify_predictor" if not args.raw_text else "udify_text_predictor" 54 | 55 | if not args.eval_file: 56 | util.predict_model_with_archive(predictor, params, archive_dir, args.input_file, args.pred_file, 57 | batch_size=args.batch_size) 58 | else: 59 | util.predict_and_evaluate_model_with_archive(predictor, params, archive_dir, args.input_file, 60 | args.pred_file, args.eval_file, batch_size=args.batch_size) 61 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | allennlp==0.9.0 2 | torch==1.4.0 3 | tensorflow 4 | pandas 5 | jupyter 6 | conllu 7 | -------------------------------------------------------------------------------- /scripts/concat_ud_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATASET_DIR="data/ud-treebanks-v2.3" 4 | 5 | echo "Generating multilingual dataset..." 6 | 7 | mkdir -p "data/ud" 8 | mkdir -p "data/ud/multilingual" 9 | 10 | python concat_treebanks.py data/ud/multilingual --dataset_dir ${DATASET_DIR} -------------------------------------------------------------------------------- /scripts/conll18_ud_eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Runs the CoNLL 2018 Shared Task Evaluation 5 | """ 6 | 7 | from udify.dataset_readers.conll18_ud_eval import main 8 | 9 | if __name__ == "__main__": 10 | main() 11 | -------------------------------------------------------------------------------- /scripts/download_sigmorphon_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Download SIGMORPHON 2019 Shared Task data 4 | DATA="https://github.com/sigmorphon/2019.git" 5 | 6 | GIT_DIR="data/sigmorphon-2019" 7 | DATASET_DIR="${GIT_DIR}/task2" 8 | 9 | 10 | echo "Downloading SIGMORPHON data..." 11 | 12 | git clone ${DATA} ${GIT_DIR} 13 | 14 | 15 | echo "Generating multilingual dataset..." 16 | 17 | mkdir -p "data/sigmorphon-2019/multilingual" 18 | 19 | python concat_treebanks.py data/sigmorphon-2019/multilingual --dataset_dir ${DATASET_DIR} 20 | -------------------------------------------------------------------------------- /scripts/download_ud_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Can download UD 2.3 or 2.4 4 | UD_2_3="https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-2895/ud-treebanks-v2.3.tgz?sequence=1&isAllowed=y" 5 | UD_2_4="https://lindat.mff.cuni.cz/repository/xmlui/bitstream/handle/11234/1-2988/ud-treebanks-v2.4.tgz?sequence=4&isAllowed=y" 6 | 7 | DATASET_DIR="data/ud-treebanks-v2.3" 8 | 9 | ARCHIVE="ud_data.tgz" 10 | 11 | 12 | echo "Downloading UD data..." 13 | 14 | curl ${UD_2_3} -o ${ARCHIVE} 15 | 16 | tar -xvzf ${ARCHIVE} -C ./data 17 | mv ${ARCHIVE} ./data 18 | 19 | 20 | echo "Generating multilingual dataset..." 21 | 22 | mkdir -p "data/ud" 23 | mkdir -p "data/ud/multilingual" 24 | 25 | python concat_treebanks.py data/ud/multilingual --dataset_dir ${DATASET_DIR} 26 | -------------------------------------------------------------------------------- /scripts/evaluate_2019_task2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Evaluation for the SIGMORPHON 2019 shared task, task 2. 3 | 4 | Computes various metrics on input. 5 | 6 | Author: Arya D. McCarthy 7 | Last update: 2018-12-21 8 | """ 9 | 10 | import argparse 11 | import contextlib 12 | import io 13 | import logging 14 | import sys 15 | import json 16 | 17 | import numpy as np 18 | 19 | from collections import Counter, namedtuple 20 | from pathlib import Path 21 | 22 | log = logging.getLogger(Path(__file__).stem) 23 | 24 | 25 | COLUMNS = "ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC".split() 26 | ConlluRow = namedtuple("ConlluRow", COLUMNS) 27 | SEPARATOR = ";" 28 | 29 | 30 | def distance(str1, str2): 31 | """Simple Levenshtein implementation.""" 32 | m = np.zeros([len(str2)+1, len(str1)+1]) 33 | for x in range(1, len(str2) + 1): 34 | m[x][0] = m[x-1][0] + 1 35 | for y in range(1, len(str1) + 1): 36 | m[0][y] = m[0][y-1] + 1 37 | for x in range(1, len(str2) + 1): 38 | for y in range(1, len(str1) + 1): 39 | if str1[y-1] == str2[x-1]: 40 | dg = 0 41 | else: 42 | dg = 1 43 | m[x][y] = min(m[x-1][y] + 1, m[x][y-1] + 1, m[x-1][y-1] + dg) 44 | return int(m[len(str2)][len(str1)]) 45 | 46 | 47 | def set_equal(str1, str2): 48 | set1 = set(str1.split(SEPARATOR)) 49 | set2 = set(str2.split(SEPARATOR)) 50 | return set1 == set2 51 | 52 | 53 | def manipulate_data(pairs): 54 | log.info("Lemma acc, Lemma Levenshtein, morph acc, morph F1") 55 | 56 | count = 0 57 | lemma_acc = 0 58 | lemma_lev = 0 59 | morph_acc = 0 60 | 61 | f1_precision_scores = 0 62 | f1_precision_counts = 0 63 | f1_recall_scores = 0 64 | f1_recall_counts = 0 65 | 66 | for r, o in pairs: 67 | log.debug("{}\t{}\t{}\t{}".format(r.LEMMA, o.LEMMA, r.FEATS, o.FEATS)) 68 | count += 1 69 | lemma_acc += (r.LEMMA == o.LEMMA) 70 | lemma_lev += distance(r.LEMMA, o.LEMMA) 71 | morph_acc += set_equal(r.FEATS, o.FEATS) 72 | 73 | r_feats = set(r.FEATS.split(SEPARATOR)) - {"_"} 74 | o_feats = set(o.FEATS.split(SEPARATOR)) - {"_"} 75 | 76 | union_size = len(r_feats & o_feats) 77 | reference_size = len(r_feats) 78 | output_size = len(o_feats) 79 | 80 | f1_precision_scores += union_size 81 | f1_recall_scores += union_size 82 | f1_precision_counts += output_size 83 | f1_recall_counts += reference_size 84 | 85 | f1_precision = f1_precision_scores / (f1_precision_counts or 1) 86 | f1_recall = f1_recall_scores / (f1_recall_counts or 1) 87 | f1 = 2 * (f1_precision * f1_recall) / (f1_precision + f1_recall + 1E-20) 88 | 89 | return (100 * lemma_acc / count, lemma_lev / count, 100 * morph_acc / count, 100 * f1) 90 | 91 | 92 | def parse_args(): 93 | """Parse command line arguments.""" 94 | parser = argparse.ArgumentParser(description=__doc__) 95 | parser.add_argument('-r', '--reference', 96 | type=Path, required=True) 97 | parser.add_argument('-o', '--output', 98 | type=Path, required=True) 99 | # Set the verbosity level for the logger. The `-v` option will set it to 100 | # the debug level, while the `-q` will set it to the warning level. 101 | # Otherwise use the info level. 102 | verbosity = parser.add_mutually_exclusive_group() 103 | verbosity.add_argument('-v', '--verbose', action='store_const', 104 | const=logging.DEBUG, default=logging.INFO) 105 | verbosity.add_argument('-q', '--quiet', dest='verbose', 106 | action='store_const', const=logging.WARNING) 107 | return parser.parse_args() 108 | 109 | 110 | def strip_comments(lines): 111 | for line in lines: 112 | if not line.startswith("#"): 113 | yield line 114 | 115 | 116 | def read_conllu(file: Path): 117 | with open(file) as f: 118 | yield from strip_comments(f) 119 | 120 | 121 | def input_pairs(reference, output): 122 | output = [o for o in output if len(o.split()) == 0 or '.' not in o.split()[0]] 123 | reference = [r for r in reference if len(r.split()) == 0 or '.' not in r.split()[0]] 124 | 125 | count = 0 126 | 127 | for r, o in zip(reference, output): 128 | count += 1 129 | assert r.count("\t") == o.count("\t"), (r.count("\t"), o.count("\t"), o) 130 | if r.count("\t") > 0: 131 | r_conllu = ConlluRow._make(r.split("\t")) 132 | o_conllu = ConlluRow._make(o.split("\t")) 133 | yield r_conllu, o_conllu 134 | 135 | 136 | def main(): 137 | args = parse_args() 138 | logging.basicConfig(level=args.verbose) 139 | reference = read_conllu(args.reference) 140 | output = read_conllu(args.output) 141 | results = manipulate_data(input_pairs(reference, output)) 142 | print(*["{0:.2f}".format(v) for v in results], sep='\t') 143 | 144 | results_keys = ["lemma_acc", "lemma_dist", "msd_acc", "msd_f1"] 145 | output_dict = {k: v for k, v in zip(results_keys, results)} 146 | print(json.dumps(output_dict, indent=4)) 147 | 148 | if __name__ == "__main__": 149 | main() 150 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training script useful for debugging UDify and AllenNLP code 3 | """ 4 | 5 | import os 6 | import copy 7 | import datetime 8 | import logging 9 | import argparse 10 | import glob 11 | 12 | from allennlp.common import Params 13 | from allennlp.common.util import import_submodules 14 | from allennlp.commands.train import train_model 15 | 16 | from udify import util 17 | 18 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 19 | level=logging.INFO) 20 | logger = logging.getLogger(__name__) 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--name", default="", type=str, help="Log dir name") 24 | parser.add_argument("--base_config", default="config/udify_base.json", type=str, help="Base configuration file") 25 | parser.add_argument("--config", default=[], type=str, nargs="+", help="Overriding configuration files") 26 | parser.add_argument("--dataset_dir", default="data/ud-treebanks-v2.5", type=str, help="The path containing all UD treebanks") 27 | parser.add_argument("--batch_size", default=32, type=int, help="The batch size used by the model; the number of training sentences is divided by this number.") 28 | parser.add_argument("--device", default=None, type=int, help="CUDA device; set to -1 for CPU") 29 | parser.add_argument("--resume", type=str, help="Resume training with the given model") 30 | parser.add_argument("--lazy", default=None, action="store_true", help="Lazy load the dataset") 31 | parser.add_argument("--cleanup_archive", action="store_true", help="Delete the model archive") 32 | parser.add_argument("--replace_vocab", action="store_true", help="Create a new vocab and replace the cached one") 33 | parser.add_argument("--archive_bert", action="store_true", help="Archives the finetuned BERT model after training") 34 | parser.add_argument("--predictor", default="udify_predictor", type=str, help="The type of predictor to use") 35 | 36 | args = parser.parse_args() 37 | 38 | log_dir_name = args.name 39 | if not log_dir_name: 40 | file_name = args.config[0] if args.config else args.base_config 41 | log_dir_name = os.path.basename(file_name).split(".")[0] 42 | 43 | if not args.name == "multilingual": 44 | train_file = args.name + "-ud-train.conllu" 45 | pathname = os.path.join(args.dataset_dir, "*", train_file) 46 | train_path = glob.glob(pathname).pop() 47 | treebank_path = os.path.dirname(train_path) 48 | 49 | if train_path: 50 | logger.info(f"found training file: {train_path}, calculating the warmup and start steps") 51 | 52 | f = open(train_path, 'r', encoding="utf-8") 53 | sentence_count = 0 54 | for line in f.readlines(): 55 | if line.isspace(): 56 | sentence_count += 1 57 | num_warmup_steps = round(sentence_count / args.batch_size) 58 | 59 | configs = [] 60 | 61 | if not args.resume: 62 | serialization_dir = os.path.join("logs", log_dir_name, datetime.datetime.now().strftime("%Y.%m.%d_%H.%M.%S")) 63 | 64 | overrides = {} 65 | if args.device is not None: 66 | overrides["trainer"] = {"cuda_device": args.device} 67 | if args.lazy is not None: 68 | overrides["dataset_reader"] = {"lazy": args.lazy} 69 | configs.append(Params(overrides)) 70 | for config_file in args.config: 71 | configs.append(Params.from_file(config_file)) 72 | configs.append(Params.from_file(args.base_config)) 73 | else: 74 | serialization_dir = args.resume 75 | configs.append(Params.from_file(os.path.join(serialization_dir, "config.json"))) 76 | 77 | train_params = util.merge_configs(configs) 78 | 79 | if not args.name == "multilingual": 80 | # overwrite the default params with the language-specific ones 81 | for param in train_params: 82 | if param == "train_data_path": 83 | train_params["train_data_path"] = os.path.join(treebank_path, f"{args.name}-ud-train.conllu") 84 | if param == "validation_data_path": 85 | train_params["validation_data_path"] = os.path.join(treebank_path, f"{args.name}-ud-dev.conllu") 86 | if param == "test_data_path": 87 | train_params["test_data_path"] = os.path.join(treebank_path, f"{args.name}-ud-test.conllu") 88 | 89 | if param == "vocabulary": 90 | train_params["vocabulary"]["directory_path"] = f"data/vocab/{args.name}/vocabulary" 91 | 92 | if param == "trainer": 93 | for sub_param in train_params["trainer"]: 94 | if sub_param == "learning_rate_scheduler": 95 | train_params["trainer"]["learning_rate_scheduler"]["warmup_steps"] = num_warmup_steps 96 | train_params["trainer"]["learning_rate_scheduler"]["start_step"] = num_warmup_steps 97 | 98 | logger.info(f"changing warmup and start steps for {train_path} to {num_warmup_steps}") 99 | 100 | if "vocabulary" in train_params: 101 | # Remove this key to make AllenNLP happy 102 | train_params["vocabulary"].pop("non_padded_namespaces", None) 103 | 104 | predict_params = train_params.duplicate() 105 | 106 | import_submodules("udify") 107 | 108 | try: 109 | util.cache_vocab(train_params) 110 | train_model(train_params, serialization_dir, recover=bool(args.resume)) 111 | except KeyboardInterrupt: 112 | logger.warning("KeyboardInterrupt, skipping training") 113 | 114 | dev_file = predict_params["validation_data_path"] 115 | test_file = predict_params["test_data_path"] 116 | 117 | dev_pred, dev_eval, test_pred, test_eval = [ 118 | os.path.join(serialization_dir, name) 119 | for name in ["dev.conllu", "dev_results.json", "test.conllu", "test_results.json"] 120 | ] 121 | 122 | if dev_file != test_file: 123 | util.predict_and_evaluate_model(args.predictor, predict_params, serialization_dir, dev_file, dev_pred, dev_eval) 124 | 125 | util.predict_and_evaluate_model(args.predictor, predict_params, serialization_dir, test_file, test_pred, test_eval) 126 | 127 | if args.archive_bert: 128 | bert_config = "config/archive/bert-base-multilingual-cased/bert_config.json" 129 | util.archive_bert_model(serialization_dir, bert_config) 130 | 131 | util.cleanup_training(serialization_dir, keep_archive=not args.cleanup_archive) 132 | -------------------------------------------------------------------------------- /udify/__init__.py: -------------------------------------------------------------------------------- 1 | from udify.dataset_readers import * 2 | from udify.models import * 3 | from udify.modules import * 4 | from udify.optimizers import * 5 | from udify.predictors import * 6 | from udify import * 7 | -------------------------------------------------------------------------------- /udify/dataset_readers/__init__.py: -------------------------------------------------------------------------------- 1 | from udify.dataset_readers.universal_dependencies import UniversalDependenciesDatasetReader 2 | from udify.dataset_readers.sigmorphon_2019_task_2 import Sigmorphon2019Task2DatasetReader 3 | -------------------------------------------------------------------------------- /udify/dataset_readers/conll18_ud_eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Compatible with Python 2.7 and 3.2+, can be used either as a module 4 | # or a standalone executable. 5 | # 6 | # Copyright 2017, 2018 Institute of Formal and Applied Linguistics (UFAL), 7 | # Faculty of Mathematics and Physics, Charles University, Czech Republic. 8 | # 9 | # This Source Code Form is subject to the terms of the Mozilla Public 10 | # License, v. 2.0. If a copy of the MPL was not distributed with this 11 | # file, You can obtain one at http://mozilla.org/MPL/2.0/. 12 | # 13 | # Authors: Milan Straka, Martin Popel 14 | # 15 | # Changelog: 16 | # - [12 Apr 2018] Version 0.9: Initial release. 17 | # - [19 Apr 2018] Version 1.0: Fix bug in MLAS (duplicate entries in functional_children). 18 | # Add --counts option. 19 | # - [02 May 2018] Version 1.1: When removing spaces to match gold and system characters, 20 | # consider all Unicode characters of category Zs instead of 21 | # just ASCII space. 22 | # - [25 Jun 2018] Version 1.2: Use python3 in the she-bang (instead of python). 23 | # In Python2, make the whole computation use `unicode` strings. 24 | 25 | # Command line usage 26 | # ------------------ 27 | # conll18_ud_eval.py [-v] gold_conllu_file system_conllu_file 28 | # 29 | # - if no -v is given, only the official CoNLL18 UD Shared Task evaluation metrics 30 | # are printed 31 | # - if -v is given, more metrics are printed (as precision, recall, F1 score, 32 | # and in case the metric is computed on aligned words also accuracy on these): 33 | # - Tokens: how well do the gold tokens match system tokens 34 | # - Sentences: how well do the gold sentences match system sentences 35 | # - Words: how well can the gold words be aligned to system words 36 | # - UPOS: using aligned words, how well does UPOS match 37 | # - XPOS: using aligned words, how well does XPOS match 38 | # - UFeats: using aligned words, how well does universal FEATS match 39 | # - AllTags: using aligned words, how well does UPOS+XPOS+FEATS match 40 | # - Lemmas: using aligned words, how well does LEMMA match 41 | # - UAS: using aligned words, how well does HEAD match 42 | # - LAS: using aligned words, how well does HEAD+DEPREL(ignoring subtypes) match 43 | # - CLAS: using aligned words with content DEPREL, how well does 44 | # HEAD+DEPREL(ignoring subtypes) match 45 | # - MLAS: using aligned words with content DEPREL, how well does 46 | # HEAD+DEPREL(ignoring subtypes)+UPOS+UFEATS+FunctionalChildren(DEPREL+UPOS+UFEATS) match 47 | # - BLEX: using aligned words with content DEPREL, how well does 48 | # HEAD+DEPREL(ignoring subtypes)+LEMMAS match 49 | # - if -c is given, raw counts of correct/gold_total/system_total/aligned words are printed 50 | # instead of precision/recall/F1/AlignedAccuracy for all metrics. 51 | 52 | # API usage 53 | # --------- 54 | # - load_conllu(file) 55 | # - loads CoNLL-U file from given file object to an internal representation 56 | # - the file object should return str in both Python 2 and Python 3 57 | # - raises UDError exception if the given file cannot be loaded 58 | # - evaluate(gold_ud, system_ud) 59 | # - evaluate the given gold and system CoNLL-U files (loaded with load_conllu) 60 | # - raises UDError if the concatenated tokens of gold and system file do not match 61 | # - returns a dictionary with the metrics described above, each metric having 62 | # three fields: precision, recall and f1 63 | 64 | # Description of token matching 65 | # ----------------------------- 66 | # In order to match tokens of gold file and system file, we consider the text 67 | # resulting from concatenation of gold tokens and text resulting from 68 | # concatenation of system tokens. These texts should match -- if they do not, 69 | # the evaluation fails. 70 | # 71 | # If the texts do match, every token is represented as a range in this original 72 | # text, and tokens are equal only if their range is the same. 73 | 74 | # Description of word matching 75 | # ---------------------------- 76 | # When matching words of gold file and system file, we first match the tokens. 77 | # The words which are also tokens are matched as tokens, but words in multi-word 78 | # tokens have to be handled differently. 79 | # 80 | # To handle multi-word tokens, we start by finding "multi-word spans". 81 | # Multi-word span is a span in the original text such that 82 | # - it contains at least one multi-word token 83 | # - all multi-word tokens in the span (considering both gold and system ones) 84 | # are completely inside the span (i.e., they do not "stick out") 85 | # - the multi-word span is as small as possible 86 | # 87 | # For every multi-word span, we align the gold and system words completely 88 | # inside this span using LCS on their FORMs. The words not intersecting 89 | # (even partially) any multi-word span are then aligned as tokens. 90 | 91 | 92 | from __future__ import division 93 | from __future__ import print_function 94 | 95 | import argparse 96 | import io 97 | import sys 98 | import unicodedata 99 | import unittest 100 | 101 | # CoNLL-U column names 102 | ID, FORM, LEMMA, UPOS, XPOS, FEATS, HEAD, DEPREL, DEPS, MISC = range(10) 103 | 104 | # Content and functional relations 105 | CONTENT_DEPRELS = { 106 | "nsubj", "obj", "iobj", "csubj", "ccomp", "xcomp", "obl", "vocative", 107 | "expl", "dislocated", "advcl", "advmod", "discourse", "nmod", "appos", 108 | "nummod", "acl", "amod", "conj", "fixed", "flat", "compound", "list", 109 | "parataxis", "orphan", "goeswith", "reparandum", "root", "dep" 110 | } 111 | 112 | FUNCTIONAL_DEPRELS = { 113 | "aux", "cop", "mark", "det", "clf", "case", "cc" 114 | } 115 | 116 | UNIVERSAL_FEATURES = { 117 | "PronType", "NumType", "Poss", "Reflex", "Foreign", "Abbr", "Gender", 118 | "Animacy", "Number", "Case", "Definite", "Degree", "VerbForm", "Mood", 119 | "Tense", "Aspect", "Voice", "Evident", "Polarity", "Person", "Polite" 120 | } 121 | 122 | # UD Error is used when raising exceptions in this module 123 | class UDError(Exception): 124 | pass 125 | 126 | # Conversion methods handling `str` <-> `unicode` conversions in Python2 127 | def _decode(text): 128 | return text if sys.version_info[0] >= 3 or not isinstance(text, str) else text.decode("utf-8") 129 | 130 | def _encode(text): 131 | return text if sys.version_info[0] >= 3 or not isinstance(text, unicode) else text.encode("utf-8") 132 | 133 | # Load given CoNLL-U file into internal representation 134 | def load_conllu(file): 135 | # Internal representation classes 136 | class UDRepresentation: 137 | def __init__(self): 138 | # Characters of all the tokens in the whole file. 139 | # Whitespace between tokens is not included. 140 | self.characters = [] 141 | # List of UDSpan instances with start&end indices into `characters`. 142 | self.tokens = [] 143 | # List of UDWord instances. 144 | self.words = [] 145 | # List of UDSpan instances with start&end indices into `characters`. 146 | self.sentences = [] 147 | class UDSpan: 148 | def __init__(self, start, end): 149 | self.start = start 150 | # Note that self.end marks the first position **after the end** of span, 151 | # so we can use characters[start:end] or range(start, end). 152 | self.end = end 153 | class UDWord: 154 | def __init__(self, span, columns, is_multiword): 155 | # Span of this word (or MWT, see below) within ud_representation.characters. 156 | self.span = span 157 | # 10 columns of the CoNLL-U file: ID, FORM, LEMMA,... 158 | self.columns = columns 159 | # is_multiword==True means that this word is part of a multi-word token. 160 | # In that case, self.span marks the span of the whole multi-word token. 161 | self.is_multiword = is_multiword 162 | # Reference to the UDWord instance representing the HEAD (or None if root). 163 | self.parent = None 164 | # List of references to UDWord instances representing functional-deprel children. 165 | self.functional_children = [] 166 | # Only consider universal FEATS. 167 | self.columns[FEATS] = "|".join(sorted(feat for feat in columns[FEATS].split("|") 168 | if feat.split("=", 1)[0] in UNIVERSAL_FEATURES)) 169 | # Let's ignore language-specific deprel subtypes. 170 | self.columns[DEPREL] = columns[DEPREL].split(":")[0] 171 | # Precompute which deprels are CONTENT_DEPRELS and which FUNCTIONAL_DEPRELS 172 | self.is_content_deprel = self.columns[DEPREL] in CONTENT_DEPRELS 173 | self.is_functional_deprel = self.columns[DEPREL] in FUNCTIONAL_DEPRELS 174 | 175 | ud = UDRepresentation() 176 | 177 | # Load the CoNLL-U file 178 | index, sentence_start = 0, None 179 | while True: 180 | line = file.readline() 181 | if not line: 182 | break 183 | line = _decode(line.rstrip("\r\n")) 184 | 185 | # Handle sentence start boundaries 186 | if sentence_start is None: 187 | # Skip comments 188 | if line.startswith("#"): 189 | continue 190 | # Start a new sentence 191 | ud.sentences.append(UDSpan(index, 0)) 192 | sentence_start = len(ud.words) 193 | if not line: 194 | # Add parent and children UDWord links and check there are no cycles 195 | def process_word(word): 196 | if word.parent == "remapping": 197 | raise UDError("There is a cycle in a sentence") 198 | if word.parent is None: 199 | try: 200 | head = int(word.columns[HEAD]) 201 | except: 202 | head = 0 203 | if head < 0 or head > len(ud.words) - sentence_start: 204 | raise UDError("HEAD '{}' points outside of the sentence".format(_encode(word.columns[HEAD]))) 205 | if head: 206 | parent = ud.words[sentence_start + head - 1] 207 | word.parent = "remapping" 208 | process_word(parent) 209 | word.parent = parent 210 | 211 | for word in ud.words[sentence_start:]: 212 | process_word(word) 213 | # func_children cannot be assigned within process_word 214 | # because it is called recursively and may result in adding one child twice. 215 | for word in ud.words[sentence_start:]: 216 | if word.parent and word.is_functional_deprel: 217 | word.parent.functional_children.append(word) 218 | 219 | # Check there is a single root node 220 | # if len([word for word in ud.words[sentence_start:] if word.parent is None]) != 1: 221 | # raise UDError("There are multiple roots in a sentence") 222 | 223 | # End the sentence 224 | ud.sentences[-1].end = index 225 | sentence_start = None 226 | continue 227 | 228 | # Read next token/word 229 | columns = line.split("\t") 230 | if len(columns) != 10: 231 | raise UDError("The CoNLL-U line does not contain 10 tab-separated columns: '{}'".format(_encode(line))) 232 | 233 | # Skip empty nodes 234 | if "." in columns[ID]: 235 | continue 236 | 237 | # Delete spaces from FORM, so gold.characters == system.characters 238 | # even if one of them tokenizes the space. Use any Unicode character 239 | # with category Zs. 240 | columns[FORM] = "".join(filter(lambda c: unicodedata.category(c) != "Zs", columns[FORM])) 241 | if not columns[FORM]: 242 | raise UDError("There is an empty FORM in the CoNLL-U file") 243 | 244 | # Save token 245 | ud.characters.extend(columns[FORM]) 246 | ud.tokens.append(UDSpan(index, index + len(columns[FORM]))) 247 | index += len(columns[FORM]) 248 | 249 | # Handle multi-word tokens to save word(s) 250 | if "-" in columns[ID]: 251 | try: 252 | start, end = map(int, columns[ID].split("-")) 253 | except: 254 | raise UDError("Cannot parse multi-word token ID '{}'".format(_encode(columns[ID]))) 255 | 256 | for _ in range(start, end + 1): 257 | word_line = _decode(file.readline().rstrip("\r\n")) 258 | word_columns = word_line.split("\t") 259 | if len(word_columns) != 10: 260 | raise UDError("The CoNLL-U line does not contain 10 tab-separated columns: '{}'".format(_encode(word_line))) 261 | ud.words.append(UDWord(ud.tokens[-1], word_columns, is_multiword=True)) 262 | # Basic tokens/words 263 | else: 264 | try: 265 | word_id = int(columns[ID]) 266 | except: 267 | raise UDError("Cannot parse word ID '{}'".format(_encode(columns[ID]))) 268 | if word_id != len(ud.words) - sentence_start + 1: 269 | raise UDError("Incorrect word ID '{}' for word '{}', expected '{}'".format( 270 | _encode(columns[ID]), _encode(columns[FORM]), len(ud.words) - sentence_start + 1)) 271 | 272 | try: 273 | head_id = int(columns[HEAD]) 274 | except: 275 | head_id = 0 276 | # raise UDError("Cannot parse HEAD '{}'".format(_encode(columns[HEAD]))) 277 | if head_id < 0: 278 | raise UDError("HEAD cannot be negative") 279 | 280 | ud.words.append(UDWord(ud.tokens[-1], columns, is_multiword=False)) 281 | 282 | if sentence_start is not None: 283 | raise UDError("The CoNLL-U file does not end with empty line") 284 | 285 | return ud 286 | 287 | # Evaluate the gold and system treebanks (loaded using load_conllu). 288 | def evaluate(gold_ud, system_ud): 289 | class Score: 290 | def __init__(self, gold_total, system_total, correct, aligned_total=None): 291 | self.correct = correct 292 | self.gold_total = gold_total 293 | self.system_total = system_total 294 | self.aligned_total = aligned_total 295 | self.precision = correct / system_total if system_total else 0.0 296 | self.recall = correct / gold_total if gold_total else 0.0 297 | self.f1 = 2 * correct / (system_total + gold_total) if system_total + gold_total else 0.0 298 | self.aligned_accuracy = correct / aligned_total if aligned_total else aligned_total 299 | class AlignmentWord: 300 | def __init__(self, gold_word, system_word): 301 | self.gold_word = gold_word 302 | self.system_word = system_word 303 | class Alignment: 304 | def __init__(self, gold_words, system_words): 305 | self.gold_words = gold_words 306 | self.system_words = system_words 307 | self.matched_words = [] 308 | self.matched_words_map = {} 309 | def append_aligned_words(self, gold_word, system_word): 310 | self.matched_words.append(AlignmentWord(gold_word, system_word)) 311 | self.matched_words_map[system_word] = gold_word 312 | 313 | def spans_score(gold_spans, system_spans): 314 | correct, gi, si = 0, 0, 0 315 | while gi < len(gold_spans) and si < len(system_spans): 316 | if system_spans[si].start < gold_spans[gi].start: 317 | si += 1 318 | elif gold_spans[gi].start < system_spans[si].start: 319 | gi += 1 320 | else: 321 | correct += gold_spans[gi].end == system_spans[si].end 322 | si += 1 323 | gi += 1 324 | 325 | return Score(len(gold_spans), len(system_spans), correct) 326 | 327 | def alignment_score(alignment, key_fn=None, filter_fn=None): 328 | if filter_fn is not None: 329 | gold = sum(1 for gold in alignment.gold_words if filter_fn(gold)) 330 | system = sum(1 for system in alignment.system_words if filter_fn(system)) 331 | aligned = sum(1 for word in alignment.matched_words if filter_fn(word.gold_word)) 332 | else: 333 | gold = len(alignment.gold_words) 334 | system = len(alignment.system_words) 335 | aligned = len(alignment.matched_words) 336 | 337 | if key_fn is None: 338 | # Return score for whole aligned words 339 | return Score(gold, system, aligned) 340 | 341 | def gold_aligned_gold(word): 342 | return word 343 | def gold_aligned_system(word): 344 | return alignment.matched_words_map.get(word, "NotAligned") if word is not None else None 345 | correct = 0 346 | for words in alignment.matched_words: 347 | if filter_fn is None or filter_fn(words.gold_word): 348 | if key_fn(words.gold_word, gold_aligned_gold) == key_fn(words.system_word, gold_aligned_system): 349 | correct += 1 350 | 351 | return Score(gold, system, correct, aligned) 352 | 353 | def beyond_end(words, i, multiword_span_end): 354 | if i >= len(words): 355 | return True 356 | if words[i].is_multiword: 357 | return words[i].span.start >= multiword_span_end 358 | return words[i].span.end > multiword_span_end 359 | 360 | def extend_end(word, multiword_span_end): 361 | if word.is_multiword and word.span.end > multiword_span_end: 362 | return word.span.end 363 | return multiword_span_end 364 | 365 | def find_multiword_span(gold_words, system_words, gi, si): 366 | # We know gold_words[gi].is_multiword or system_words[si].is_multiword. 367 | # Find the start of the multiword span (gs, ss), so the multiword span is minimal. 368 | # Initialize multiword_span_end characters index. 369 | if gold_words[gi].is_multiword: 370 | multiword_span_end = gold_words[gi].span.end 371 | if not system_words[si].is_multiword and system_words[si].span.start < gold_words[gi].span.start: 372 | si += 1 373 | else: # if system_words[si].is_multiword 374 | multiword_span_end = system_words[si].span.end 375 | if not gold_words[gi].is_multiword and gold_words[gi].span.start < system_words[si].span.start: 376 | gi += 1 377 | gs, ss = gi, si 378 | 379 | # Find the end of the multiword span 380 | # (so both gi and si are pointing to the word following the multiword span end). 381 | while not beyond_end(gold_words, gi, multiword_span_end) or \ 382 | not beyond_end(system_words, si, multiword_span_end): 383 | if gi < len(gold_words) and (si >= len(system_words) or 384 | gold_words[gi].span.start <= system_words[si].span.start): 385 | multiword_span_end = extend_end(gold_words[gi], multiword_span_end) 386 | gi += 1 387 | else: 388 | multiword_span_end = extend_end(system_words[si], multiword_span_end) 389 | si += 1 390 | return gs, ss, gi, si 391 | 392 | def compute_lcs(gold_words, system_words, gi, si, gs, ss): 393 | lcs = [[0] * (si - ss) for i in range(gi - gs)] 394 | for g in reversed(range(gi - gs)): 395 | for s in reversed(range(si - ss)): 396 | if gold_words[gs + g].columns[FORM].lower() == system_words[ss + s].columns[FORM].lower(): 397 | lcs[g][s] = 1 + (lcs[g+1][s+1] if g+1 < gi-gs and s+1 < si-ss else 0) 398 | lcs[g][s] = max(lcs[g][s], lcs[g+1][s] if g+1 < gi-gs else 0) 399 | lcs[g][s] = max(lcs[g][s], lcs[g][s+1] if s+1 < si-ss else 0) 400 | return lcs 401 | 402 | def align_words(gold_words, system_words): 403 | alignment = Alignment(gold_words, system_words) 404 | 405 | gi, si = 0, 0 406 | while gi < len(gold_words) and si < len(system_words): 407 | if gold_words[gi].is_multiword or system_words[si].is_multiword: 408 | # A: Multi-word tokens => align via LCS within the whole "multiword span". 409 | gs, ss, gi, si = find_multiword_span(gold_words, system_words, gi, si) 410 | 411 | if si > ss and gi > gs: 412 | lcs = compute_lcs(gold_words, system_words, gi, si, gs, ss) 413 | 414 | # Store aligned words 415 | s, g = 0, 0 416 | while g < gi - gs and s < si - ss: 417 | if gold_words[gs + g].columns[FORM].lower() == system_words[ss + s].columns[FORM].lower(): 418 | alignment.append_aligned_words(gold_words[gs+g], system_words[ss+s]) 419 | g += 1 420 | s += 1 421 | elif lcs[g][s] == (lcs[g+1][s] if g+1 < gi-gs else 0): 422 | g += 1 423 | else: 424 | s += 1 425 | else: 426 | # B: No multi-word token => align according to spans. 427 | if (gold_words[gi].span.start, gold_words[gi].span.end) == (system_words[si].span.start, system_words[si].span.end): 428 | alignment.append_aligned_words(gold_words[gi], system_words[si]) 429 | gi += 1 430 | si += 1 431 | elif gold_words[gi].span.start <= system_words[si].span.start: 432 | gi += 1 433 | else: 434 | si += 1 435 | 436 | return alignment 437 | 438 | # Check that the underlying character sequences do match. 439 | if gold_ud.characters != system_ud.characters: 440 | index = 0 441 | while index < len(gold_ud.characters) and index < len(system_ud.characters) and \ 442 | gold_ud.characters[index] == system_ud.characters[index]: 443 | index += 1 444 | 445 | raise UDError( 446 | "The concatenation of tokens in gold file and in system file differ!\n" + 447 | "First 20 differing characters in gold file: '{}' and system file: '{}'".format( 448 | "".join(map(_encode, gold_ud.characters[index:index + 20])), 449 | "".join(map(_encode, system_ud.characters[index:index + 20])) 450 | ) 451 | ) 452 | 453 | # Align words 454 | alignment = align_words(gold_ud.words, system_ud.words) 455 | 456 | # Compute the F1-scores 457 | return { 458 | "Tokens": spans_score(gold_ud.tokens, system_ud.tokens), 459 | "Sentences": spans_score(gold_ud.sentences, system_ud.sentences), 460 | "Words": alignment_score(alignment), 461 | "UPOS": alignment_score(alignment, lambda w, _: w.columns[UPOS]), 462 | "XPOS": alignment_score(alignment, lambda w, _: w.columns[XPOS]), 463 | "UFeats": alignment_score(alignment, lambda w, _: w.columns[FEATS]), 464 | "AllTags": alignment_score(alignment, lambda w, _: (w.columns[UPOS], w.columns[XPOS], w.columns[FEATS])), 465 | "Lemmas": alignment_score(alignment, lambda w, ga: w.columns[LEMMA] if ga(w).columns[LEMMA] != "_" else "_"), 466 | "UAS": alignment_score(alignment, lambda w, ga: ga(w.parent)), 467 | "LAS": alignment_score(alignment, lambda w, ga: (ga(w.parent), w.columns[DEPREL])), 468 | "CLAS": alignment_score(alignment, lambda w, ga: (ga(w.parent), w.columns[DEPREL]), 469 | filter_fn=lambda w: w.is_content_deprel), 470 | "MLAS": alignment_score(alignment, lambda w, ga: (ga(w.parent), w.columns[DEPREL], w.columns[UPOS], w.columns[FEATS], 471 | [(ga(c), c.columns[DEPREL], c.columns[UPOS], c.columns[FEATS]) 472 | for c in w.functional_children]), 473 | filter_fn=lambda w: w.is_content_deprel), 474 | "BLEX": alignment_score(alignment, lambda w, ga: (ga(w.parent), w.columns[DEPREL], 475 | w.columns[LEMMA] if ga(w).columns[LEMMA] != "_" else "_"), 476 | filter_fn=lambda w: w.is_content_deprel), 477 | } 478 | 479 | 480 | def load_conllu_file(path): 481 | _file = open(path, mode="r", **({"encoding": "utf-8"} if sys.version_info >= (3, 0) else {})) 482 | return load_conllu(_file) 483 | 484 | def evaluate_wrapper(args): 485 | # Load CoNLL-U files 486 | gold_ud = load_conllu_file(args.gold_file) 487 | system_ud = load_conllu_file(args.system_file) 488 | return evaluate(gold_ud, system_ud) 489 | 490 | def main(): 491 | # Parse arguments 492 | parser = argparse.ArgumentParser() 493 | parser.add_argument("gold_file", type=str, 494 | help="Name of the CoNLL-U file with the gold data.") 495 | parser.add_argument("system_file", type=str, 496 | help="Name of the CoNLL-U file with the predicted data.") 497 | parser.add_argument("--verbose", "-v", default=False, action="store_true", 498 | help="Print all metrics.") 499 | parser.add_argument("--counts", "-c", default=False, action="store_true", 500 | help="Print raw counts of correct/gold/system/aligned words instead of prec/rec/F1 for all metrics.") 501 | args = parser.parse_args() 502 | 503 | # Evaluate 504 | evaluation = evaluate_wrapper(args) 505 | 506 | # Print the evaluation 507 | if not args.verbose and not args.counts: 508 | print("LAS F1 Score: {:.2f}".format(100 * evaluation["LAS"].f1)) 509 | print("MLAS Score: {:.2f}".format(100 * evaluation["MLAS"].f1)) 510 | print("BLEX Score: {:.2f}".format(100 * evaluation["BLEX"].f1)) 511 | else: 512 | if args.counts: 513 | print("Metric | Correct | Gold | Predicted | Aligned") 514 | else: 515 | print("Metric | Precision | Recall | F1 Score | AligndAcc") 516 | print("-----------+-----------+-----------+-----------+-----------") 517 | for metric in["Tokens", "Sentences", "Words", "UPOS", "XPOS", "UFeats", "AllTags", "Lemmas", "UAS", "LAS", "CLAS", "MLAS", "BLEX"]: 518 | if args.counts: 519 | print("{:11}|{:10} |{:10} |{:10} |{:10}".format( 520 | metric, 521 | evaluation[metric].correct, 522 | evaluation[metric].gold_total, 523 | evaluation[metric].system_total, 524 | evaluation[metric].aligned_total or (evaluation[metric].correct if metric == "Words" else "") 525 | )) 526 | else: 527 | print("{:11}|{:10.2f} |{:10.2f} |{:10.2f} |{}".format( 528 | metric, 529 | 100 * evaluation[metric].precision, 530 | 100 * evaluation[metric].recall, 531 | 100 * evaluation[metric].f1, 532 | "{:10.2f}".format(100 * evaluation[metric].aligned_accuracy) if evaluation[metric].aligned_accuracy is not None else "" 533 | )) 534 | 535 | if __name__ == "__main__": 536 | main() 537 | 538 | # Tests, which can be executed with `python -m unittest conll18_ud_eval`. 539 | class TestAlignment(unittest.TestCase): 540 | @staticmethod 541 | def _load_words(words): 542 | """Prepare fake CoNLL-U files with fake HEAD to prevent multiple roots errors.""" 543 | lines, num_words = [], 0 544 | for w in words: 545 | parts = w.split(" ") 546 | if len(parts) == 1: 547 | num_words += 1 548 | lines.append("{}\t{}\t_\t_\t_\t_\t{}\t_\t_\t_".format(num_words, parts[0], int(num_words>1))) 549 | else: 550 | lines.append("{}-{}\t{}\t_\t_\t_\t_\t_\t_\t_\t_".format(num_words + 1, num_words + len(parts) - 1, parts[0])) 551 | for part in parts[1:]: 552 | num_words += 1 553 | lines.append("{}\t{}\t_\t_\t_\t_\t{}\t_\t_\t_".format(num_words, part, int(num_words>1))) 554 | return load_conllu((io.StringIO if sys.version_info >= (3, 0) else io.BytesIO)("\n".join(lines+["\n"]))) 555 | 556 | def _test_exception(self, gold, system): 557 | self.assertRaises(UDError, evaluate, self._load_words(gold), self._load_words(system)) 558 | 559 | def _test_ok(self, gold, system, correct): 560 | metrics = evaluate(self._load_words(gold), self._load_words(system)) 561 | gold_words = sum((max(1, len(word.split(" ")) - 1) for word in gold)) 562 | system_words = sum((max(1, len(word.split(" ")) - 1) for word in system)) 563 | self.assertEqual((metrics["Words"].precision, metrics["Words"].recall, metrics["Words"].f1), 564 | (correct / system_words, correct / gold_words, 2 * correct / (gold_words + system_words))) 565 | 566 | def test_exception(self): 567 | self._test_exception(["a"], ["b"]) 568 | 569 | def test_equal(self): 570 | self._test_ok(["a"], ["a"], 1) 571 | self._test_ok(["a", "b", "c"], ["a", "b", "c"], 3) 572 | 573 | def test_equal_with_multiword(self): 574 | self._test_ok(["abc a b c"], ["a", "b", "c"], 3) 575 | self._test_ok(["a", "bc b c", "d"], ["a", "b", "c", "d"], 4) 576 | self._test_ok(["abcd a b c d"], ["ab a b", "cd c d"], 4) 577 | self._test_ok(["abc a b c", "de d e"], ["a", "bcd b c d", "e"], 5) 578 | 579 | def test_alignment(self): 580 | self._test_ok(["abcd"], ["a", "b", "c", "d"], 0) 581 | self._test_ok(["abc", "d"], ["a", "b", "c", "d"], 1) 582 | self._test_ok(["a", "bc", "d"], ["a", "b", "c", "d"], 2) 583 | self._test_ok(["a", "bc b c", "d"], ["a", "b", "cd"], 2) 584 | self._test_ok(["abc a BX c", "def d EX f"], ["ab a b", "cd c d", "ef e f"], 4) 585 | self._test_ok(["ab a b", "cd bc d"], ["a", "bc", "d"], 2) 586 | self._test_ok(["a", "bc b c", "d"], ["ab AX BX", "cd CX a"], 1) 587 | -------------------------------------------------------------------------------- /udify/dataset_readers/evaluate_2019_task2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """Evaluation for the SIGMORPHON 2019 shared task, task 2. 3 | 4 | Computes various metrics on input. 5 | 6 | Author: Arya D. McCarthy 7 | Last update: 2018-12-21 8 | """ 9 | 10 | import argparse 11 | import contextlib 12 | import io 13 | import logging 14 | import sys 15 | 16 | import numpy as np 17 | 18 | from collections import Counter, namedtuple 19 | from pathlib import Path 20 | 21 | log = logging.getLogger(Path(__file__).stem) 22 | 23 | 24 | COLUMNS = "ID FORM LEMMA UPOS XPOS FEATS HEAD DEPREL DEPS MISC".split() 25 | ConlluRow = namedtuple("ConlluRow", COLUMNS) 26 | SEPARATOR = ";" 27 | 28 | 29 | def distance(str1, str2): 30 | """Simple Levenshtein implementation.""" 31 | m = np.zeros([len(str2)+1, len(str1)+1]) 32 | for x in range(1, len(str2) + 1): 33 | m[x][0] = m[x-1][0] + 1 34 | for y in range(1, len(str1) + 1): 35 | m[0][y] = m[0][y-1] + 1 36 | for x in range(1, len(str2) + 1): 37 | for y in range(1, len(str1) + 1): 38 | if str1[y-1] == str2[x-1]: 39 | dg = 0 40 | else: 41 | dg = 1 42 | m[x][y] = min(m[x-1][y] + 1, m[x][y-1] + 1, m[x-1][y-1] + dg) 43 | return int(m[len(str2)][len(str1)]) 44 | 45 | 46 | def set_equal(str1, str2): 47 | set1 = set(str1.split(SEPARATOR)) 48 | set2 = set(str2.split(SEPARATOR)) 49 | return set1 == set2 50 | 51 | 52 | def manipulate_data(pairs): 53 | log.info("Lemma acc, Lemma Levenshtein, morph acc, morph F1") 54 | 55 | count = 0 56 | lemma_acc = 0 57 | lemma_lev = 0 58 | morph_acc = 0 59 | 60 | f1_precision_scores = 0 61 | f1_precision_counts = 0 62 | f1_recall_scores = 0 63 | f1_recall_counts = 0 64 | 65 | for r, o in pairs: 66 | log.debug("{}\t{}\t{}\t{}".format(r.LEMMA, o.LEMMA, r.FEATS, o.FEATS)) 67 | count += 1 68 | lemma_acc += (r.LEMMA == o.LEMMA) 69 | lemma_lev += distance(r.LEMMA, o.LEMMA) 70 | morph_acc += set_equal(r.FEATS, o.FEATS) 71 | 72 | r_feats = set(r.FEATS.split(SEPARATOR)) - {"_"} 73 | o_feats = set(o.FEATS.split(SEPARATOR)) - {"_"} 74 | 75 | union_size = len(r_feats & o_feats) 76 | reference_size = len(r_feats) 77 | output_size = len(o_feats) 78 | 79 | f1_precision_scores += union_size 80 | f1_recall_scores += union_size 81 | f1_precision_counts += output_size 82 | f1_recall_counts += reference_size 83 | 84 | f1_precision = f1_precision_scores / (f1_precision_counts or 1) 85 | f1_recall = f1_recall_scores / (f1_recall_counts or 1) 86 | f1 = 2 * (f1_precision * f1_recall) / (f1_precision + f1_recall + 1E-20) 87 | 88 | return (100 * lemma_acc / count, lemma_lev / count, 100 * morph_acc / count, 100 * f1) 89 | 90 | 91 | def parse_args(): 92 | """Parse command line arguments.""" 93 | parser = argparse.ArgumentParser(description=__doc__) 94 | parser.add_argument('-r', '--reference', 95 | type=Path, required=True) 96 | parser.add_argument('-o', '--output', 97 | type=Path, required=True) 98 | # Set the verbosity level for the logger. The `-v` option will set it to 99 | # the debug level, while the `-q` will set it to the warning level. 100 | # Otherwise use the info level. 101 | verbosity = parser.add_mutually_exclusive_group() 102 | verbosity.add_argument('-v', '--verbose', action='store_const', 103 | const=logging.DEBUG, default=logging.INFO) 104 | verbosity.add_argument('-q', '--quiet', dest='verbose', 105 | action='store_const', const=logging.WARNING) 106 | return parser.parse_args() 107 | 108 | 109 | def strip_comments(lines): 110 | for line in lines: 111 | if not line.startswith("#"): 112 | yield line 113 | 114 | 115 | def read_conllu(file: Path): 116 | with open(file) as f: 117 | yield from strip_comments(f) 118 | 119 | 120 | def input_pairs(reference, output): 121 | output = [o for o in output if len(o.split()) == 0 or '.' not in o.split()[0]] 122 | reference = [r for r in reference if len(r.split()) == 0 or '.' not in r.split()[0]] 123 | 124 | for r, o in zip(reference, output): 125 | assert r.count("\t") == o.count("\t"), (r.count("\t"), o.count("\t"), o) 126 | if r.count("\t") > 0: 127 | r_conllu = ConlluRow._make(r.split("\t")) 128 | o_conllu = ConlluRow._make(o.split("\t")) 129 | yield r_conllu, o_conllu 130 | 131 | 132 | def main(): 133 | args = parse_args() 134 | logging.basicConfig(level=args.verbose) 135 | reference = read_conllu(args.reference) 136 | output = read_conllu(args.output) 137 | results = manipulate_data(input_pairs(reference, output)) 138 | print(*["{0:.2f}".format(v) for v in results], sep='\t') 139 | 140 | if __name__ == "__main__": 141 | main() 142 | -------------------------------------------------------------------------------- /udify/dataset_readers/lemma_edit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for processing lemmas 3 | 4 | Adopted from UDPipe Future 5 | https://github.com/CoNLL-UD-2018/UDPipe-Future 6 | """ 7 | 8 | 9 | def min_edit_script(source, target, allow_copy=False): 10 | """ 11 | Finds the minimum edit script to transform the source to the target 12 | """ 13 | a = [[(len(source) + len(target) + 1, None)] * (len(target) + 1) for _ in range(len(source) + 1)] 14 | for i in range(0, len(source) + 1): 15 | for j in range(0, len(target) + 1): 16 | if i == 0 and j == 0: 17 | a[i][j] = (0, "") 18 | else: 19 | if allow_copy and i and j and source[i - 1] == target[j - 1] and a[i-1][j-1][0] < a[i][j][0]: 20 | a[i][j] = (a[i-1][j-1][0], a[i-1][j-1][1] + "→") 21 | if i and a[i-1][j][0] < a[i][j][0]: 22 | a[i][j] = (a[i-1][j][0] + 1, a[i-1][j][1] + "-") 23 | if j and a[i][j-1][0] < a[i][j][0]: 24 | a[i][j] = (a[i][j-1][0] + 1, a[i][j-1][1] + "+" + target[j - 1]) 25 | return a[-1][-1][1] 26 | 27 | 28 | def gen_lemma_rule(form, lemma, allow_copy=False): 29 | """ 30 | Generates a lemma rule to transform the source to the target 31 | """ 32 | form = form.lower() 33 | 34 | previous_case = -1 35 | lemma_casing = "" 36 | for i, c in enumerate(lemma): 37 | case = "↑" if c.lower() != c else "↓" 38 | if case != previous_case: 39 | lemma_casing += "{}{}{}".format("¦" if lemma_casing else "", case, i if i <= len(lemma) // 2 else i - len(lemma)) 40 | previous_case = case 41 | lemma = lemma.lower() 42 | 43 | best, best_form, best_lemma = 0, 0, 0 44 | for l in range(len(lemma)): 45 | for f in range(len(form)): 46 | cpl = 0 47 | while f + cpl < len(form) and l + cpl < len(lemma) and form[f + cpl] == lemma[l + cpl]: cpl += 1 48 | if cpl > best: 49 | best = cpl 50 | best_form = f 51 | best_lemma = l 52 | 53 | rule = lemma_casing + ";" 54 | if not best: 55 | rule += "a" + lemma 56 | else: 57 | rule += "d{}¦{}".format( 58 | min_edit_script(form[:best_form], lemma[:best_lemma], allow_copy), 59 | min_edit_script(form[best_form + best:], lemma[best_lemma + best:], allow_copy), 60 | ) 61 | return rule 62 | 63 | 64 | def apply_lemma_rule(form, lemma_rule): 65 | """ 66 | Applies the lemma rule to the form to generate the lemma 67 | """ 68 | casing, rule = lemma_rule.split(";", 1) 69 | if rule.startswith("a"): 70 | lemma = rule[1:] 71 | else: 72 | form = form.lower() 73 | rules, rule_sources = rule[1:].split("¦"), [] 74 | assert len(rules) == 2 75 | for rule in rules: 76 | source, i = 0, 0 77 | while i < len(rule): 78 | if rule[i] == "→" or rule[i] == "-": 79 | source += 1 80 | else: 81 | assert rule[i] == "+" 82 | i += 1 83 | i += 1 84 | rule_sources.append(source) 85 | 86 | try: 87 | lemma, form_offset = "", 0 88 | for i in range(2): 89 | j, offset = 0, (0 if i == 0 else len(form) - rule_sources[1]) 90 | while j < len(rules[i]): 91 | if rules[i][j] == "→": 92 | lemma += form[offset] 93 | offset += 1 94 | elif rules[i][j] == "-": 95 | offset += 1 96 | else: 97 | assert(rules[i][j] == "+") 98 | lemma += rules[i][j + 1] 99 | j += 1 100 | j += 1 101 | if i == 0: 102 | lemma += form[rule_sources[0]: len(form) - rule_sources[1]] 103 | except: 104 | lemma = form 105 | 106 | for rule in casing.split("¦"): 107 | if rule == "↓0": continue # The lemma is lowercased initially 108 | case, offset = rule[0], int(rule[1:]) 109 | lemma = lemma[:offset] + (lemma[offset:].upper() if case == "↑" else lemma[offset:].lower()) 110 | 111 | return lemma 112 | -------------------------------------------------------------------------------- /udify/dataset_readers/parser.py: -------------------------------------------------------------------------------- 1 | """ 2 | Parses Universal Dependencies data 3 | """ 4 | 5 | from __future__ import unicode_literals 6 | 7 | import re 8 | from collections import OrderedDict 9 | 10 | DEFAULT_FIELDS = ('id', 'form', 'lemma', 'upostag', 'xpostag', 'feats', 'head', 'deprel', 'deps', 'misc') 11 | 12 | deps_pattern = r"\d+:[a-z][a-z_-]*(:[a-z][a-z_-]*)?" 13 | MULTI_DEPS_PATTERN = re.compile(r"^{}(\|{})*$".format(deps_pattern, deps_pattern)) 14 | 15 | 16 | class ParseException(Exception): 17 | pass 18 | 19 | 20 | def process_multiword_tokens(annotation): 21 | """ 22 | Processes CoNLLU annotations for multi-word tokens. 23 | If the token id returned by the conllu library is a tuple object (either a multi-word token or an elided token), 24 | then the token id is set to None so that the token won't be used later on by the model. 25 | """ 26 | 27 | for i in range(len(annotation)): 28 | conllu_id = annotation[i]["id"] 29 | if type(conllu_id) == tuple: 30 | if "-" in conllu_id: 31 | conllu_id = str(conllu_id[0]) + "-" + str(conllu_id[2]) 32 | annotation[i]["multi_id"] = conllu_id 33 | annotation[i]["id"] = None 34 | elif "." in conllu_id: 35 | annotation[i]["id"] = None 36 | annotation[i]["multi_id"] = None 37 | else: 38 | annotation[i]["multi_id"] = None 39 | 40 | return annotation 41 | 42 | 43 | def parse_token_and_metadata(data, fields=None): 44 | if not data: 45 | raise ParseException("Can't create TokenList, no data sent to constructor.") 46 | 47 | fields = fields or DEFAULT_FIELDS 48 | 49 | tokens = [] 50 | metadata = OrderedDict() 51 | 52 | for line in data.split('\n'): 53 | line = line.strip() 54 | 55 | if not line: 56 | continue 57 | 58 | if line.startswith('#'): 59 | var_name, var_value = parse_comment_line(line) 60 | if var_name: 61 | metadata[var_name] = var_value 62 | else: 63 | tokens.append(parse_line(line, fields=fields)) 64 | 65 | return tokens, metadata 66 | 67 | 68 | def parse_line(line, fields=DEFAULT_FIELDS, parse_feats=True): 69 | line = re.split(r"\t| {2,}", line) 70 | 71 | if len(line) == 1 and " " in line[0]: 72 | raise ParseException("Invalid line format, line must contain either tabs or two spaces.") 73 | 74 | data = OrderedDict() 75 | 76 | for i, field in enumerate(fields): 77 | # Allow parsing CoNNL-U files with fewer columns 78 | if i >= len(line): 79 | break 80 | 81 | if field == "id": 82 | value = parse_id_value(line[i]) 83 | data["multi_id"] = parse_multi_id_value(line[i]) 84 | 85 | elif field == "xpostag": 86 | value = parse_nullable_value(line[i]) 87 | 88 | elif field == "feats": 89 | if parse_feats: 90 | value = parse_dict_value(line[i]) 91 | else: 92 | value = line[i] 93 | 94 | elif field == "head": 95 | value = parse_int_value(line[i]) 96 | 97 | elif field == "deps": 98 | value = parse_paired_list_value(line[i]) 99 | 100 | elif field == "misc": 101 | value = parse_dict_value(line[i]) 102 | 103 | else: 104 | value = line[i] 105 | 106 | data[field] = value 107 | 108 | return data 109 | 110 | 111 | def parse_comment_line(line): 112 | line = line.strip() 113 | if line[0] != '#': 114 | raise ParseException("Invalid comment format, comment must start with '#'") 115 | if '=' not in line: 116 | return None, None 117 | var_name, var_value = line[1:].split('=', 1) 118 | var_name = var_name.strip() 119 | var_value = var_value.strip() 120 | return var_name, var_value 121 | 122 | 123 | def parse_int_value(value): 124 | if value == '_': 125 | return None 126 | try: 127 | return int(value) 128 | except ValueError: 129 | return None 130 | 131 | 132 | def parse_id_value(value): 133 | # return value if "-" not in value else None 134 | return value if "-" not in value and "." not in value else None 135 | # TODO: handle special ids with "." 136 | 137 | 138 | def parse_multi_id_value(value): 139 | if len(value.split('-')) == 2: 140 | return value 141 | return None 142 | 143 | 144 | def parse_paired_list_value(value): 145 | if re.match(MULTI_DEPS_PATTERN, value): 146 | return [ 147 | (part.split(":", 1)[1], parse_int_value(part.split(":", 1)[0])) 148 | for part in value.split("|") 149 | ] 150 | 151 | return parse_nullable_value(value) 152 | 153 | 154 | def parse_dict_value(value): 155 | if "=" in value: 156 | return OrderedDict([ 157 | (part.split("=")[0], parse_nullable_value(part.split("=")[1])) 158 | for part in value.split("|") if len(part.split('=')) == 2 159 | ]) 160 | 161 | return parse_nullable_value(value) 162 | 163 | 164 | def parse_nullable_value(value): 165 | if not value or value == "_": 166 | return None 167 | 168 | return value -------------------------------------------------------------------------------- /udify/dataset_readers/sigmorphon_2019_task_2.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple, List, Any, Callable 2 | 3 | from overrides import overrides 4 | from udify.dataset_readers.parser import parse_line, DEFAULT_FIELDS, process_multiword_tokens 5 | from conllu import parse_incr 6 | 7 | import re 8 | 9 | from allennlp.common.file_utils import cached_path 10 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader 11 | from allennlp.data.fields import Field, TextField, SequenceLabelField, MetadataField 12 | from allennlp.data.instance import Instance 13 | from allennlp.data.token_indexers import SingleIdTokenIndexer, TokenIndexer 14 | from allennlp.data.tokenizers import Token 15 | 16 | from udify.dataset_readers.lemma_edit import gen_lemma_rule 17 | 18 | import logging 19 | 20 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 21 | 22 | 23 | # A dictionary version of the UniMorph schema specified at https://unimorph.github.io/doc/unimorph-schema.pdf 24 | unimorph_schema = { 25 | 'aktionsart': ['accmp', 'ach', 'acty', 'atel', 'dur', 'dyn', 'pct', 'semel', 'stat', 'tel'], 26 | 'animacy': ['anim', 'hum', 'inan', 'nhum'], 27 | 'argument_marking': ['argac3s'], 28 | 'aspect': ['hab', 'ipfv', 'iter', 'pfv', 'prf', 'prog', 'prosp'], 29 | 'case': ['abl', 'abs', 'acc', 'all', 'ante', 'apprx', 'apud', 'at', 'avr', 'ben', 'byway', 'circ', 30 | 'com', 'compv', 'dat', 'eqtv', 'erg', 'ess', 'frml', 'gen', 'in', 'ins', 'inter', 'nom', 31 | 'noms', 'on', 'onhr', 'onvr', 'post', 'priv', 'prol', 'propr', 'prox', 'prp', 'prt', 'rel', 32 | 'rem', 'sub', 'term', 'trans', 'vers', 'voc'], 33 | 'comparison': ['ab', 'cmpr', 'eqt', 'rl', 'sprl'], 34 | 'definiteness': ['def', 'indf', 'nspec', 'spec'], 35 | 'deixis': ['abv', 'bel', 'even', 'med', 'noref', 'nvis', 'phor', 'prox', 'ref1', 'ref2', 'remt', 36 | 'vis'], 37 | 'evidentiality': ['assum', 'aud', 'drct', 'fh', 'hrsy', 'infer', 'nfh', 'nvsen', 'quot', 'rprt', 38 | 'sen'], 39 | 'finiteness': ['fin', 'nfin'], 40 | 'gender': ['bantu1-23', 'fem', 'masc', 'nakh1-8', 'neut'], 41 | 'information_structure': ['foc', 'top'], 42 | 'interrogativity': ['decl', 'int'], 43 | 'language-specific_features': ['lgspec1', 'lgspec2'], 44 | 'mood': ['adm', 'aunprp', 'auprp', 'cond', 'deb', 'ded', 'imp', 'ind', 'inten', 'irr', 'lkly', 45 | 'oblig', 'opt', 'perm', 'pot', 'purp', 'real', 'sbjv', 'sim'], 46 | 'number': ['du', 'gpauc', 'grpl', 'invn', 'pauc', 'pl', 'sg', 'tri'], 47 | 'part_of_speech': ['adj', 'adp', 'adv', 'art', 'aux', 'clf', 'comp', 'conj', 'det', 'intj', 'n', 48 | 'num', 'part', 'pro', 'propn', 'v', 'v.cvb', 'v.msdr', 'v.ptcp'], 49 | 'person': ['0', '1', '2', '3', '4', 'excl', 'incl', 'obv', 'prx'], 50 | 'polarity': ['pos', 'neg'], 51 | 'politeness': ['avoid', 'col', 'elev', 'foreg', 'form', 'high', 'humb', 'infm', 'lit', 'low', 'pol', 52 | 'stelev', 'stsupr'], 53 | 'possession': ['aln', 'naln', 'pss1d', 'pss1de', 'pss1di', 'pss1p', 'pss1pe', 'pss1pi', 'pss1s', 54 | 'pss2d', 'pss2df', 'pss2dm', 'pss2p', 'pss2pf', 'pss2pm', 'pss2s', 'pss2sf', 55 | 'pss2sform', 'pss2sinfm', 'pss2sm', 'pss3d', 'pss3df', 'pss3dm', 'pss3p', 'pss3pf', 56 | 'pss3pm', 'pss3s', 'pss3sf', 'pss3sm', 'pssd', 'psss', 'pssp', 'pss', 'pss3', 'pss1'], 57 | 'switch-reference': ['mn', 'ds', 'dsadv', 'log', 'or', 'seqma', 'simma', 'ss', 'ssadv'], 58 | 'tense': ['1day', 'fut', 'hod', 'immed', 'prs', 'pst', 'rct', 'rmt'], 59 | 'valency': ['appl', 'caus', 'ditr', 'imprs', 'intr', 'recp', 'refl', 'tr'], 60 | 'voice': ['acfoc', 'act', 'agfoc', 'antip', 'bfoc', 'cfoc', 'dir', 'ifoc', 'inv', 'lfoc', 'mid', 61 | 'pass', 'pfoc'] 62 | # Extra Undefined Labels: "arg*", "pss*", "dist", "prontype", "number", "mood" 63 | } 64 | 65 | 66 | @DatasetReader.register("udify_sigmorphon_2019_task_2") 67 | class Sigmorphon2019Task2DatasetReader(DatasetReader): 68 | def __init__(self, 69 | token_indexers: Dict[str, TokenIndexer] = None, 70 | lazy: bool = False) -> None: 71 | super().__init__(lazy) 72 | self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()} 73 | 74 | self.label_to_dimension = {} 75 | for dimension, labels in unimorph_schema.items(): 76 | for label in labels: 77 | self.label_to_dimension[label] = dimension 78 | 79 | @overrides 80 | def _read(self, file_path: str): 81 | # if `file_path` is a URL, redirect to the cache 82 | file_path = cached_path(file_path) 83 | 84 | with open(file_path, 'r') as conllu_file: 85 | logger.info("Reading UD instances from conllu dataset at: %s", file_path) 86 | 87 | for annotation in parse_incr(conllu_file): 88 | # CoNLLU annotations sometimes add back in words that have been elided 89 | # in the original sentence; we remove these, as we're just predicting 90 | # dependencies for the original sentence. 91 | # We filter by None here as elided words have a non-integer word id, 92 | # and we replace these word ids with None in process_multiword_tokens. 93 | annotation = process_multiword_tokens(annotation) 94 | multiword_tokens = [x for x in annotation if x["multi_id"] is not None] 95 | annotation = [x for x in annotation if x["id"] is not None] 96 | 97 | if len(annotation) == 0: 98 | continue 99 | 100 | def get_field(tag: str, map_fn: Callable[[Any], Any] = None) -> List[Any]: 101 | map_fn = map_fn if map_fn is not None else lambda x: x 102 | return [map_fn(x[tag]) if x[tag] is not None else "_" for x in annotation if tag in x] 103 | 104 | # Extract multiword token rows (not used for prediction, purely for evaluation) 105 | ids = [x["id"] for x in annotation] 106 | multiword_ids = [x["multi_id"] for x in multiword_tokens] 107 | multiword_forms = [x["form"] for x in multiword_tokens] 108 | 109 | words = get_field("form") 110 | lemmas = get_field("lemma") 111 | lemma_rules = [gen_lemma_rule(word, lemma) 112 | if lemma != "_" else "_" 113 | for word, lemma in zip(words, lemmas)] 114 | feats = get_field("feats") 115 | 116 | yield self.text_to_instance(words, lemmas, lemma_rules, feats, ids, multiword_ids, multiword_forms) 117 | 118 | @overrides 119 | def text_to_instance(self, # type: ignore 120 | words: List[str], 121 | lemmas: List[str] = None, 122 | lemma_rules: List[str] = None, 123 | feats: List[str] = None, 124 | ids: List[str] = None, 125 | multiword_ids: List[str] = None, 126 | multiword_forms: List[str] = None) -> Instance: 127 | fields: Dict[str, Field] = {} 128 | 129 | tokens = TextField([Token(w) for w in words], self._token_indexers) 130 | fields["tokens"] = tokens 131 | 132 | if lemma_rules: 133 | fields["lemmas"] = SequenceLabelField(lemma_rules, tokens, label_namespace="lemmas") 134 | 135 | if "feats": 136 | fields["feats"] = SequenceLabelField(feats, tokens, label_namespace="feats") 137 | 138 | # TODO: parameter to turn this off 139 | feature_seq = [] 140 | 141 | for feat in feats: 142 | features = feat.lower().split(";") if feat != "_" else "_" 143 | dimensions = {dimension: "_" for dimension in unimorph_schema} 144 | 145 | if feat != "_": 146 | for label in features: 147 | # Use regex to handle special cases where multi-labels are contained inside "{}" 148 | first_label = re.findall(r"(?#{)([a-zA-Z0-9.\-_]+)(?#\+|\/|})", label) 149 | first_label = first_label[0] if first_label else label 150 | 151 | if first_label not in self.label_to_dimension: 152 | if first_label.startswith("arg"): 153 | # TODO: support argument_marking dimension 154 | continue 155 | elif first_label in ["dist", "prontype", "number", "mood"]: 156 | # TODO: unknown labels 157 | continue 158 | elif first_label.startswith("pss"): 159 | dimension = "possession" 160 | else: 161 | raise KeyError(first_label) 162 | else: 163 | dimension = self.label_to_dimension[first_label] 164 | 165 | dimensions[dimension] = label 166 | 167 | feature_seq.append(dimensions) 168 | 169 | for dimension in unimorph_schema: 170 | labels = [f[dimension] for f in feature_seq] 171 | fields[dimension] = SequenceLabelField(labels, tokens, label_namespace=dimension) 172 | 173 | fields["metadata"] = MetadataField({ 174 | "words": words, 175 | "feats": feats, 176 | "lemmas": lemmas, 177 | "lemma_rules": lemma_rules, 178 | "ids": ids, 179 | "multiword_ids": multiword_ids, 180 | "multiword_forms": multiword_forms 181 | }) 182 | 183 | return Instance(fields) -------------------------------------------------------------------------------- /udify/dataset_readers/universal_dependencies.py: -------------------------------------------------------------------------------- 1 | """ 2 | A Dataset Reader for Universal Dependencies, with support for multiword tokens and special handling for NULL "_" tokens 3 | """ 4 | 5 | from typing import Dict, Tuple, List, Any, Callable 6 | 7 | from overrides import overrides 8 | from udify.dataset_readers.parser import parse_line, DEFAULT_FIELDS, process_multiword_tokens 9 | from conllu import parse_incr 10 | 11 | from allennlp.common.file_utils import cached_path 12 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader 13 | from allennlp.data.fields import Field, TextField, SequenceLabelField, MetadataField 14 | from allennlp.data.instance import Instance 15 | from allennlp.data.token_indexers import SingleIdTokenIndexer, TokenIndexer 16 | from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter, WordSplitter 17 | from allennlp.data.tokenizers import Token 18 | 19 | from udify.dataset_readers.lemma_edit import gen_lemma_rule 20 | 21 | import logging 22 | 23 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 24 | 25 | 26 | @DatasetReader.register("udify_universal_dependencies") 27 | class UniversalDependenciesDatasetReader(DatasetReader): 28 | def __init__(self, 29 | token_indexers: Dict[str, TokenIndexer] = None, 30 | lazy: bool = False) -> None: 31 | super().__init__(lazy) 32 | self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()} 33 | 34 | @overrides 35 | def _read(self, file_path: str): 36 | # if `file_path` is a URL, redirect to the cache 37 | file_path = cached_path(file_path) 38 | 39 | with open(file_path, 'r') as conllu_file: 40 | logger.info("Reading UD instances from conllu dataset at: %s", file_path) 41 | 42 | for annotation in parse_incr(conllu_file): 43 | # CoNLLU annotations sometimes add back in words that have been elided 44 | # in the original sentence; we remove these, as we're just predicting 45 | # dependencies for the original sentence. 46 | # We filter by None here as elided words have a non-integer word id, 47 | # and we replace these word ids with None in process_multiword_tokens. 48 | annotation = process_multiword_tokens(annotation) 49 | multiword_tokens = [x for x in annotation if x["multi_id"] is not None] 50 | annotation = [x for x in annotation if x["id"] is not None] 51 | 52 | if len(annotation) == 0: 53 | continue 54 | 55 | def get_field(tag: str, map_fn: Callable[[Any], Any] = None) -> List[Any]: 56 | map_fn = map_fn if map_fn is not None else lambda x: x 57 | return [map_fn(x[tag]) if x[tag] is not None else "_" for x in annotation if tag in x] 58 | 59 | # Extract multiword token rows (not used for prediction, purely for evaluation) 60 | ids = [x["id"] for x in annotation] 61 | multiword_ids = [x["multi_id"] for x in multiword_tokens] 62 | multiword_forms = [x["form"] for x in multiword_tokens] 63 | 64 | words = get_field("form") 65 | lemmas = get_field("lemma") 66 | lemma_rules = [gen_lemma_rule(word, lemma) 67 | if lemma != "_" else "_" 68 | for word, lemma in zip(words, lemmas)] 69 | upos_tags = get_field("upostag") 70 | xpos_tags = get_field("xpostag") 71 | feats = get_field("feats", lambda x: "|".join(k + "=" + v for k, v in x.items()) 72 | if hasattr(x, "items") else "_") 73 | heads = get_field("head") 74 | dep_rels = get_field("deprel") 75 | dependencies = list(zip(dep_rels, heads)) 76 | 77 | yield self.text_to_instance(words, lemmas, lemma_rules, upos_tags, xpos_tags, 78 | feats, dependencies, ids, multiword_ids, multiword_forms) 79 | 80 | @overrides 81 | def text_to_instance(self, # type: ignore 82 | words: List[str], 83 | lemmas: List[str] = None, 84 | lemma_rules: List[str] = None, 85 | upos_tags: List[str] = None, 86 | xpos_tags: List[str] = None, 87 | feats: List[str] = None, 88 | dependencies: List[Tuple[str, int]] = None, 89 | ids: List[str] = None, 90 | multiword_ids: List[str] = None, 91 | multiword_forms: List[str] = None) -> Instance: 92 | fields: Dict[str, Field] = {} 93 | 94 | tokens = TextField([Token(w) for w in words], self._token_indexers) 95 | fields["tokens"] = tokens 96 | 97 | names = ["upos", "xpos", "feats", "lemmas"] 98 | all_tags = [upos_tags, xpos_tags, feats, lemma_rules] 99 | for name, field in zip(names, all_tags): 100 | if field: 101 | fields[name] = SequenceLabelField(field, tokens, label_namespace=name) 102 | 103 | if dependencies is not None: 104 | # We don't want to expand the label namespace with an additional dummy token, so we'll 105 | # always give the 'ROOT_HEAD' token a label of 'root'. 106 | fields["head_tags"] = SequenceLabelField([x[0] for x in dependencies], 107 | tokens, 108 | label_namespace="head_tags") 109 | fields["head_indices"] = SequenceLabelField([int(x[1]) for x in dependencies], 110 | tokens, 111 | label_namespace="head_index_tags") 112 | 113 | fields["metadata"] = MetadataField({ 114 | "words": words, 115 | "upos_tags": upos_tags, 116 | "xpos_tags": xpos_tags, 117 | "feats": feats, 118 | "lemmas": lemmas, 119 | "lemma_rules": lemma_rules, 120 | "ids": ids, 121 | "multiword_ids": multiword_ids, 122 | "multiword_forms": multiword_forms 123 | }) 124 | 125 | return Instance(fields) 126 | 127 | 128 | @DatasetReader.register("udify_universal_dependencies_raw") 129 | class UniversalDependenciesRawDatasetReader(DatasetReader): 130 | """Like UniversalDependenciesDatasetReader, but reads raw sentences and tokenizes them first.""" 131 | 132 | def __init__(self, 133 | dataset_reader: DatasetReader, 134 | tokenizer: WordSplitter = None) -> None: 135 | super().__init__(lazy=dataset_reader.lazy) 136 | self.dataset_reader = dataset_reader 137 | if tokenizer: 138 | self.tokenizer = tokenizer 139 | else: 140 | self.tokenizer = SpacyWordSplitter(language="xx_ent_wiki_sm") 141 | 142 | @overrides 143 | def _read(self, file_path: str): 144 | # if `file_path` is a URL, redirect to the cache 145 | file_path = cached_path(file_path) 146 | 147 | with open(file_path, 'r') as conllu_file: 148 | for sentence in conllu_file: 149 | if sentence: 150 | words = [word.text for word in self.tokenizer.split_words(sentence)] 151 | yield self.text_to_instance(words) 152 | 153 | @overrides 154 | def text_to_instance(self, words: List[str]) -> Instance: 155 | return self.dataset_reader.text_to_instance(words) -------------------------------------------------------------------------------- /udify/models/__init__.py: -------------------------------------------------------------------------------- 1 | from udify.models.udify_model import UdifyModel 2 | from udify.models.dependency_decoder import DependencyDecoder 3 | from udify.models.tag_decoder import TagDecoder 4 | -------------------------------------------------------------------------------- /udify/models/dependency_decoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Decodes dependency trees given a list of contextualized word embeddings 3 | """ 4 | 5 | from typing import Dict, Optional, Tuple, Any, List 6 | import logging 7 | import copy 8 | 9 | from overrides import overrides 10 | import torch 11 | import torch.nn.functional as F 12 | import numpy 13 | 14 | from allennlp.common.checks import check_dimensions_match 15 | from allennlp.data import Vocabulary 16 | from allennlp.modules import Embedding, InputVariationalDropout, Seq2SeqEncoder 17 | from allennlp.modules.matrix_attention.bilinear_matrix_attention import BilinearMatrixAttention 18 | from allennlp.modules import FeedForward 19 | from allennlp.models.model import Model 20 | from allennlp.nn import InitializerApplicator, RegularizerApplicator, Activation 21 | from allennlp.nn.util import get_range_vector 22 | from allennlp.nn.util import get_device_of, masked_log_softmax, get_lengths_from_binary_sequence_mask 23 | from allennlp.nn.chu_liu_edmonds import decode_mst 24 | from allennlp.training.metrics import AttachmentScores 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | POS_TO_IGNORE = {'``', "''", ':', ',', '.', 'PU', 'PUNCT', 'SYM'} 29 | 30 | 31 | @Model.register("udify_dependency_decoder") 32 | class DependencyDecoder(Model): 33 | """ 34 | Modifies BiaffineDependencyParser, removing the input TextFieldEmbedder dependency to allow the model to 35 | essentially act as a decoder when given intermediate word embeddings instead of as a standalone model. 36 | """ 37 | 38 | def __init__(self, 39 | vocab: Vocabulary, 40 | encoder: Seq2SeqEncoder, 41 | tag_representation_dim: int, 42 | arc_representation_dim: int, 43 | pos_embed_dim: int = None, 44 | tag_feedforward: FeedForward = None, 45 | arc_feedforward: FeedForward = None, 46 | use_mst_decoding_for_validation: bool = True, 47 | dropout: float = 0.0, 48 | initializer: InitializerApplicator = InitializerApplicator(), 49 | regularizer: Optional[RegularizerApplicator] = None) -> None: 50 | super(DependencyDecoder, self).__init__(vocab, regularizer) 51 | 52 | self.pos_tag_embedding = None 53 | if pos_embed_dim is not None: 54 | self.pos_tag_embedding = Embedding(self.vocab.get_vocab_size("upos"), pos_embed_dim) 55 | 56 | self.dropout = torch.nn.Dropout(p=dropout) 57 | 58 | self.encoder = encoder 59 | encoder_output_dim = encoder.get_output_dim() 60 | 61 | self.head_arc_feedforward = arc_feedforward or \ 62 | FeedForward(encoder_output_dim, 1, 63 | arc_representation_dim, 64 | Activation.by_name("elu")()) 65 | self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward) 66 | 67 | self.arc_attention = BilinearMatrixAttention(arc_representation_dim, 68 | arc_representation_dim, 69 | use_input_biases=True) 70 | 71 | num_labels = self.vocab.get_vocab_size("head_tags") 72 | 73 | self.head_tag_feedforward = tag_feedforward or \ 74 | FeedForward(encoder_output_dim, 1, 75 | tag_representation_dim, 76 | Activation.by_name("elu")()) 77 | self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward) 78 | 79 | self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim, 80 | tag_representation_dim, 81 | num_labels) 82 | 83 | self._dropout = InputVariationalDropout(dropout) 84 | self._head_sentinel = torch.nn.Parameter(torch.randn([1, 1, encoder_output_dim])) 85 | 86 | check_dimensions_match(tag_representation_dim, self.head_tag_feedforward.get_output_dim(), 87 | "tag representation dim", "tag feedforward output dim") 88 | check_dimensions_match(arc_representation_dim, self.head_arc_feedforward.get_output_dim(), 89 | "arc representation dim", "arc feedforward output dim") 90 | 91 | self.use_mst_decoding_for_validation = use_mst_decoding_for_validation 92 | 93 | tags = self.vocab.get_token_to_index_vocabulary("pos") 94 | punctuation_tag_indices = {tag: index for tag, index in tags.items() if tag in POS_TO_IGNORE} 95 | self._pos_to_ignore = set(punctuation_tag_indices.values()) 96 | logger.info(f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. " 97 | "Ignoring words with these POS tags for evaluation.") 98 | 99 | self._attachment_scores = AttachmentScores() 100 | initializer(self) 101 | 102 | @overrides 103 | def forward(self, # type: ignore 104 | # words: Dict[str, torch.LongTensor], 105 | encoded_text: torch.FloatTensor, 106 | mask: torch.LongTensor, 107 | pos_logits: torch.LongTensor = None, # predicted 108 | head_tags: torch.LongTensor = None, 109 | head_indices: torch.LongTensor = None, 110 | metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: 111 | 112 | batch_size, _, _ = encoded_text.size() 113 | 114 | pos_tags = None 115 | if pos_logits is not None and self.pos_tag_embedding is not None: 116 | # Embed the predicted POS tags and concatenate the embeddings to the input 117 | num_pos_classes = pos_logits.size(-1) 118 | pos_logits = pos_logits.view(-1, num_pos_classes) 119 | _, pos_tags = pos_logits.max(-1) 120 | 121 | pos_embed_size = self.pos_tag_embedding.get_output_dim() 122 | embedded_pos_tags = self.dropout(self.pos_tag_embedding(pos_tags)) 123 | embedded_pos_tags = embedded_pos_tags.view(batch_size, -1, pos_embed_size) 124 | encoded_text = torch.cat([encoded_text, embedded_pos_tags], -1) 125 | 126 | encoded_text = self.encoder(encoded_text, mask) 127 | 128 | batch_size, _, encoding_dim = encoded_text.size() 129 | 130 | head_sentinel = self._head_sentinel.expand(batch_size, 1, encoding_dim) 131 | # Concatenate the head sentinel onto the sentence representation. 132 | encoded_text = torch.cat([head_sentinel, encoded_text], 1) 133 | mask = torch.cat([mask.new_ones(batch_size, 1), mask], 1) 134 | if head_indices is not None: 135 | head_indices = torch.cat([head_indices.new_zeros(batch_size, 1), head_indices], 1) 136 | if head_tags is not None: 137 | head_tags = torch.cat([head_tags.new_zeros(batch_size, 1), head_tags], 1) 138 | float_mask = mask.float() 139 | encoded_text = self._dropout(encoded_text) 140 | 141 | # shape (batch_size, sequence_length, arc_representation_dim) 142 | head_arc_representation = self._dropout(self.head_arc_feedforward(encoded_text)) 143 | child_arc_representation = self._dropout(self.child_arc_feedforward(encoded_text)) 144 | 145 | # shape (batch_size, sequence_length, tag_representation_dim) 146 | head_tag_representation = self._dropout(self.head_tag_feedforward(encoded_text)) 147 | child_tag_representation = self._dropout(self.child_tag_feedforward(encoded_text)) 148 | # shape (batch_size, sequence_length, sequence_length) 149 | attended_arcs = self.arc_attention(head_arc_representation, 150 | child_arc_representation) 151 | 152 | minus_inf = -1e8 153 | minus_mask = (1 - float_mask) * minus_inf 154 | attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1) 155 | 156 | if self.training or not self.use_mst_decoding_for_validation: 157 | predicted_heads, predicted_head_tags = self._greedy_decode(head_tag_representation, 158 | child_tag_representation, 159 | attended_arcs, 160 | mask) 161 | else: 162 | predicted_heads, predicted_head_tags = self._mst_decode(head_tag_representation, 163 | child_tag_representation, 164 | attended_arcs, 165 | mask) 166 | if head_indices is not None and head_tags is not None: 167 | 168 | arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation, 169 | child_tag_representation=child_tag_representation, 170 | attended_arcs=attended_arcs, 171 | head_indices=head_indices, 172 | head_tags=head_tags, 173 | mask=mask) 174 | loss = arc_nll + tag_nll 175 | 176 | evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags) 177 | # We calculate attachment scores for the whole sentence 178 | # but excluding the symbolic ROOT token at the start, 179 | # which is why we start from the second element in the sequence. 180 | self._attachment_scores(predicted_heads[:, 1:], 181 | predicted_head_tags[:, 1:], 182 | head_indices[:, 1:], 183 | head_tags[:, 1:], 184 | evaluation_mask) 185 | else: 186 | arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation, 187 | child_tag_representation=child_tag_representation, 188 | attended_arcs=attended_arcs, 189 | head_indices=predicted_heads.long(), 190 | head_tags=predicted_head_tags.long(), 191 | mask=mask) 192 | loss = arc_nll + tag_nll 193 | 194 | output_dict = { 195 | "heads": predicted_heads, 196 | "head_tags": predicted_head_tags, 197 | "arc_loss": arc_nll, 198 | "tag_loss": tag_nll, 199 | "loss": loss, 200 | "mask": mask, 201 | "words": [meta["words"] for meta in metadata], 202 | # "pos": [meta["pos"] for meta in metadata] 203 | } 204 | 205 | return output_dict 206 | 207 | @overrides 208 | def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 209 | 210 | head_tags = output_dict.pop("head_tags").cpu().detach().numpy() 211 | heads = output_dict.pop("heads").cpu().detach().numpy() 212 | mask = output_dict.pop("mask") 213 | lengths = get_lengths_from_binary_sequence_mask(mask) 214 | head_tag_labels = [] 215 | head_indices = [] 216 | for instance_heads, instance_tags, length in zip(heads, head_tags, lengths): 217 | instance_heads = list(instance_heads[1:length]) 218 | instance_tags = instance_tags[1:length] 219 | labels = [self.vocab.get_token_from_index(label, "head_tags") 220 | for label in instance_tags] 221 | head_tag_labels.append(labels) 222 | head_indices.append(instance_heads) 223 | 224 | output_dict["predicted_dependencies"] = head_tag_labels 225 | output_dict["predicted_heads"] = head_indices 226 | return output_dict 227 | 228 | def _construct_loss(self, 229 | head_tag_representation: torch.Tensor, 230 | child_tag_representation: torch.Tensor, 231 | attended_arcs: torch.Tensor, 232 | head_indices: torch.Tensor, 233 | head_tags: torch.Tensor, 234 | mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 235 | """ 236 | Computes the arc and tag loss for a sequence given gold head indices and tags. 237 | Parameters 238 | ---------- 239 | head_tag_representation : ``torch.Tensor``, required. 240 | A tensor of shape (batch_size, sequence_length, tag_representation_dim), 241 | which will be used to generate predictions for the dependency tags 242 | for the given arcs. 243 | child_tag_representation : ``torch.Tensor``, required 244 | A tensor of shape (batch_size, sequence_length, tag_representation_dim), 245 | which will be used to generate predictions for the dependency tags 246 | for the given arcs. 247 | attended_arcs : ``torch.Tensor``, required. 248 | A tensor of shape (batch_size, sequence_length, sequence_length) used to generate 249 | a distribution over attachments of a given word to all other words. 250 | head_indices : ``torch.Tensor``, required. 251 | A tensor of shape (batch_size, sequence_length). 252 | The indices of the heads for every word. 253 | head_tags : ``torch.Tensor``, required. 254 | A tensor of shape (batch_size, sequence_length). 255 | The dependency labels of the heads for every word. 256 | mask : ``torch.Tensor``, required. 257 | A mask of shape (batch_size, sequence_length), denoting unpadded 258 | elements in the sequence. 259 | Returns 260 | ------- 261 | arc_nll : ``torch.Tensor``, required. 262 | The negative log likelihood from the arc loss. 263 | tag_nll : ``torch.Tensor``, required. 264 | The negative log likelihood from the arc tag loss. 265 | """ 266 | float_mask = mask.float() 267 | batch_size, sequence_length, _ = attended_arcs.size() 268 | # shape (batch_size, 1) 269 | range_vector = get_range_vector(batch_size, get_device_of(attended_arcs)).unsqueeze(1) 270 | # shape (batch_size, sequence_length, sequence_length) 271 | normalised_arc_logits = masked_log_softmax(attended_arcs, 272 | mask) * float_mask.unsqueeze(2) * float_mask.unsqueeze(1) 273 | 274 | # shape (batch_size, sequence_length, num_head_tags) 275 | head_tag_logits = self._get_head_tags(head_tag_representation, child_tag_representation, head_indices) 276 | normalised_head_tag_logits = masked_log_softmax(head_tag_logits, 277 | mask.unsqueeze(-1)) * float_mask.unsqueeze(-1) 278 | # index matrix with shape (batch, sequence_length) 279 | timestep_index = get_range_vector(sequence_length, get_device_of(attended_arcs)) 280 | child_index = timestep_index.view(1, sequence_length).expand(batch_size, sequence_length).long() 281 | # shape (batch_size, sequence_length) 282 | arc_loss = normalised_arc_logits[range_vector, child_index, head_indices] 283 | tag_loss = normalised_head_tag_logits[range_vector, child_index, head_tags] 284 | # We don't care about predictions for the symbolic ROOT token's head, 285 | # so we remove it from the loss. 286 | arc_loss = arc_loss[:, 1:] 287 | tag_loss = tag_loss[:, 1:] 288 | 289 | # The number of valid positions is equal to the number of unmasked elements minus 290 | # 1 per sequence in the batch, to account for the symbolic HEAD token. 291 | valid_positions = mask.sum() - batch_size 292 | 293 | arc_nll = -arc_loss.sum() / valid_positions.float() 294 | tag_nll = -tag_loss.sum() / valid_positions.float() 295 | return arc_nll, tag_nll 296 | 297 | def _greedy_decode(self, 298 | head_tag_representation: torch.Tensor, 299 | child_tag_representation: torch.Tensor, 300 | attended_arcs: torch.Tensor, 301 | mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 302 | """ 303 | Decodes the head and head tag predictions by decoding the unlabeled arcs 304 | independently for each word and then again, predicting the head tags of 305 | these greedily chosen arcs independently. Note that this method of decoding 306 | is not guaranteed to produce trees (i.e. there maybe be multiple roots, 307 | or cycles when children are attached to their parents). 308 | Parameters 309 | ---------- 310 | head_tag_representation : ``torch.Tensor``, required. 311 | A tensor of shape (batch_size, sequence_length, tag_representation_dim), 312 | which will be used to generate predictions for the dependency tags 313 | for the given arcs. 314 | child_tag_representation : ``torch.Tensor``, required 315 | A tensor of shape (batch_size, sequence_length, tag_representation_dim), 316 | which will be used to generate predictions for the dependency tags 317 | for the given arcs. 318 | attended_arcs : ``torch.Tensor``, required. 319 | A tensor of shape (batch_size, sequence_length, sequence_length) used to generate 320 | a distribution over attachments of a given word to all other words. 321 | Returns 322 | ------- 323 | heads : ``torch.Tensor`` 324 | A tensor of shape (batch_size, sequence_length) representing the 325 | greedily decoded heads of each word. 326 | head_tags : ``torch.Tensor`` 327 | A tensor of shape (batch_size, sequence_length) representing the 328 | dependency tags of the greedily decoded heads of each word. 329 | """ 330 | # Mask the diagonal, because the head of a word can't be itself. 331 | attended_arcs = attended_arcs + torch.diag(attended_arcs.new(mask.size(1)).fill_(-numpy.inf)) 332 | # Mask padded tokens, because we only want to consider actual words as heads. 333 | if mask is not None: 334 | minus_mask = (1 - mask).bool().unsqueeze(2) 335 | attended_arcs.masked_fill_(minus_mask, -numpy.inf) 336 | 337 | # Compute the heads greedily. 338 | # shape (batch_size, sequence_length) 339 | _, heads = attended_arcs.max(dim=2) 340 | 341 | # Given the greedily predicted heads, decode their dependency tags. 342 | # shape (batch_size, sequence_length, num_head_tags) 343 | head_tag_logits = self._get_head_tags(head_tag_representation, 344 | child_tag_representation, 345 | heads) 346 | _, head_tags = head_tag_logits.max(dim=2) 347 | return heads, head_tags 348 | 349 | def _mst_decode(self, 350 | head_tag_representation: torch.Tensor, 351 | child_tag_representation: torch.Tensor, 352 | attended_arcs: torch.Tensor, 353 | mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 354 | """ 355 | Decodes the head and head tag predictions using the Edmonds' Algorithm 356 | for finding minimum spanning trees on directed graphs. Nodes in the 357 | graph are the words in the sentence, and between each pair of nodes, 358 | there is an edge in each direction, where the weight of the edge corresponds 359 | to the most likely dependency label probability for that arc. The MST is 360 | then generated from this directed graph. 361 | Parameters 362 | ---------- 363 | head_tag_representation : ``torch.Tensor``, required. 364 | A tensor of shape (batch_size, sequence_length, tag_representation_dim), 365 | which will be used to generate predictions for the dependency tags 366 | for the given arcs. 367 | child_tag_representation : ``torch.Tensor``, required 368 | A tensor of shape (batch_size, sequence_length, tag_representation_dim), 369 | which will be used to generate predictions for the dependency tags 370 | for the given arcs. 371 | attended_arcs : ``torch.Tensor``, required. 372 | A tensor of shape (batch_size, sequence_length, sequence_length) used to generate 373 | a distribution over attachments of a given word to all other words. 374 | Returns 375 | ------- 376 | heads : ``torch.Tensor`` 377 | A tensor of shape (batch_size, sequence_length) representing the 378 | greedily decoded heads of each word. 379 | head_tags : ``torch.Tensor`` 380 | A tensor of shape (batch_size, sequence_length) representing the 381 | dependency tags of the optimally decoded heads of each word. 382 | """ 383 | batch_size, sequence_length, tag_representation_dim = head_tag_representation.size() 384 | 385 | lengths = mask.data.sum(dim=1).long().cpu().numpy() 386 | 387 | expanded_shape = [batch_size, sequence_length, sequence_length, tag_representation_dim] 388 | head_tag_representation = head_tag_representation.unsqueeze(2) 389 | head_tag_representation = head_tag_representation.expand(*expanded_shape).contiguous() 390 | child_tag_representation = child_tag_representation.unsqueeze(1) 391 | child_tag_representation = child_tag_representation.expand(*expanded_shape).contiguous() 392 | # Shape (batch_size, sequence_length, sequence_length, num_head_tags) 393 | pairwise_head_logits = self.tag_bilinear(head_tag_representation, child_tag_representation) 394 | 395 | # Note that this log_softmax is over the tag dimension, and we don't consider pairs 396 | # of tags which are invalid (e.g are a pair which includes a padded element) anyway below. 397 | # Shape (batch, num_labels,sequence_length, sequence_length) 398 | normalized_pairwise_head_logits = F.log_softmax(pairwise_head_logits, dim=3).permute(0, 3, 1, 2) 399 | 400 | # Mask padded tokens, because we only want to consider actual words as heads. 401 | minus_inf = -1e8 402 | minus_mask = (1 - mask.float()) * minus_inf 403 | attended_arcs = attended_arcs + minus_mask.unsqueeze(2) + minus_mask.unsqueeze(1) 404 | 405 | # Shape (batch_size, sequence_length, sequence_length) 406 | normalized_arc_logits = F.log_softmax(attended_arcs, dim=2).transpose(1, 2) 407 | 408 | # Shape (batch_size, num_head_tags, sequence_length, sequence_length) 409 | # This energy tensor expresses the following relation: 410 | # energy[i,j] = "Score that i is the head of j". In this 411 | # case, we have heads pointing to their children. 412 | batch_energy = torch.exp(normalized_arc_logits.unsqueeze(1) + normalized_pairwise_head_logits) 413 | return self._run_mst_decoding(batch_energy, lengths) 414 | 415 | @staticmethod 416 | def _run_mst_decoding(batch_energy: torch.Tensor, lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 417 | heads = [] 418 | head_tags = [] 419 | for energy, length in zip(batch_energy.detach().cpu(), lengths): 420 | scores, tag_ids = energy.max(dim=0) 421 | # Although we need to include the root node so that the MST includes it, 422 | # we do not want any word to be the parent of the root node. 423 | # Here, we enforce this by setting the scores for all word -> ROOT edges 424 | # edges to be 0. 425 | scores[0, :] = 0 426 | # Decode the heads. Because we modify the scores to prevent 427 | # adding in word -> ROOT edges, we need to find the labels ourselves. 428 | instance_heads, _ = decode_mst(scores.numpy(), length, has_labels=False) 429 | 430 | # Find the labels which correspond to the edges in the max spanning tree. 431 | instance_head_tags = [] 432 | for child, parent in enumerate(instance_heads): 433 | instance_head_tags.append(tag_ids[parent, child].item()) 434 | # We don't care what the head or tag is for the root token, but by default it's 435 | # not necesarily the same in the batched vs unbatched case, which is annoying. 436 | # Here we'll just set them to zero. 437 | instance_heads[0] = 0 438 | instance_head_tags[0] = 0 439 | heads.append(instance_heads) 440 | head_tags.append(instance_head_tags) 441 | return torch.from_numpy(numpy.stack(heads)), torch.from_numpy(numpy.stack(head_tags)) 442 | 443 | def _get_head_tags(self, 444 | head_tag_representation: torch.Tensor, 445 | child_tag_representation: torch.Tensor, 446 | head_indices: torch.Tensor) -> torch.Tensor: 447 | """ 448 | Decodes the head tags given the head and child tag representations 449 | and a tensor of head indices to compute tags for. Note that these are 450 | either gold or predicted heads, depending on whether this function is 451 | being called to compute the loss, or if it's being called during inference. 452 | Parameters 453 | ---------- 454 | head_tag_representation : ``torch.Tensor``, required. 455 | A tensor of shape (batch_size, sequence_length, tag_representation_dim), 456 | which will be used to generate predictions for the dependency tags 457 | for the given arcs. 458 | child_tag_representation : ``torch.Tensor``, required 459 | A tensor of shape (batch_size, sequence_length, tag_representation_dim), 460 | which will be used to generate predictions for the dependency tags 461 | for the given arcs. 462 | head_indices : ``torch.Tensor``, required. 463 | A tensor of shape (batch_size, sequence_length). The indices of the heads 464 | for every word. 465 | Returns 466 | ------- 467 | head_tag_logits : ``torch.Tensor`` 468 | A tensor of shape (batch_size, sequence_length, num_head_tags), 469 | representing logits for predicting a distribution over tags 470 | for each arc. 471 | """ 472 | batch_size = head_tag_representation.size(0) 473 | # shape (batch_size,) 474 | range_vector = get_range_vector(batch_size, get_device_of(head_tag_representation)).unsqueeze(1) 475 | 476 | # This next statement is quite a complex piece of indexing, which you really 477 | # need to read the docs to understand. See here: 478 | # https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.indexing.html#advanced-indexing 479 | # In effect, we are selecting the indices corresponding to the heads of each word from the 480 | # sequence length dimension for each element in the batch. 481 | 482 | # shape (batch_size, sequence_length, tag_representation_dim) 483 | selected_head_tag_representations = head_tag_representation[range_vector, head_indices] 484 | selected_head_tag_representations = selected_head_tag_representations.contiguous() 485 | # shape (batch_size, sequence_length, num_head_tags) 486 | head_tag_logits = self.tag_bilinear(selected_head_tag_representations, 487 | child_tag_representation) 488 | return head_tag_logits 489 | 490 | def _get_mask_for_eval(self, 491 | mask: torch.LongTensor, 492 | pos_tags: torch.LongTensor) -> torch.LongTensor: 493 | """ 494 | Dependency evaluation excludes words are punctuation. 495 | Here, we create a new mask to exclude word indices which 496 | have a "punctuation-like" part of speech tag. 497 | Parameters 498 | ---------- 499 | mask : ``torch.LongTensor``, required. 500 | The original mask. 501 | pos_tags : ``torch.LongTensor``, required. 502 | The pos tags for the sequence. 503 | Returns 504 | ------- 505 | A new mask, where any indices equal to labels 506 | we should be ignoring are masked. 507 | """ 508 | new_mask = mask.detach() 509 | for label in self._pos_to_ignore: 510 | label_mask = pos_tags.eq(label).long() 511 | new_mask = new_mask * (1 - label_mask) 512 | return new_mask 513 | 514 | @overrides 515 | def get_metrics(self, reset: bool = False) -> Dict[str, float]: 516 | return {f".run/deps/{metric_name}": metric 517 | for metric_name, metric in self._attachment_scores.get_metric(reset).items()} 518 | -------------------------------------------------------------------------------- /udify/models/tag_decoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Decodes sequences of tags, e.g., POS tags, given a list of contextualized word embeddings 3 | """ 4 | 5 | from typing import Optional, Any, Dict, List 6 | from overrides import overrides 7 | 8 | import numpy 9 | import torch 10 | from torch.nn.modules.linear import Linear 11 | from torch.nn.modules.adaptive import AdaptiveLogSoftmaxWithLoss 12 | import torch.nn.functional as F 13 | 14 | from allennlp.data import Vocabulary 15 | from allennlp.modules import TimeDistributed, Seq2SeqEncoder 16 | from allennlp.models.model import Model 17 | from allennlp.nn import InitializerApplicator, RegularizerApplicator 18 | from allennlp.nn.util import sequence_cross_entropy_with_logits 19 | from allennlp.training.metrics import CategoricalAccuracy 20 | 21 | from udify.dataset_readers.lemma_edit import apply_lemma_rule 22 | 23 | 24 | def sequence_cross_entropy(log_probs: torch.FloatTensor, 25 | targets: torch.LongTensor, 26 | weights: torch.FloatTensor, 27 | average: str = "batch", 28 | label_smoothing: float = None) -> torch.FloatTensor: 29 | if average not in {None, "token", "batch"}: 30 | raise ValueError("Got average f{average}, expected one of " 31 | "None, 'token', or 'batch'") 32 | # shape : (batch * sequence_length, num_classes) 33 | log_probs_flat = log_probs.view(-1, log_probs.size(2)) 34 | # shape : (batch * max_len, 1) 35 | targets_flat = targets.view(-1, 1).long() 36 | 37 | if label_smoothing is not None and label_smoothing > 0.0: 38 | num_classes = log_probs.size(-1) 39 | smoothing_value = label_smoothing / num_classes 40 | # Fill all the correct indices with 1 - smoothing value. 41 | one_hot_targets = torch.zeros_like(log_probs_flat).scatter_(-1, targets_flat, 1.0 - label_smoothing) 42 | smoothed_targets = one_hot_targets + smoothing_value 43 | negative_log_likelihood_flat = - log_probs_flat * smoothed_targets 44 | negative_log_likelihood_flat = negative_log_likelihood_flat.sum(-1, keepdim=True) 45 | else: 46 | # Contribution to the negative log likelihood only comes from the exact indices 47 | # of the targets, as the target distributions are one-hot. Here we use torch.gather 48 | # to extract the indices of the num_classes dimension which contribute to the loss. 49 | # shape : (batch * sequence_length, 1) 50 | negative_log_likelihood_flat = - torch.gather(log_probs_flat, dim=1, index=targets_flat) 51 | # shape : (batch, sequence_length) 52 | negative_log_likelihood = negative_log_likelihood_flat.view(*targets.size()) 53 | # shape : (batch, sequence_length) 54 | negative_log_likelihood = negative_log_likelihood * weights.float() 55 | 56 | if average == "batch": 57 | # shape : (batch_size,) 58 | per_batch_loss = negative_log_likelihood.sum(1) / (weights.sum(1).float() + 1e-13) 59 | num_non_empty_sequences = ((weights.sum(1) > 0).float().sum() + 1e-13) 60 | return per_batch_loss.sum() / num_non_empty_sequences 61 | elif average == "token": 62 | return negative_log_likelihood.sum() / (weights.sum().float() + 1e-13) 63 | else: 64 | # shape : (batch_size,) 65 | per_batch_loss = negative_log_likelihood.sum(1) / (weights.sum(1).float() + 1e-13) 66 | return per_batch_loss 67 | 68 | 69 | @Model.register("udify_tag_decoder") 70 | class TagDecoder(Model): 71 | """ 72 | A basic sequence tagger that decodes from inputs of word embeddings 73 | """ 74 | def __init__(self, 75 | vocab: Vocabulary, 76 | task: str, 77 | encoder: Seq2SeqEncoder, 78 | label_smoothing: float = 0.0, 79 | dropout: float = 0.0, 80 | adaptive: bool = False, 81 | features: List[str] = None, 82 | initializer: InitializerApplicator = InitializerApplicator(), 83 | regularizer: Optional[RegularizerApplicator] = None) -> None: 84 | super(TagDecoder, self).__init__(vocab, regularizer) 85 | 86 | self.task = task 87 | self.encoder = encoder 88 | self.output_dim = encoder.get_output_dim() 89 | self.label_smoothing = label_smoothing 90 | self.num_classes = self.vocab.get_vocab_size(task) 91 | self.adaptive = adaptive 92 | self.features = features if features else [] 93 | 94 | self.metrics = { 95 | "acc": CategoricalAccuracy(), 96 | # "acc3": CategoricalAccuracy(top_k=3) 97 | } 98 | 99 | if self.adaptive: 100 | # TODO 101 | adaptive_cutoffs = [round(self.num_classes / 15), 3 * round(self.num_classes / 15)] 102 | self.task_output = AdaptiveLogSoftmaxWithLoss(self.output_dim, 103 | self.num_classes, 104 | cutoffs=adaptive_cutoffs, 105 | div_value=4.0) 106 | else: 107 | self.task_output = TimeDistributed(Linear(self.output_dim, self.num_classes)) 108 | 109 | self.feature_outputs = torch.nn.ModuleDict() 110 | self.features_metrics = {} 111 | for feature in self.features: 112 | self.feature_outputs[feature] = TimeDistributed(Linear(self.output_dim, 113 | vocab.get_vocab_size(feature))) 114 | self.features_metrics[feature] = { 115 | "acc": CategoricalAccuracy(), 116 | } 117 | 118 | initializer(self) 119 | 120 | @overrides 121 | def forward(self, 122 | encoded_text: torch.FloatTensor, 123 | mask: torch.LongTensor, 124 | gold_tags: Dict[str, torch.LongTensor], 125 | metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: 126 | hidden = encoded_text 127 | hidden = self.encoder(hidden, mask) 128 | 129 | batch_size, sequence_length, _ = hidden.size() 130 | output_dim = [batch_size, sequence_length, self.num_classes] 131 | 132 | loss_fn = self._adaptive_loss if self.adaptive else self._loss 133 | 134 | output_dict = loss_fn(hidden, mask, gold_tags.get(self.task, None), output_dim) 135 | self._features_loss(hidden, mask, gold_tags, output_dict) 136 | 137 | return output_dict 138 | 139 | def _adaptive_loss(self, hidden, mask, gold_tags, output_dim): 140 | logits = hidden 141 | reshaped_log_probs = logits.view(-1, logits.size(2)) 142 | 143 | class_probabilities = self.task_output.log_prob(reshaped_log_probs).view(output_dim) 144 | 145 | output_dict = {"logits": logits, "class_probabilities": class_probabilities} 146 | 147 | if gold_tags is not None: 148 | output_dict["loss"] = sequence_cross_entropy(class_probabilities, 149 | gold_tags, 150 | mask, 151 | label_smoothing=self.label_smoothing) 152 | for metric in self.metrics.values(): 153 | metric(class_probabilities, gold_tags, mask.float()) 154 | 155 | return output_dict 156 | 157 | def _loss(self, hidden, mask, gold_tags, output_dim): 158 | logits = self.task_output(hidden) 159 | reshaped_log_probs = logits.view(-1, self.num_classes) 160 | class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view(output_dim) 161 | 162 | output_dict = {"logits": logits, "class_probabilities": class_probabilities} 163 | 164 | if gold_tags is not None: 165 | output_dict["loss"] = sequence_cross_entropy_with_logits(logits, 166 | gold_tags, 167 | mask, 168 | label_smoothing=self.label_smoothing) 169 | for metric in self.metrics.values(): 170 | metric(logits, gold_tags, mask.float()) 171 | 172 | return output_dict 173 | 174 | def _features_loss(self, hidden, mask, gold_tags, output_dict): 175 | if gold_tags is None: 176 | return 177 | 178 | for feature in self.features: 179 | logits = self.feature_outputs[feature](hidden) 180 | loss = sequence_cross_entropy_with_logits(logits, 181 | gold_tags[feature], 182 | mask, 183 | label_smoothing=self.label_smoothing) 184 | loss /= len(self.features) 185 | output_dict["loss"] += loss 186 | 187 | for metric in self.features_metrics[feature].values(): 188 | metric(logits, gold_tags[feature], mask.float()) 189 | 190 | @overrides 191 | def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 192 | all_words = output_dict["words"] 193 | 194 | all_predictions = output_dict["class_probabilities"][self.task].cpu().data.numpy() 195 | if all_predictions.ndim == 3: 196 | predictions_list = [all_predictions[i] for i in range(all_predictions.shape[0])] 197 | else: 198 | predictions_list = [all_predictions] 199 | all_tags = [] 200 | for predictions, words in zip(predictions_list, all_words): 201 | argmax_indices = numpy.argmax(predictions, axis=-1) 202 | tags = [self.vocab.get_token_from_index(x, namespace=self.task) 203 | for x in argmax_indices] 204 | 205 | # TODO: specific task 206 | if self.task == "lemmas": 207 | def decode_lemma(word, rule): 208 | if rule == "_": 209 | return "_" 210 | if rule == "@@UNKNOWN@@": 211 | return word 212 | return apply_lemma_rule(word, rule) 213 | tags = [decode_lemma(word, rule) for word, rule in zip(words, tags)] 214 | 215 | all_tags.append(tags) 216 | output_dict[self.task] = all_tags 217 | 218 | return output_dict 219 | 220 | @overrides 221 | def get_metrics(self, reset: bool = False) -> Dict[str, float]: 222 | main_metrics = { 223 | f".run/{self.task}/{metric_name}": metric.get_metric(reset) 224 | for metric_name, metric in self.metrics.items() 225 | } 226 | 227 | features_metrics = { 228 | f"_run/{self.task}/{feature}/{metric_name}": metric.get_metric(reset) 229 | for feature in self.features 230 | for metric_name, metric in self.features_metrics[feature].items() 231 | } 232 | 233 | return {**main_metrics, **features_metrics} 234 | -------------------------------------------------------------------------------- /udify/models/udify_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | The base UDify model for training and prediction 3 | """ 4 | 5 | from typing import Optional, Any, Dict, List, Tuple 6 | from overrides import overrides 7 | import logging 8 | 9 | import torch 10 | 11 | from pytorch_pretrained_bert.tokenization import BertTokenizer 12 | 13 | from allennlp.common.checks import check_dimensions_match, ConfigurationError 14 | from allennlp.data import Vocabulary 15 | from allennlp.modules import Seq2SeqEncoder, TextFieldEmbedder 16 | from allennlp.models.model import Model 17 | from allennlp.nn import InitializerApplicator, RegularizerApplicator 18 | from allennlp.nn.util import get_text_field_mask 19 | 20 | from udify.modules.scalar_mix import ScalarMixWithDropout 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | @Model.register("udify_model") 26 | class UdifyModel(Model): 27 | """ 28 | The UDify model base class. Applies a sequence of shared encoders before decoding in a multi-task configuration. 29 | Uses TagDecoder and DependencyDecoder to decode each UD task. 30 | """ 31 | 32 | def __init__(self, 33 | vocab: Vocabulary, 34 | tasks: List[str], 35 | text_field_embedder: TextFieldEmbedder, 36 | encoder: Seq2SeqEncoder, 37 | decoders: Dict[str, Model], 38 | post_encoder_embedder: TextFieldEmbedder = None, 39 | dropout: float = 0.0, 40 | word_dropout: float = 0.0, 41 | mix_embedding: int = None, 42 | layer_dropout: int = 0.0, 43 | initializer: InitializerApplicator = InitializerApplicator(), 44 | regularizer: Optional[RegularizerApplicator] = None) -> None: 45 | super(UdifyModel, self).__init__(vocab, regularizer) 46 | 47 | self.tasks = sorted(tasks) 48 | self.vocab = vocab 49 | self.bert_vocab = BertTokenizer.from_pretrained("config/archive/bert-base-multilingual-cased/vocab.txt").vocab 50 | self.text_field_embedder = text_field_embedder 51 | self.post_encoder_embedder = post_encoder_embedder 52 | self.shared_encoder = encoder 53 | self.word_dropout = word_dropout 54 | self.dropout = torch.nn.Dropout(p=dropout) 55 | self.decoders = torch.nn.ModuleDict(decoders) 56 | 57 | if mix_embedding: 58 | self.scalar_mix = torch.nn.ModuleDict({ 59 | task: ScalarMixWithDropout(mix_embedding, 60 | do_layer_norm=False, 61 | dropout=layer_dropout) 62 | for task in self.decoders 63 | }) 64 | else: 65 | self.scalar_mix = None 66 | 67 | self.metrics = {} 68 | 69 | for task in self.tasks: 70 | if task not in self.decoders: 71 | raise ConfigurationError(f"Task {task} has no corresponding decoder. Make sure their names match.") 72 | 73 | check_dimensions_match(text_field_embedder.get_output_dim(), encoder.get_input_dim(), 74 | "text field embedding dim", "encoder input dim") 75 | 76 | initializer(self) 77 | self._count_params() 78 | 79 | @overrides 80 | def forward(self, 81 | tokens: Dict[str, torch.LongTensor], 82 | metadata: List[Dict[str, Any]] = None, 83 | **kwargs: Dict[str, torch.LongTensor]) -> Dict[str, torch.Tensor]: 84 | if "track_epoch" in kwargs: 85 | track_epoch = kwargs.pop("track_epoch") 86 | 87 | gold_tags = kwargs 88 | 89 | if "tokens" in self.tasks: 90 | # Model is predicting tokens, so add them to the gold tags 91 | gold_tags["tokens"] = tokens["tokens"] 92 | 93 | mask = get_text_field_mask(tokens) 94 | self._apply_token_dropout(tokens) 95 | 96 | embedded_text_input = self.text_field_embedder(tokens) 97 | 98 | if self.post_encoder_embedder: 99 | post_embeddings = self.post_encoder_embedder(tokens) 100 | 101 | encoded_text = self.shared_encoder(embedded_text_input, mask) 102 | 103 | logits = {} 104 | class_probabilities = {} 105 | output_dict = {"logits": logits, 106 | "class_probabilities": class_probabilities} 107 | loss = 0 108 | 109 | # Run through each of the tasks on the shared encoder and save predictions 110 | for task in self.tasks: 111 | if self.scalar_mix: 112 | decoder_input = self.scalar_mix[task](encoded_text, mask) 113 | else: 114 | decoder_input = encoded_text 115 | 116 | if self.post_encoder_embedder: 117 | decoder_input = decoder_input + post_embeddings 118 | 119 | if task == "deps": 120 | tag_logits = logits["upos"] if "upos" in logits else None 121 | pred_output = self.decoders[task](decoder_input, mask, tag_logits, 122 | gold_tags.get("head_tags", None), gold_tags.get("head_indices", None), metadata) 123 | for key in ["heads", "head_tags", "arc_loss", "tag_loss", "mask"]: 124 | output_dict[key] = pred_output[key] 125 | else: 126 | pred_output = self.decoders[task](decoder_input, mask, gold_tags, metadata) 127 | 128 | logits[task] = pred_output["logits"] 129 | class_probabilities[task] = pred_output["class_probabilities"] 130 | 131 | if task in gold_tags or task == "deps" and "head_tags" in gold_tags: 132 | # Keep track of the loss if we have the gold tags available 133 | loss += pred_output["loss"] 134 | 135 | if gold_tags: 136 | output_dict["loss"] = loss 137 | 138 | if metadata is not None: 139 | output_dict["words"] = [x["words"] for x in metadata] 140 | output_dict["ids"] = [x["ids"] for x in metadata if "ids" in x] 141 | output_dict["multiword_ids"] = [x["multiword_ids"] for x in metadata if "multiword_ids" in x] 142 | output_dict["multiword_forms"] = [x["multiword_forms"] for x in metadata if "multiword_forms" in x] 143 | 144 | return output_dict 145 | 146 | def _apply_token_dropout(self, tokens): 147 | # Word dropout 148 | if "tokens" in tokens: 149 | oov_token = self.vocab.get_token_index(self.vocab._oov_token) 150 | ignore_tokens = [self.vocab.get_token_index(self.vocab._padding_token)] 151 | tokens["tokens"] = self.token_dropout(tokens["tokens"], 152 | oov_token=oov_token, 153 | padding_tokens=ignore_tokens, 154 | p=self.word_dropout, 155 | training=self.training) 156 | 157 | # BERT token dropout 158 | if "bert" in tokens: 159 | oov_token = self.bert_vocab["[MASK]"] 160 | ignore_tokens = [self.bert_vocab["[PAD]"], self.bert_vocab["[CLS]"], self.bert_vocab["[SEP]"]] 161 | tokens["bert"] = self.token_dropout(tokens["bert"], 162 | oov_token=oov_token, 163 | padding_tokens=ignore_tokens, 164 | p=self.word_dropout, 165 | training=self.training) 166 | 167 | @staticmethod 168 | def token_dropout(tokens: torch.LongTensor, 169 | oov_token: int, 170 | padding_tokens: List[int], 171 | p: float = 0.2, 172 | training: float = True) -> torch.LongTensor: 173 | """ 174 | During training, randomly replaces some of the non-padding tokens to a mask token with probability ``p`` 175 | 176 | :param tokens: The current batch of padded sentences with word ids 177 | :param oov_token: The mask token 178 | :param padding_tokens: The tokens for padding the input batch 179 | :param p: The probability a word gets mapped to the unknown token 180 | :param training: Applies the dropout if set to ``True`` 181 | :return: A copy of the input batch with token dropout applied 182 | """ 183 | if training and p > 0: 184 | # Ensure that the tensors run on the same device 185 | device = tokens.device 186 | 187 | # This creates a mask that only considers unpadded tokens for mapping to oov 188 | padding_mask = torch.ones(tokens.size(), dtype=torch.bool).to(device) 189 | for pad in padding_tokens: 190 | padding_mask &= (tokens != pad) 191 | 192 | # Create a uniformly random mask selecting either the original words or OOV tokens 193 | dropout_mask = (torch.empty(tokens.size()).uniform_() < p).to(device) 194 | oov_mask = dropout_mask & padding_mask 195 | 196 | oov_fill = torch.empty(tokens.size(), dtype=torch.long).fill_(oov_token).to(device) 197 | 198 | result = torch.where(oov_mask, oov_fill, tokens) 199 | 200 | return result 201 | else: 202 | return tokens 203 | 204 | @overrides 205 | def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 206 | for task in self.tasks: 207 | self.decoders[task].decode(output_dict) 208 | 209 | return output_dict 210 | 211 | @overrides 212 | def get_metrics(self, reset: bool = False) -> Dict[str, float]: 213 | metrics = {name: task_metric 214 | for task in self.tasks 215 | for name, task_metric in self.decoders[task].get_metrics(reset).items()} 216 | 217 | # The "sum" metric summing all tracked metrics keeps a good measure of patience for early stopping and saving 218 | metrics_to_track = {"upos", "xpos", "feats", "lemmas", "LAS", "UAS"} 219 | metrics[".run/.sum"] = sum(metric 220 | for name, metric in metrics.items() 221 | if not name.startswith("_") and set(name.split("/")).intersection(metrics_to_track)) 222 | 223 | return metrics 224 | 225 | def _count_params(self): 226 | self.total_params = sum(p.numel() for p in self.parameters()) 227 | self.total_train_params = sum(p.numel() for p in self.parameters() if p.requires_grad) 228 | 229 | logger.info(f"Total number of parameters: {self.total_params}") 230 | logger.info(f"Total number of trainable parameters: {self.total_train_params}") 231 | -------------------------------------------------------------------------------- /udify/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from udify.modules.bert_pretrained import UdifyPretrainedBertEmbedder, WordpieceIndexer, PretrainedBertIndexer, BertEmbedder 2 | from udify.modules.residual_rnn import ResidualRNN 3 | from udify.modules.scalar_mix import ScalarMixWithDropout 4 | from udify.modules.text_field_embedder import UdifyTextFieldEmbedder 5 | from udify.modules.token_characters_encoder import UdifyTokenCharactersEncoder 6 | -------------------------------------------------------------------------------- /udify/modules/residual_rnn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines a "Residual RNN", which adds the input of the RNNs to the output, which tends to work better than conventional 3 | RNN layers. 4 | """ 5 | 6 | from overrides import overrides 7 | import torch 8 | from allennlp.common.checks import ConfigurationError 9 | from allennlp.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder 10 | from allennlp.modules.seq2seq_encoders import PytorchSeq2SeqWrapper 11 | 12 | 13 | @Seq2SeqEncoder.register("udify_residual_rnn") 14 | class ResidualRNN(Seq2SeqEncoder): 15 | """ 16 | Instead of using AllenNLP's default PyTorch wrapper for seq2seq layers, we would like 17 | to apply intermediate logic between each layer, so we create a new wrapper for each layer, 18 | with residual connections between. 19 | """ 20 | 21 | def __init__(self, 22 | input_size: int, 23 | hidden_size: int, 24 | num_layers: int = 1, 25 | dropout: float = 0.0, 26 | residual: bool = True, 27 | rnn_type: str = "lstm") -> None: 28 | super(ResidualRNN, self).__init__() 29 | 30 | self._input_size = input_size 31 | self._hidden_size = hidden_size 32 | self._dropout = torch.nn.Dropout(p=dropout) 33 | self._residual = residual 34 | 35 | rnn_type = rnn_type.lower() 36 | if rnn_type == "lstm": 37 | rnn_cell = torch.nn.LSTM 38 | elif rnn_type == "gru": 39 | rnn_cell = torch.nn.GRU 40 | else: 41 | raise ConfigurationError(f"Unknown RNN cell type {rnn_type}") 42 | 43 | layers = [] 44 | for layer_index in range(num_layers): 45 | # Use hidden size on later layers so that the first layer projects and all other layers are residual 46 | input_ = input_size if layer_index == 0 else hidden_size 47 | rnn = rnn_cell(input_, hidden_size, bidirectional=True, batch_first=True) 48 | layer = PytorchSeq2SeqWrapper(rnn) 49 | layers.append(layer) 50 | self.add_module("rnn_layer_{}".format(layer_index), layer) 51 | self._layers = layers 52 | 53 | @overrides 54 | def get_input_dim(self) -> int: 55 | return self._input_size 56 | 57 | @overrides 58 | def get_output_dim(self) -> int: 59 | return self._hidden_size 60 | 61 | @overrides 62 | def is_bidirectional(self): 63 | return True 64 | 65 | def forward(self, inputs: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: # pylint: disable=arguments-differ 66 | hidden = inputs 67 | for i, layer in enumerate(self._layers): 68 | encoded = layer(hidden, mask) 69 | # Sum the backward and forward states to allow residual connections 70 | encoded = encoded[:, :, :self._hidden_size] + encoded[:, :, self._hidden_size:] 71 | 72 | projecting = i == 0 and self._input_size != self._hidden_size 73 | if self._residual and not projecting: 74 | hidden = hidden + self._dropout(encoded) 75 | else: 76 | hidden = self._dropout(encoded) 77 | 78 | return hidden 79 | -------------------------------------------------------------------------------- /udify/modules/scalar_mix.py: -------------------------------------------------------------------------------- 1 | """ 2 | The dot-product "Layer Attention" that is applied to the layers of BERT, along with layer dropout to reduce overfitting 3 | """ 4 | 5 | from typing import List 6 | 7 | import torch 8 | from torch.nn import ParameterList, Parameter 9 | 10 | from allennlp.common.checks import ConfigurationError 11 | 12 | 13 | class ScalarMixWithDropout(torch.nn.Module): 14 | """ 15 | Computes a parameterised scalar mixture of N tensors, ``mixture = gamma * sum(s_k * tensor_k)`` 16 | where ``s = softmax(w)``, with ``w`` and ``gamma`` scalar parameters. 17 | 18 | If ``do_layer_norm=True`` then apply layer normalization to each tensor before weighting. 19 | 20 | If ``dropout > 0``, then for each scalar weight, adjust its softmax weight mass to 0 with 21 | the dropout probability (i.e., setting the unnormalized weight to -inf). This effectively 22 | should redistribute dropped probability mass to all other weights. 23 | """ 24 | def __init__(self, 25 | mixture_size: int, 26 | do_layer_norm: bool = False, 27 | initial_scalar_parameters: List[float] = None, 28 | trainable: bool = True, 29 | dropout: float = None, 30 | dropout_value: float = -1e20) -> None: 31 | super(ScalarMixWithDropout, self).__init__() 32 | self.mixture_size = mixture_size 33 | self.do_layer_norm = do_layer_norm 34 | self.dropout = dropout 35 | 36 | if initial_scalar_parameters is None: 37 | initial_scalar_parameters = [0.0] * mixture_size 38 | elif len(initial_scalar_parameters) != mixture_size: 39 | raise ConfigurationError("Length of initial_scalar_parameters {} differs " 40 | "from mixture_size {}".format( 41 | initial_scalar_parameters, mixture_size)) 42 | 43 | self.scalar_parameters = ParameterList( 44 | [Parameter(torch.FloatTensor([initial_scalar_parameters[i]]), 45 | requires_grad=trainable) for i 46 | in range(mixture_size)]) 47 | self.gamma = Parameter(torch.FloatTensor([1.0]), requires_grad=trainable) 48 | 49 | if self.dropout: 50 | dropout_mask = torch.zeros(len(self.scalar_parameters)) 51 | dropout_fill = torch.empty(len(self.scalar_parameters)).fill_(dropout_value) 52 | self.register_buffer("dropout_mask", dropout_mask) 53 | self.register_buffer("dropout_fill", dropout_fill) 54 | 55 | def forward(self, tensors: List[torch.Tensor], # pylint: disable=arguments-differ 56 | mask: torch.Tensor = None) -> torch.Tensor: 57 | """ 58 | Compute a weighted average of the ``tensors``. The input tensors an be any shape 59 | with at least two dimensions, but must all be the same shape. 60 | 61 | When ``do_layer_norm=True``, the ``mask`` is required input. If the ``tensors`` are 62 | dimensioned ``(dim_0, ..., dim_{n-1}, dim_n)``, then the ``mask`` is dimensioned 63 | ``(dim_0, ..., dim_{n-1})``, as in the typical case with ``tensors`` of shape 64 | ``(batch_size, timesteps, dim)`` and ``mask`` of shape ``(batch_size, timesteps)``. 65 | 66 | When ``do_layer_norm=False`` the ``mask`` is ignored. 67 | """ 68 | if len(tensors) != self.mixture_size: 69 | raise ConfigurationError("{} tensors were passed, but the module was initialized to " 70 | "mix {} tensors.".format(len(tensors), self.mixture_size)) 71 | 72 | def _do_layer_norm(tensor, broadcast_mask, num_elements_not_masked): 73 | tensor_masked = tensor * broadcast_mask 74 | mean = torch.sum(tensor_masked) / num_elements_not_masked 75 | variance = torch.sum(((tensor_masked - mean) * broadcast_mask)**2) / num_elements_not_masked 76 | return (tensor - mean) / torch.sqrt(variance + 1E-12) 77 | 78 | weights = torch.cat([parameter for parameter in self.scalar_parameters]) 79 | 80 | if self.dropout: 81 | weights = torch.where(self.dropout_mask.uniform_() > self.dropout, weights, self.dropout_fill) 82 | 83 | normed_weights = torch.nn.functional.softmax(weights, dim=0) 84 | normed_weights = torch.split(normed_weights, split_size_or_sections=1) 85 | 86 | if not self.do_layer_norm: 87 | pieces = [] 88 | for weight, tensor in zip(normed_weights, tensors): 89 | pieces.append(weight * tensor) 90 | return self.gamma * sum(pieces) 91 | 92 | else: 93 | mask_float = mask.float() 94 | broadcast_mask = mask_float.unsqueeze(-1) 95 | input_dim = tensors[0].size(-1) 96 | num_elements_not_masked = torch.sum(mask_float) * input_dim 97 | 98 | pieces = [] 99 | for weight, tensor in zip(normed_weights, tensors): 100 | pieces.append(weight * _do_layer_norm(tensor, 101 | broadcast_mask, num_elements_not_masked)) 102 | return self.gamma * sum(pieces) 103 | -------------------------------------------------------------------------------- /udify/modules/text_field_embedder.py: -------------------------------------------------------------------------------- 1 | """ 2 | A modification to AllenNLP's TextFieldEmbedder 3 | """ 4 | 5 | from typing import Dict, List, Optional 6 | import warnings 7 | 8 | import torch 9 | from overrides import overrides 10 | from torch.nn.modules.linear import Linear 11 | 12 | from allennlp.common import Params 13 | from allennlp.common.checks import ConfigurationError 14 | from allennlp.data import Vocabulary 15 | from allennlp.modules.text_field_embedders import TextFieldEmbedder 16 | from allennlp.modules.time_distributed import TimeDistributed 17 | from allennlp.modules.token_embedders.token_embedder import TokenEmbedder 18 | 19 | 20 | @TextFieldEmbedder.register("udify_embedder") 21 | class UdifyTextFieldEmbedder(TextFieldEmbedder): 22 | """ 23 | This is a ``TextFieldEmbedder`` that, instead of concatenating embeddings together 24 | as in ``BasicTextFieldEmbedder``, sums them together. It also optionally allows 25 | dropout to be applied to its output. See AllenNLP's basic_text_field_embedder.py, 26 | https://github.com/allenai/allennlp/blob/a75cb0a9ddedb23801db74629c3a3017dafb375e/ 27 | allennlp/modules/text_field_embedders/ 28 | basic_text_field_embedder.py 29 | """ 30 | 31 | def __init__(self, 32 | token_embedders: Dict[str, TokenEmbedder], 33 | output_dim: Optional[int] = None, 34 | sum_embeddings: List[str] = None, 35 | embedder_to_indexer_map: Dict[str, List[str]] = None, 36 | allow_unmatched_keys: bool = False, 37 | dropout: float = 0.0) -> None: 38 | super(UdifyTextFieldEmbedder, self).__init__() 39 | self._output_dim = output_dim 40 | self._token_embedders = token_embedders 41 | self._embedder_to_indexer_map = embedder_to_indexer_map 42 | for key, embedder in token_embedders.items(): 43 | name = 'token_embedder_%s' % key 44 | self.add_module(name, embedder) 45 | self._allow_unmatched_keys = allow_unmatched_keys 46 | self._dropout = torch.nn.Dropout(p=dropout) if dropout > 0 else lambda x: x 47 | self._sum_embeddings = sum_embeddings if sum_embeddings is not None else [] 48 | 49 | hidden_dim = 0 50 | for embedder in self._token_embedders.values(): 51 | hidden_dim += embedder.get_output_dim() 52 | 53 | if len(self._sum_embeddings) > 1: 54 | for key in self._sum_embeddings[1:]: 55 | hidden_dim -= self._token_embedders[key].get_output_dim() 56 | 57 | if self._output_dim is None: 58 | self._projection_layer = None 59 | self._output_dim = hidden_dim 60 | else: 61 | self._projection_layer = Linear(hidden_dim, self._output_dim) 62 | 63 | @overrides 64 | def get_output_dim(self) -> int: 65 | return self._output_dim 66 | 67 | def forward(self, text_field_input: Dict[str, torch.Tensor], num_wrapping_dims: int = 0) -> torch.Tensor: 68 | embedder_keys = self._token_embedders.keys() 69 | input_keys = text_field_input.keys() 70 | 71 | # Check for unmatched keys 72 | if not self._allow_unmatched_keys: 73 | if embedder_keys < input_keys: 74 | # token embedder keys are a strict subset of text field input keys. 75 | message = (f"Your text field is generating more keys ({list(input_keys)}) " 76 | f"than you have token embedders ({list(embedder_keys)}. " 77 | f"If you are using a token embedder that requires multiple keys " 78 | f"(for example, the OpenAI Transformer embedder or the BERT embedder) " 79 | f"you need to add allow_unmatched_keys = True " 80 | f"(and likely an embedder_to_indexer_map) to your " 81 | f"BasicTextFieldEmbedder configuration. " 82 | f"Otherwise, you should check that there is a 1:1 embedding " 83 | f"between your token indexers and token embedders.") 84 | raise ConfigurationError(message) 85 | 86 | elif self._token_embedders.keys() != text_field_input.keys(): 87 | # some other mismatch 88 | message = "Mismatched token keys: %s and %s" % (str(self._token_embedders.keys()), 89 | str(text_field_input.keys())) 90 | raise ConfigurationError(message) 91 | 92 | def embed(key): 93 | # If we pre-specified a mapping explictly, use that. 94 | if self._embedder_to_indexer_map is not None: 95 | tensors = [text_field_input[indexer_key] for 96 | indexer_key in self._embedder_to_indexer_map[key]] 97 | else: 98 | # otherwise, we assume the mapping between indexers and embedders 99 | # is bijective and just use the key directly. 100 | tensors = [text_field_input[key]] 101 | # Note: need to use getattr here so that the pytorch voodoo 102 | # with submodules works with multiple GPUs. 103 | embedder = getattr(self, 'token_embedder_{}'.format(key)) 104 | for _ in range(num_wrapping_dims): 105 | embedder = TimeDistributed(embedder) 106 | token_vectors = embedder(*tensors) 107 | 108 | return token_vectors 109 | 110 | embedded_representations = [] 111 | keys = sorted(embedder_keys) 112 | 113 | sum_embed = [] 114 | for key in self._sum_embeddings: 115 | token_vectors = embed(key) 116 | sum_embed.append(token_vectors) 117 | keys.remove(key) 118 | 119 | if sum_embed: 120 | embedded_representations.append(sum(sum_embed)) 121 | 122 | for key in keys: 123 | token_vectors = embed(key) 124 | embedded_representations.append(token_vectors) 125 | 126 | combined_embeddings = self._dropout(torch.cat(embedded_representations, dim=-1)) 127 | 128 | if self._projection_layer is not None: 129 | combined_embeddings = self._dropout(self._projection_layer(combined_embeddings)) 130 | 131 | return combined_embeddings 132 | 133 | # This is some unusual logic, it needs a custom from_params. 134 | @classmethod 135 | def from_params(cls, vocab: Vocabulary, params: Params) -> 'UdifyTextFieldEmbedder': # type: ignore 136 | # pylint: disable=arguments-differ,bad-super-call 137 | 138 | # The original `from_params` for this class was designed in a way that didn't agree 139 | # with the constructor. The constructor wants a 'token_embedders' parameter that is a 140 | # `Dict[str, TokenEmbedder]`, but the original `from_params` implementation expected those 141 | # key-value pairs to be top-level in the params object. 142 | # 143 | # This breaks our 'configuration wizard' and configuration checks. Hence, going forward, 144 | # the params need a 'token_embedders' key so that they line up with what the constructor wants. 145 | # For now, the old behavior is still supported, but produces a DeprecationWarning. 146 | 147 | embedder_to_indexer_map = params.pop("embedder_to_indexer_map", None) 148 | if embedder_to_indexer_map is not None: 149 | embedder_to_indexer_map = embedder_to_indexer_map.as_dict(quiet=True) 150 | allow_unmatched_keys = params.pop_bool("allow_unmatched_keys", False) 151 | 152 | token_embedder_params = params.pop('token_embedders', None) 153 | 154 | dropout = params.pop_float("dropout", 0.0) 155 | 156 | output_dim = params.pop_int("output_dim", None) 157 | sum_embeddings = params.pop("sum_embeddings", None) 158 | 159 | if token_embedder_params is not None: 160 | # New way: explicitly specified, so use it. 161 | token_embedders = { 162 | name: TokenEmbedder.from_params(subparams, vocab=vocab) 163 | for name, subparams in token_embedder_params.items() 164 | } 165 | 166 | else: 167 | # Warn that the original behavior is deprecated 168 | warnings.warn(DeprecationWarning("the token embedders for BasicTextFieldEmbedder should now " 169 | "be specified as a dict under the 'token_embedders' key, " 170 | "not as top-level key-value pairs")) 171 | 172 | token_embedders = {} 173 | keys = list(params.keys()) 174 | for key in keys: 175 | embedder_params = params.pop(key) 176 | token_embedders[key] = TokenEmbedder.from_params(vocab=vocab, params=embedder_params) 177 | 178 | params.assert_empty(cls.__name__) 179 | return cls(token_embedders, output_dim, sum_embeddings, embedder_to_indexer_map, allow_unmatched_keys, dropout) 180 | -------------------------------------------------------------------------------- /udify/modules/token_characters_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | A modification to AllenNLP's TokenCharactersEncoder 3 | """ 4 | 5 | import torch 6 | 7 | from allennlp.common import Params 8 | from allennlp.data.vocabulary import Vocabulary 9 | from allennlp.modules.token_embedders.embedding import Embedding 10 | from allennlp.modules.seq2vec_encoders.seq2vec_encoder import Seq2VecEncoder 11 | from allennlp.modules.time_distributed import TimeDistributed 12 | from allennlp.modules.token_embedders import TokenEmbedder 13 | 14 | 15 | @TokenEmbedder.register("udify_character_encoding") 16 | class UdifyTokenCharactersEncoder(TokenEmbedder): 17 | """ 18 | Like AllenNLP's TokenCharactersEncoder, but applies dropout to the embeddings. 19 | https://github.com/allenai/allennlp/blob/7dbd7d34a2f1390d1ff01f2e9ed6f8bdaaef77eb/ 20 | allennlp/modules/token_embedders/token_characters_encoder.py 21 | """ 22 | def __init__(self, embedding: Embedding, encoder: Seq2VecEncoder, dropout: float = 0.0) -> None: 23 | super(UdifyTokenCharactersEncoder, self).__init__() 24 | self._embedding = TimeDistributed(embedding) 25 | self._encoder = TimeDistributed(encoder) 26 | if dropout > 0: 27 | self._dropout = torch.nn.Dropout(p=dropout) 28 | else: 29 | self._dropout = lambda x: x 30 | 31 | def get_output_dim(self) -> int: 32 | return self._encoder._module.get_output_dim() # pylint: disable=protected-access 33 | 34 | def forward(self, token_characters: torch.Tensor) -> torch.Tensor: # pylint: disable=arguments-differ 35 | mask = (token_characters != 0).long() 36 | return self._encoder(self._dropout(self._embedding(token_characters)), mask) 37 | 38 | # The setdefault requires a custom from_params 39 | @classmethod 40 | def from_params(cls, vocab: Vocabulary, params: Params) -> 'UDTokenCharactersEncoder': # type: ignore 41 | # pylint: disable=arguments-differ 42 | embedding_params: Params = params.pop("embedding") 43 | # Embedding.from_params() uses "tokens" as the default namespace, but we need to change 44 | # that to be "token_characters" by default. 45 | embedding_params.setdefault("vocab_namespace", "token_characters") 46 | embedding = Embedding.from_params(vocab, embedding_params) 47 | encoder_params: Params = params.pop("encoder") 48 | encoder = Seq2VecEncoder.from_params(encoder_params) 49 | dropout = params.pop_float("dropout", 0.0) 50 | params.assert_empty(cls.__name__) 51 | return cls(embedding, encoder, dropout) -------------------------------------------------------------------------------- /udify/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from udify.optimizers.ulmfit_sqrt import UlmfitSqrtLR 2 | -------------------------------------------------------------------------------- /udify/optimizers/ulmfit_sqrt.py: -------------------------------------------------------------------------------- 1 | """ 2 | Special LR scheduler for fine-tuning Transformer networks 3 | """ 4 | 5 | from overrides import overrides 6 | import torch 7 | import logging 8 | 9 | from allennlp.training.learning_rate_schedulers.learning_rate_scheduler import LearningRateScheduler 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | @LearningRateScheduler.register("ulmfit_sqrt") 15 | class UlmfitSqrtLR(LearningRateScheduler): 16 | """Implements a combination of ULMFiT (slanted triangular) with Noam sqrt learning rate decay""" 17 | def __init__(self, 18 | optimizer: torch.optim.Optimizer, 19 | model_size: int, 20 | warmup_steps: int, 21 | start_step: int = 0, 22 | factor: float = 100, 23 | steepness: float = 0.5, 24 | last_epoch: int = -1, 25 | gradual_unfreezing: bool = False, 26 | discriminative_fine_tuning: bool = False, 27 | decay_factor: float = 0.38) -> None: 28 | self.warmup_steps = warmup_steps + start_step 29 | self.start_step = start_step 30 | self.factor = factor 31 | self.steepness = steepness 32 | self.model_size = model_size 33 | self.gradual_unfreezing = gradual_unfreezing 34 | self.freezing_current = self.gradual_unfreezing 35 | 36 | if self.gradual_unfreezing: 37 | assert not optimizer.param_groups[-1]["params"], \ 38 | "The default group should be empty." 39 | if self.gradual_unfreezing or discriminative_fine_tuning: 40 | assert len(optimizer.param_groups) > 2, \ 41 | "There should be at least 3 param_groups (2 + empty default group)" \ 42 | " for gradual unfreezing / discriminative fine-tuning to make sense." 43 | 44 | super().__init__(optimizer, last_epoch=last_epoch) 45 | 46 | if discriminative_fine_tuning: 47 | # skip the last param_group if it is has no parameters 48 | exponent = 0 49 | for i in range(len(self.base_values) - 1, -1, -1): 50 | param_group = optimizer.param_groups[i] 51 | if param_group['params']: 52 | param_group['lr'] = self.base_values[i] * decay_factor ** exponent 53 | self.base_values[i] = param_group['lr'] 54 | exponent += 1 55 | 56 | @overrides 57 | def step(self, metric: float = None, epoch: int = None) -> None: 58 | if self.gradual_unfreezing: 59 | # the method is called once when initialising before the 60 | # first epoch (epoch -1) and then always at the end of each 61 | # epoch; so the first time, with epoch id -1, we want to set 62 | # up for epoch #1; the second time, with epoch id 0, 63 | # we want to set up for epoch #2, etc. 64 | num_layers_to_unfreeze = epoch + 2 if epoch > -1 else 1 65 | if num_layers_to_unfreeze >= len(self.optimizer.param_groups) - 1: 66 | logger.info('Gradual unfreezing finished. Training all layers.') 67 | self.freezing_current = False 68 | else: 69 | logger.info(f'Gradual unfreezing. Training only the top {num_layers_to_unfreeze} layers.') 70 | for i, param_group in enumerate(reversed(self.optimizer.param_groups)): 71 | for param in param_group["params"]: 72 | # i = 0 is the default group; we care about i > 0 73 | param.requires_grad = bool(i <= num_layers_to_unfreeze) 74 | 75 | def step_batch(self, batch_num_total: int = None) -> None: 76 | if batch_num_total is None: 77 | self.last_epoch += 1 # type: ignore 78 | else: 79 | self.last_epoch = batch_num_total 80 | for param_group, learning_rate in zip(self.optimizer.param_groups, self.get_values()): 81 | param_group['lr'] = learning_rate 82 | 83 | def get_values(self): 84 | if self.freezing_current: 85 | # If parameters are still frozen, keep the base learning rates 86 | return self.base_values 87 | 88 | # This computes the Noam Sqrt LR decay based on the current step 89 | step = max(self.last_epoch - self.start_step, 1) 90 | scale = self.factor * (self.model_size ** (-0.5) * 91 | min(step ** (-self.steepness), step * self.warmup_steps ** (-self.steepness - 1))) 92 | 93 | return [scale * lr for lr in self.base_values] 94 | -------------------------------------------------------------------------------- /udify/predictors/__init__.py: -------------------------------------------------------------------------------- 1 | from udify.predictors.predictor import UdifyPredictor 2 | -------------------------------------------------------------------------------- /udify/predictors/predictor.py: -------------------------------------------------------------------------------- 1 | """ 2 | The main UDify predictor to output conllu files 3 | """ 4 | 5 | from typing import List 6 | from overrides import overrides 7 | 8 | from allennlp.common.util import JsonDict, sanitize 9 | from allennlp.data import DatasetReader, Instance 10 | from allennlp.models import Model 11 | from allennlp.predictors.predictor import Predictor 12 | 13 | 14 | @Predictor.register("udify_predictor") 15 | class UdifyPredictor(Predictor): 16 | """ 17 | Predictor for a UDify model that takes in a sentence and returns 18 | a single set conllu annotations for it. 19 | """ 20 | def __init__(self, model: Model, dataset_reader: DatasetReader) -> None: 21 | super().__init__(model, dataset_reader) 22 | 23 | def predict(self, sentence: str) -> JsonDict: 24 | return self.predict_json({"sentence": sentence}) 25 | 26 | @overrides 27 | def predict_batch_instance(self, instances: List[Instance]) -> List[JsonDict]: 28 | if "@@UNKNOWN@@" not in self._model.vocab._token_to_index["lemmas"]: 29 | # Handle cases where the labels are present in the test set but not training set 30 | for instance in instances: 31 | self._predict_unknown(instance) 32 | outputs = self._model.forward_on_instances(instances) 33 | return sanitize(outputs) 34 | 35 | @overrides 36 | def predict_instance(self, instance: Instance) -> JsonDict: 37 | if "@@UNKNOWN@@" not in self._model.vocab._token_to_index["lemmas"]: 38 | # Handle cases where the labels are present in the test set but not training set 39 | self._predict_unknown(instance) 40 | outputs = self._model.forward_on_instance(instance) 41 | return sanitize(outputs) 42 | 43 | def _predict_unknown(self, instance: Instance): 44 | """ 45 | Maps each unknown label in each namespace to a default token 46 | :param instance: the instance containing a list of labels for each namespace 47 | """ 48 | def replace_tokens(instance: Instance, namespace: str, token: str): 49 | if namespace not in instance.fields: 50 | return 51 | 52 | instance.fields[namespace].labels = [label 53 | if label in self._model.vocab._token_to_index[namespace] 54 | else token 55 | for label in instance.fields[namespace].labels] 56 | 57 | replace_tokens(instance, "lemmas", "↓0;d¦") 58 | replace_tokens(instance, "feats", "_") 59 | replace_tokens(instance, "xpos", "_") 60 | replace_tokens(instance, "upos", "NOUN") 61 | replace_tokens(instance, "head_tags", "case") 62 | 63 | @overrides 64 | def _json_to_instance(self, json_dict: JsonDict) -> Instance: 65 | """ 66 | Expects JSON that looks like ``{"sentence": "..."}``. 67 | Runs the underlying model, and adds the ``"words"`` to the output. 68 | """ 69 | sentence = json_dict["sentence"] 70 | tokens = sentence.split() 71 | return self._dataset_reader.text_to_instance(tokens) 72 | 73 | @overrides 74 | def dump_line(self, outputs: JsonDict) -> str: 75 | word_count = len([word for word in outputs["words"]]) 76 | lines = zip(*[outputs[k] if k in outputs else ["_"] * word_count 77 | for k in ["ids", "words", "lemmas", "upos", "xpos", "feats", 78 | "predicted_heads", "predicted_dependencies"]]) 79 | 80 | multiword_map = None 81 | if outputs["multiword_ids"]: 82 | multiword_ids = [[id] + [int(x) for x in id.split("-")] for id in outputs["multiword_ids"]] 83 | multiword_forms = outputs["multiword_forms"] 84 | multiword_map = {start: (id_, form) for (id_, start, end), form in zip(multiword_ids, multiword_forms)} 85 | 86 | output_lines = [] 87 | for i, line in enumerate(lines): 88 | line = [str(l) for l in line] 89 | 90 | # Handle multiword tokens 91 | if multiword_map and i+1 in multiword_map: 92 | id_, form = multiword_map[i+1] 93 | row = f"{id_}\t{form}" + "".join(["\t_"] * 8) 94 | output_lines.append(row) 95 | 96 | row = "\t".join(line) + "".join(["\t_"] * 2) 97 | output_lines.append(row) 98 | 99 | return "\n".join(output_lines) + "\n\n" 100 | -------------------------------------------------------------------------------- /udify/predictors/text_predictor.py: -------------------------------------------------------------------------------- 1 | """ 2 | The main UDify predictor to output conllu files 3 | """ 4 | 5 | from typing import List 6 | from overrides import overrides 7 | import json 8 | 9 | from allennlp.common.util import JsonDict, sanitize 10 | from allennlp.data import DatasetReader, Instance 11 | from allennlp.models import Model 12 | from allennlp.predictors.predictor import Predictor 13 | 14 | from udify.dataset_readers.universal_dependencies import UniversalDependenciesRawDatasetReader 15 | from udify.predictors.predictor import UdifyPredictor 16 | 17 | 18 | @Predictor.register("udify_text_predictor") 19 | class UdifyTextPredictor(Predictor): 20 | """ 21 | Predictor for a UDify model that takes in a raw sentence and outputs json. 22 | """ 23 | def __init__(self, 24 | model: Model, 25 | dataset_reader: DatasetReader, 26 | output_conllu: bool = False) -> None: 27 | super().__init__(model, dataset_reader) 28 | self._dataset_reader = UniversalDependenciesRawDatasetReader(self._dataset_reader) 29 | self.predictor = UdifyPredictor(model, dataset_reader) 30 | self.output_conllu = output_conllu 31 | 32 | def predict(self, sentence: str) -> JsonDict: 33 | return self.predict_json({"sentence": sentence}) 34 | 35 | @overrides 36 | def predict_batch_instance(self, instances: List[Instance]) -> List[JsonDict]: 37 | return self.predictor.predict_batch_instance(instances) 38 | 39 | @overrides 40 | def predict_instance(self, instance: Instance) -> JsonDict: 41 | return self.predictor.predict_instance(instance) 42 | 43 | def _predict_unknown(self, instance: Instance): 44 | """ 45 | Maps each unknown label in each namespace to a default token 46 | :param instance: the instance containing a list of labels for each namespace 47 | """ 48 | return self.predictor._predict_unknown(instance) 49 | 50 | @overrides 51 | def _json_to_instance(self, json_dict: JsonDict) -> Instance: 52 | """ 53 | Expects JSON that looks like ``{"sentence": "..."}``. 54 | Runs the underlying model, and adds the ``"words"`` to the output. 55 | """ 56 | sentence = json_dict["sentence"] 57 | tokens = [word.text for word in self._dataset_reader.tokenizer.split_words(sentence)] 58 | return self._dataset_reader.text_to_instance(tokens) 59 | 60 | @overrides 61 | def dump_line(self, outputs: JsonDict) -> str: 62 | if self.output_conllu: 63 | return self.predictor.dump_line(outputs) 64 | else: 65 | return json.dumps(outputs) + "\n" 66 | -------------------------------------------------------------------------------- /udify/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | A collection of handy utilities 3 | """ 4 | 5 | from typing import List, Tuple, Dict, Any 6 | 7 | import os 8 | import glob 9 | import json 10 | import logging 11 | import tarfile 12 | import traceback 13 | 14 | import torch 15 | 16 | from allennlp.common.checks import ConfigurationError 17 | from allennlp.common import Params 18 | from allennlp.common.params import with_fallback 19 | from allennlp.commands.make_vocab import make_vocab_from_params 20 | from allennlp.commands.predict import _PredictManager 21 | from allennlp.common.checks import check_for_gpu 22 | from allennlp.models.archival import load_archive 23 | from allennlp.predictors.predictor import Predictor 24 | 25 | from udify.dataset_readers.evaluate_2019_task2 import read_conllu, input_pairs, manipulate_data 26 | 27 | from udify.dataset_readers.conll18_ud_eval import evaluate, load_conllu_file, UDError 28 | 29 | VOCAB_CONFIG_PATH = "config/create_vocab.json" 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | 34 | def merge_configs(configs: List[Params]) -> Params: 35 | """ 36 | Merges a list of configurations together, with items with duplicate keys closer to the front of the list 37 | overriding any keys of items closer to the rear. 38 | :param configs: a list of AllenNLP Params 39 | :return: a single merged Params object 40 | """ 41 | while len(configs) > 1: 42 | overrides, config = configs[-2:] 43 | configs = configs[:-2] 44 | 45 | if "udify_replace" in overrides: 46 | replacements = [replace.split(".") for replace in overrides.pop("udify_replace")] 47 | for replace in replacements: 48 | obj = config 49 | try: 50 | for key in replace[:-1]: 51 | obj = obj[key] 52 | except KeyError: 53 | raise ConfigurationError(f"Config does not have key {key}") 54 | obj.pop(replace[-1]) 55 | 56 | configs.append(Params(with_fallback(preferred=overrides.params, fallback=config.params))) 57 | 58 | return configs[0] 59 | 60 | 61 | def cache_vocab(params: Params, vocab_config_path: str = None): 62 | """ 63 | Caches the vocabulary given in the Params to the filesystem. Useful for large datasets that are run repeatedly. 64 | :param params: the AllenNLP Params 65 | :param vocab_config_path: an optional config path for constructing the vocab 66 | """ 67 | if "vocabulary" not in params or "directory_path" not in params["vocabulary"]: 68 | return 69 | 70 | vocab_path = params["vocabulary"]["directory_path"] 71 | 72 | if os.path.exists(vocab_path): 73 | if os.listdir(vocab_path): 74 | return 75 | 76 | # Remove empty vocabulary directory to make AllenNLP happy 77 | try: 78 | os.rmdir(vocab_path) 79 | except OSError: 80 | pass 81 | 82 | vocab_config_path = vocab_config_path if vocab_config_path else VOCAB_CONFIG_PATH 83 | 84 | params = merge_configs([params, Params.from_file(vocab_config_path)]) 85 | params["vocabulary"].pop("directory_path", None) 86 | make_vocab_from_params(params, os.path.split(vocab_path)[0]) 87 | 88 | 89 | def get_ud_treebank_files(dataset_dir: str, treebanks: List[str] = None) -> Dict[str, Tuple[str, str, str]]: 90 | """ 91 | Retrieves all treebank data paths in the given directory. 92 | :param dataset_dir: the directory where all treebank directories are stored 93 | :param treebanks: if not None or empty, retrieve just the subset of treebanks listed here 94 | :return: a dictionary mapping a treebank name to a list of train, dev, and test conllu files 95 | """ 96 | datasets = {} 97 | treebanks = os.listdir(dataset_dir) if not treebanks else treebanks 98 | for treebank in treebanks: 99 | treebank_path = os.path.join(dataset_dir, treebank) 100 | conllu_files = [file for file in sorted(os.listdir(treebank_path)) if file.endswith(".conllu")] 101 | 102 | train_file = [file for file in conllu_files if file.endswith("train.conllu")] 103 | dev_file = [file for file in conllu_files if file.endswith("dev.conllu")] 104 | test_file = [file for file in conllu_files if file.endswith("test.conllu")] 105 | 106 | train_file = os.path.join(treebank_path, train_file[0]) if train_file else None 107 | dev_file = os.path.join(treebank_path, dev_file[0]) if dev_file else None 108 | test_file = os.path.join(treebank_path, test_file[0]) if test_file else None 109 | 110 | datasets[treebank] = (train_file, dev_file, test_file) 111 | return datasets 112 | 113 | 114 | def get_ud_treebank_names(dataset_dir: str) -> List[Tuple[str, str]]: 115 | """ 116 | Retrieves all treebank names from the given directory. 117 | :param dataset_dir: the directory where all treebank directories are stored 118 | :return: a list of long and short treebank names 119 | """ 120 | treebanks = os.listdir(dataset_dir) 121 | short_names = [] 122 | 123 | for treebank in treebanks: 124 | treebank_path = os.path.join(dataset_dir, treebank) 125 | conllu_files = [file for file in sorted(os.listdir(treebank_path)) if file.endswith(".conllu")] 126 | 127 | test_file = [file for file in conllu_files if file.endswith("test.conllu")] 128 | test_file = test_file[0].split("-")[0] if test_file else None 129 | 130 | short_names.append(test_file) 131 | 132 | treebanks = ["_".join(treebank.split("_")[1:]) for treebank in treebanks] 133 | 134 | return list(zip(treebanks, short_names)) 135 | 136 | 137 | def predict_model_with_archive(predictor: str, params: Params, archive: str, 138 | input_file: str, output_file: str, batch_size: int = 1): 139 | cuda_device = params["trainer"]["cuda_device"] 140 | 141 | check_for_gpu(cuda_device) 142 | archive = load_archive(archive, 143 | cuda_device=cuda_device) 144 | 145 | predictor = Predictor.from_archive(archive, predictor) 146 | 147 | manager = _PredictManager(predictor, 148 | input_file, 149 | output_file, 150 | batch_size, 151 | print_to_console=False, 152 | has_dataset_reader=True) 153 | manager.run() 154 | 155 | 156 | def predict_and_evaluate_model_with_archive(predictor: str, params: Params, archive: str, gold_file: str, 157 | pred_file: str, output_file: str, segment_file: str = None, batch_size: int = 1): 158 | if not gold_file or not os.path.isfile(gold_file): 159 | logger.warning(f"No file exists for {gold_file}") 160 | return 161 | 162 | segment_file = segment_file if segment_file else gold_file 163 | predict_model_with_archive(predictor, params, archive, segment_file, pred_file, batch_size) 164 | 165 | try: 166 | evaluation = evaluate(load_conllu_file(gold_file), load_conllu_file(pred_file)) 167 | save_metrics(evaluation, output_file) 168 | except UDError: 169 | logger.warning(f"Failed to evaluate {pred_file}") 170 | traceback.print_exc() 171 | 172 | 173 | def predict_model(predictor: str, params: Params, archive_dir: str, 174 | input_file: str, output_file: str, batch_size: int = 1): 175 | """ 176 | Predict output annotations from the given model and input file and produce an output file. 177 | :param predictor: the type of predictor to use, e.g., "udify_predictor" 178 | :param params: the Params of the model 179 | :param archive_dir: the saved model archive 180 | :param input_file: the input file to predict 181 | :param output_file: the output file to save 182 | :param batch_size: the batch size, set this higher to speed up GPU inference 183 | """ 184 | archive = os.path.join(archive_dir, "model.tar.gz") 185 | predict_model_with_archive(predictor, params, archive, input_file, output_file, batch_size) 186 | 187 | 188 | def predict_and_evaluate_model(predictor: str, params: Params, archive_dir: str, gold_file: str, 189 | pred_file: str, output_file: str, segment_file: str = None, batch_size: int = 1): 190 | """ 191 | Predict output annotations from the given model and input file and evaluate the model. 192 | :param predictor: the type of predictor to use, e.g., "udify_predictor" 193 | :param params: the Params of the model 194 | :param archive_dir: the saved model archive 195 | :param gold_file: the file with gold annotations 196 | :param pred_file: the input file to predict 197 | :param output_file: the output file to save 198 | :param segment_file: an optional file separate gold file that can be evaluated, 199 | useful if it has alternate segmentation 200 | :param batch_size: the batch size, set this higher to speed up GPU inference 201 | """ 202 | archive = os.path.join(archive_dir, "model.tar.gz") 203 | predict_and_evaluate_model_with_archive(predictor, params, archive, gold_file, 204 | pred_file, output_file, segment_file, batch_size) 205 | 206 | 207 | def save_metrics(evaluation: Dict[str, Any], output_file: str): 208 | """ 209 | Saves CoNLL 2018 evaluation as a JSON file. 210 | :param evaluation: the evaluation dict calculated by the CoNLL 2018 evaluation script 211 | :param output_file: the output file to save 212 | """ 213 | evaluation_dict = {k: v.__dict__ for k, v in evaluation.items()} 214 | 215 | with open(output_file, "w") as f: 216 | json.dump(evaluation_dict, f, indent=4) 217 | 218 | logger.info("Metric | Correct | Gold | Predicted | Aligned") 219 | logger.info("-----------+-----------+-----------+-----------+-----------") 220 | for metric in ["Tokens", "Sentences", "Words", "UPOS", "XPOS", "UFeats", 221 | "AllTags", "Lemmas", "UAS", "LAS", "CLAS", "MLAS", "BLEX"]: 222 | logger.info("{:11}|{:10.2f} |{:10.2f} |{:10.2f} |{}".format( 223 | metric, 224 | 100 * evaluation[metric].precision, 225 | 100 * evaluation[metric].recall, 226 | 100 * evaluation[metric].f1, 227 | "{:10.2f}".format(100 * evaluation[metric].aligned_accuracy) 228 | if evaluation[metric].aligned_accuracy is not None else "")) 229 | 230 | 231 | def cleanup_training(serialization_dir: str, keep_archive: bool = False, keep_weights: bool = False): 232 | """ 233 | Removes files generated from training. 234 | :param serialization_dir: the directory to clean 235 | :param keep_archive: whether to keep a copy of the model archive 236 | :param keep_weights: whether to keep copies of the intermediate model checkpoints 237 | """ 238 | if not keep_weights: 239 | for file in glob.glob(os.path.join(serialization_dir, "*.th")): 240 | os.remove(file) 241 | if not keep_archive: 242 | os.remove(os.path.join(serialization_dir, "model.tar.gz")) 243 | 244 | 245 | def archive_bert_model(serialization_dir: str, config_file: str, output_file: str = None): 246 | """ 247 | Extracts BERT parameters from the given model and saves them to an archive. 248 | :param serialization_dir: the directory containing the saved model archive 249 | :param config_file: the configuration file of the model archive 250 | :param output_file: the output BERT archive name to save 251 | """ 252 | archive = load_archive(os.path.join(serialization_dir, "model.tar.gz")) 253 | 254 | model = archive.model 255 | model.eval() 256 | 257 | try: 258 | bert_model = model.text_field_embedder.token_embedder_bert.model 259 | except AttributeError: 260 | logger.warning(f"Could not find the BERT model inside the archive {serialization_dir}") 261 | traceback.print_exc() 262 | return 263 | 264 | weights_file = os.path.join(serialization_dir, "pytorch_model.bin") 265 | torch.save(bert_model.state_dict(), weights_file) 266 | 267 | if not output_file: 268 | output_file = os.path.join(serialization_dir, "bert-finetune.tar.gz") 269 | 270 | with tarfile.open(output_file, 'w:gz') as archive: 271 | archive.add(config_file, arcname="bert_config.json") 272 | archive.add(weights_file, arcname="pytorch_model.bin") 273 | 274 | os.remove(weights_file) 275 | 276 | 277 | def evaluate_sigmorphon_model(gold_file: str, pred_file: str, output_file: str): 278 | """ 279 | Evaluates the predicted file according to SIGMORPHON 2019 Task 2 280 | :param gold_file: the gold annotations 281 | :param pred_file: the predicted annotations 282 | :param output_file: a JSON file to save with the evaluation metrics 283 | """ 284 | results_keys = ["lemma_acc", "lemma_dist", "msd_acc", "msd_f1"] 285 | 286 | reference = read_conllu(gold_file) 287 | output = read_conllu(pred_file) 288 | results = manipulate_data(input_pairs(reference, output)) 289 | 290 | output_dict = {k: v for k, v in zip(results_keys, results)} 291 | 292 | with open(output_file, "w") as f: 293 | json.dump(output_dict, f, indent=4) 294 | --------------------------------------------------------------------------------